[
  {
    "path": ".axolotl-complete.bash",
    "content": "#!/bin/bash\n\n_axolotl_completions() {\n    local cur prev\n    COMPREPLY=()\n    cur=\"${COMP_WORDS[COMP_CWORD]}\"\n    prev=\"${COMP_WORDS[COMP_CWORD-1]}\"\n\n    # If we're completing the first argument (the command)\n    if [[ $COMP_CWORD -eq 1 ]]; then\n        mapfile -t COMPREPLY < <(compgen -W \"delinearize-llama4 fetch lm-eval merge-sharded-fsdp-weights quantize vllm-serve evaluate inference merge-lora preprocess train\" -- \"$cur\")\n        return 0\n    fi\n\n    # Commands that should complete with directories and YAML files\n    local -a yaml_commands=(\"merge-sharded-fsdp-weights\" \"quantize\" \"vllm-serve\" \"evaluate\" \"inference\" \"merge-lora\" \"preprocess\" \"train\")\n\n    # Check if previous word is in our list\n    if [[ \" ${yaml_commands[*]} \" =~ (^|[[:space:]])$prev($|[[:space:]]) ]]; then\n        # Use filename completion which handles directories properly\n        compopt -o filenames\n        mapfile -t COMPREPLY < <(compgen -f -- \"$cur\")\n\n        # Filter to only include directories and YAML files\n        local -a filtered=()\n        for item in \"${COMPREPLY[@]}\"; do\n            if [[ -d \"$item\" ]] || [[ \"$item\" == *.yaml ]] || [[ \"$item\" == *.yml ]]; then\n                filtered+=(\"$item\")\n            fi\n        done\n        COMPREPLY=(\"${filtered[@]}\")\n\n        return 0\n    fi\n\n    # Default: no completion\n    return 0\n}\n\n# Remove the -o nospace option - let filenames handle it\ncomplete -F _axolotl_completions axolotl\n"
  },
  {
    "path": ".bandit",
    "content": "[bandit]\nexclude = tests\nskips = B101,B615,B102,B110\n"
  },
  {
    "path": ".coderabbit.yaml",
    "content": "# yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json\nlanguage: \"en-US\"\nearly_access: false\nreviews:\n  profile: \"chill\"\n  request_changes_workflow: false\n  high_level_summary: true\n  review_status: true\n  collapse_walkthrough: true\n  poem: false\n  sequence_diagrams: false\n  auto_review:\n    enabled: true\n    drafts: false\n    auto_incremental_review: false\nchat:\n  auto_reply: true\n"
  },
  {
    "path": ".coveragerc",
    "content": "[run]\nsource = axolotl\nomit =\n    */tests/*\n    setup.py\n\n[report]\nexclude_lines =\n    pragma: no cover\n    def __repr__\n    raise NotImplementedError\n    if __name__ == .__main__.:\n    pass\n    raise ImportError\n"
  },
  {
    "path": ".editorconfig",
    "content": "root = true\n\n[*]\nend_of_line = lf\ninsert_final_newline = true\ntrim_trailing_whitespace = true\n\n[*.py]\nindent_style = space\nindent_size = 4\n\n[**.yml]\nindent_style = space\nindent_size = 2\n"
  },
  {
    "path": ".gitattributes",
    "content": "data/*.jsonl filter=lfs diff=lfs merge=lfs -text\n"
  },
  {
    "path": ".github/CODE_OF_CONDUCT.md",
    "content": "# Contributor Covenant Code of Conduct\n\n## Our Pledge\n\nWe as members, contributors, and leaders pledge to make participation in our\ncommunity a harassment-free experience for everyone, regardless of age, body\nsize, visible or invisible disability, ethnicity, sex characteristics, gender\nidentity and expression, level of experience, education, socio-economic status,\nnationality, personal appearance, race, religion, or sexual identity\nand orientation.\n\nWe pledge to act and interact in ways that contribute to an open, welcoming,\ndiverse, inclusive, and healthy community.\n\n## Our Standards\n\nExamples of behavior that contributes to a positive environment for our\ncommunity include:\n\n* Demonstrating empathy and kindness toward other people\n* Being respectful of differing opinions, viewpoints, and experiences\n* Giving and gracefully accepting constructive feedback\n* Accepting responsibility and apologizing to those affected by our mistakes,\n  and learning from the experience\n* Focusing on what is best not just for us as individuals, but for the\n  overall community\n\nExamples of unacceptable behavior include:\n\n* The use of sexualized language or imagery, and sexual attention or\n  advances of any kind\n* Trolling, insulting or derogatory comments, and personal or political attacks\n* Public or private harassment\n* Publishing others' private information, such as a physical or email\n  address, without their explicit permission\n* Other conduct which could reasonably be considered inappropriate in a\n  professional setting\n\n## Enforcement Responsibilities\n\nCommunity leaders are responsible for clarifying and enforcing our standards of\nacceptable behavior and will take appropriate and fair corrective action in\nresponse to any behavior that they deem inappropriate, threatening, offensive,\nor harmful.\n\nCommunity leaders have the right and responsibility to remove, edit, or reject\ncomments, commits, code, wiki edits, issues, and other contributions that are\nnot aligned to this Code of Conduct, and will communicate reasons for moderation\ndecisions when appropriate.\n\n## Scope\n\nThis Code of Conduct applies within all community spaces, and also applies when\nan individual is officially representing the community in public spaces.\nExamples of representing our community include using an official e-mail address,\nposting via an official social media account, or acting as an appointed\nrepresentative at an online or offline event.\n\n## Enforcement\n\nInstances of abusive, harassing, or otherwise unacceptable behavior may be\nreported to the community leaders responsible for enforcement on Discord\nat https://discord.gg/QYF8QrtEUm\n\nAll complaints will be reviewed and investigated promptly and fairly.\n\nAll community leaders are obligated to respect the privacy and security of the\nreporter of any incident.\n\n## Enforcement Guidelines\n\nCommunity leaders will follow these Community Impact Guidelines in determining\nthe consequences for any action they deem in violation of this Code of Conduct:\n\n### 1. Correction\n\n**Community Impact**: Use of inappropriate language or other behavior deemed\nunprofessional or unwelcome in the community.\n\n**Consequence**: A private, written warning from community leaders, providing\nclarity around the nature of the violation and an explanation of why the\nbehavior was inappropriate. A public apology may be requested.\n\n### 2. Warning\n\n**Community Impact**: A violation through a single incident or series\nof actions.\n\n**Consequence**: A warning with consequences for continued behavior. No\ninteraction with the people involved, including unsolicited interaction with\nthose enforcing the Code of Conduct, for a specified period of time. This\nincludes avoiding interactions in community spaces as well as external channels\nlike social media. Violating these terms may lead to a temporary or\npermanent ban.\n\n### 3. Temporary Ban\n\n**Community Impact**: A serious violation of community standards, including\nsustained inappropriate behavior.\n\n**Consequence**: A temporary ban from any sort of interaction or public\ncommunication with the community for a specified period of time. No public or\nprivate interaction with the people involved, including unsolicited interaction\nwith those enforcing the Code of Conduct, is allowed during this period.\nViolating these terms may lead to a permanent ban.\n\n### 4. Permanent Ban\n\n**Community Impact**: Demonstrating a pattern of violation of community\nstandards, including sustained inappropriate behavior,  harassment of an\nindividual, or aggression toward or disparagement of classes of individuals.\n\n**Consequence**: A permanent ban from any sort of public interaction within\nthe community.\n\n## Attribution\n\nThis Code of Conduct is adapted from the [Contributor Covenant][homepage],\nversion 2.0, available at\nhttps://www.contributor-covenant.org/version/2/0/code_of_conduct.html.\n\nCommunity Impact Guidelines were inspired by [Mozilla's code of conduct\nenforcement ladder](https://github.com/mozilla/diversity).\n\n[homepage]: https://www.contributor-covenant.org\n\nFor answers to common questions about this code of conduct, see the FAQ at\nhttps://www.contributor-covenant.org/faq. Translations are available at\nhttps://www.contributor-covenant.org/translations.\n"
  },
  {
    "path": ".github/CONTRIBUTING.md",
    "content": "# Contributing to axolotl\n\nFirst of all, thank you for your interest in contributing to axolotl! We appreciate the time and effort you're willing to invest in making our project better. This document provides guidelines and information to make the contribution process as smooth as possible.\n\n## Table of Contents\n\n- [Code of Conduct](#code-of-conduct)\n- [Getting Started](#getting-started)\n- [How to Contribute](#how-to-contribute)\n  - [Reporting Bugs](#reporting-bugs)\n  - [Suggesting Enhancements](#suggesting-enhancements)\n  - [Submitting Pull Requests](#submitting-pull-requests)\n- [Style Guidelines](#style-guidelines)\n  - [Code Style](#code-style)\n  - [Commit Messages](#commit-messages)\n- [Additional Resources](#additional-resources)\n\n## Code of Conduct\n\nAll contributors are expected to adhere to our [Code of Conduct](CODE_OF_CONDUCT.md). Please read it before participating in the axolotl community.\n\n## Getting Started\n\nBugs? Please check for open issue else create a new [Issue](https://github.com/axolotl-ai-cloud/axolotl/issues/new).\n\nPRs are **greatly welcome**!\n\n1. Fork the repository and clone it to your local machine.\n2. Set up the development environment by following the instructions in the [README.md](https://github.com/axolotl-ai-cloud/axolotl/tree/main/README.md) file.\n3. Explore the codebase, run tests, and verify that everything works as expected.\n\nPlease run below to setup env\n```bash\npip3 install -r requirements-dev.txt -r requirements-tests.txt\npre-commit install\n\n# test\npytest tests/\n```\n\n## How to Contribute\n\n### Reporting Bugs\n\nIf you encounter a bug or issue while using axolotl, please open a new issue on the [GitHub Issues](https://github.com/axolotl-ai-cloud/axolotl/issues) page. Provide a clear and concise description of the problem, steps to reproduce it, and any relevant error messages or logs.\n\n### Suggesting Enhancements\n\nWe welcome ideas for improvements and new features. To suggest an enhancement, open a new issue on the [GitHub Issues](https://github.com/axolotl-ai-cloud/axolotl/issues) page. Describe the enhancement in detail, explain the use case, and outline the benefits it would bring to the project.\n\n### Submitting Pull Requests\n\n1. Create a new branch for your feature or bugfix. Use a descriptive name like `feature/your-feature-name` or `fix/your-bugfix-name`.\n2. Make your changes, following the [Style Guidelines](#style-guidelines) below.\n3. Test your changes and ensure that they don't introduce new issues or break existing functionality.\n4. Commit your changes, following the [commit message guidelines](#commit-messages).\n5. Push your branch to your fork on GitHub.\n6. Open a new pull request against the `main` branch of the axolotl repository. Include a clear and concise description of your changes, referencing any related issues.\n\n#### Skipping CI Checks\n\nYou can skip certain CI checks by including specific keywords in your commit messages:\n\n- `[skip ci]` or `skip ci` - Skips all CI checks for that commit\n- `[skip-e2e]` or `skip-e2e` - Skips only end-to-end tests while running other CI checks. You may also include this in the title of your PR to disable end-to-end tests for the entire PR.\n\n## Style Guidelines\n\n### Code Style\n\naxolotl uses [Ruff](https://docs.astral.sh/ruff/) as its code style guide. Please ensure that your code follows these guidelines.\n\nUse the pre-commit linter to ensure that your code is formatted consistently.\n```bash\npre-commit run --all-files\n```\n\n### Commit Messages\n\nWrite clear and concise commit messages that briefly describe the changes made in each commit. Use the imperative mood and start with a capitalized verb, e.g., \"Add new feature\" or \"Fix bug in function\".\n\n## Additional Resources\n\n- [GitHub Help](https://help.github.com/)\n- [GitHub Pull Request Documentation](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests)\n- [Ruff](https://docs.astral.sh/ruff/)\n\nThank you once again for your interest in contributing to axolotl. We look forward to collaborating with you and creating an even better project together!\n"
  },
  {
    "path": ".github/FUNDING.yml",
    "content": "# These are supported funding model platforms\n\ngithub: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]\npatreon: # Replace with a single Patreon username\nopen_collective: # Replace with a single Open Collective username\nko_fi: # Replace with a single Ko-fi username\ntidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel\ncommunity_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry\nliberapay: # Replace with a single Liberapay username\nissuehunt: # Replace with a single IssueHunt username\notechie: # Replace with a single Otechie username\nlfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry\ncustom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/bug-report.yaml",
    "content": "name: Bug Report\ndescription: File a bug report\nlabels: [\"bug\", \"needs triage\"]\nbody:\n  - type: markdown\n    attributes:\n      value: |\n        ## Before you start\n        Please **make sure you are on the latest version.**\n        If you encountered the issue after you installed, updated, or reloaded, **please try restarting before reporting the bug**.\n\n  - type: checkboxes\n    id: no-duplicate-issues\n    attributes:\n      label: \"Please check that this issue hasn't been reported before.\"\n      description: \"The **Label filters** may help make your search more focussed.\"\n      options:\n        - label: \"I searched previous [Bug Reports](https://github.com/axolotl-ai-cloud/axolotl/labels/bug) didn't find any similar reports.\"\n          required: true\n\n  - type: textarea\n    id: expected\n    attributes:\n      label: Expected Behavior\n      description: Tell us what **should** happen.\n    validations:\n      required: true\n\n  - type: textarea\n    id: what-happened\n    attributes:\n      label: Current behaviour\n      description: |\n        Tell us what happens instead of the expected behavior.\n        Provide stacktrace and/or screenshots.\n    validations:\n      required: true\n\n  - type: textarea\n    id: reproduce\n    attributes:\n      label: Steps to reproduce\n      description: |\n        Which exact steps can a developer take to reproduce the issue?\n        The more detail you provide, the easier it will be to narrow down and fix the bug.\n        Please paste in tasks and/or queries **as text, not screenshots**.\n      placeholder: |\n        Example of the level of detail needed to reproduce any bugs efficiently and reliably.\n        1. Go to the '...' page.\n        2. Click on the '...' button.\n        3. Scroll down to '...'.\n        4. Observe the error.\n    validations:\n      required: true\n\n  - type: textarea\n    id: config\n    attributes:\n      label: Config yaml\n      description: |\n        Please attach the config yaml!\n      render: yaml\n\n  - type: textarea\n    id: possible-solution\n    attributes:\n      label: Possible solution\n      description: |\n        Not obligatory, but please suggest a fix or reason for the bug, if you have an idea.\n\n\n  - type: checkboxes\n    id: operating-systems\n    attributes:\n      label: Which Operating Systems are you using?\n      description: You may select more than one.\n      options:\n        - label: Linux\n        - label: macOS\n        - label: Windows\n\n  - type: input\n    id: Python-version\n    attributes:\n      label: Python Version\n      description: Which {Programming} version are you using?\n      placeholder: 3.10 / please change accordingly\n    validations:\n      required: true\n\n  - type: input\n    id: axolotl-branch-commit\n    attributes:\n      label: axolotl branch-commit\n      description: On which branch/commit are you?\n      placeholder: main/4d6490b\n    validations:\n      required: true\n\n  - type: checkboxes\n    id: acknowledgements\n    attributes:\n      label: 'Acknowledgements'\n      description: 'Please confirm the following:'\n      options:\n        - label: 'My issue title is concise, descriptive, and in title casing.'\n          required: true\n        - label: 'I have searched the existing issues to make sure this bug has not been reported yet.'\n          required: true\n        - label: 'I am using the latest version of axolotl.'\n          required: true\n        - label: 'I have provided enough information for the maintainers to reproduce and diagnose the issue.'\n          required: true\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/config.yml",
    "content": "blank_issues_enabled: false\ncontact_links:\n  - name: Ask a question\n    url: https://github.com/axolotl-ai-cloud/axolotl/discussions/categories/q-a\n    about: Ask questions and discuss with other community members\n  - name: Discuss the Project in Discord\n    url: https://discord.gg/HhrNrHJPRb\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/docs.yml",
    "content": "name: Documentation Improvement / Clarity\ndescription: Make a suggestion to improve the project documentation.\nlabels: ['needs triage', 'docs']\nbody:\n  - type: markdown\n    attributes:\n      value: '## :book: Documentation :book:'\n  - type: markdown\n    attributes:\n      value: |\n        * Ask questions in [Discord](https://discord.gg/HhrNrHJPRb).\n        * Before you file an issue read the [Contributing guide](./CONTRIBUTING.md).\n        * Check to make sure someone hasn't already opened a [similar issue](https://github.com/axolotl-ai-cloud/axolotl/issues).\n  - type: textarea\n    attributes:\n      label: What piece of documentation is affected?\n      description: Please link to the article you'd like to see updated.\n    validations:\n      required: true\n  - type: textarea\n    attributes:\n      label: What part(s) of the article would you like to see updated?\n      description: |\n        - Give as much detail as you can to help us understand the change you want to see.\n        - Why should the docs be changed? What use cases does it support?\n        - What is the expected outcome?\n    validations:\n      required: true\n  - type: textarea\n    attributes:\n      label: Additional Information\n      description: Add any other context or screenshots about the feature request here.\n    validations:\n      required: false\n  - type: checkboxes\n    id: acknowledgements\n    attributes:\n      label: 'Acknowledgements'\n      description: 'Please confirm the following:'\n      options:\n        - label: 'My issue title is concise, descriptive, and in title casing.'\n          required: true\n        - label: 'I have searched the existing issues to make sure this feature has not been requested yet.'\n          required: true\n        - label: 'I have provided enough information for the maintainers to understand and evaluate this request.'\n          required: true\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/feature-request.yaml",
    "content": "name: Feature Request / Enhancement\ndescription: Suggest a new feature or feature enhancement for the project\nlabels: [\"enhancement\", \"needs triage\"]\nbody:\n  - type: checkboxes\n    id: no-duplicate-issues\n    attributes:\n      label: \"⚠️ Please check that this feature request hasn't been suggested before.\"\n      description: \"There are two locations for previous feature requests. Please search in both. Thank you. The **Label filters** may help make your search more focussed.\"\n      options:\n        - label: \"I searched previous [Ideas in Discussions](https://github.com/axolotl-ai-cloud/axolotl/discussions/categories/ideas) didn't find any similar feature requests.\"\n          required: true\n        - label: \"I searched previous [Issues](https://github.com/axolotl-ai-cloud/axolotl/labels/enhancement) didn't find any similar feature requests.\"\n          required: true\n\n  - type: textarea\n    id: feature-description\n    validations:\n      required: true\n    attributes:\n      label: \"🔖 Feature description\"\n      description: \"A clear and concise description of what the feature request is.\"\n      placeholder: \"You should add ...\"\n\n  - type: textarea\n    id: solution\n    validations:\n      required: true\n    attributes:\n      label: \"✔️ Solution\"\n      description: \"A clear and concise description of what you want to happen, and why.\"\n      placeholder: \"In my use-case, ...\"\n\n  - type: textarea\n    id: alternatives\n    validations:\n      required: false\n    attributes:\n      label: \"❓ Alternatives\"\n      description: \"A clear and concise description of any alternative solutions or features you've considered.\"\n      placeholder: \"I have considered ...\"\n\n  - type: textarea\n    id: additional-context\n    validations:\n      required: false\n    attributes:\n      label: \"📝 Additional Context\"\n      description: \"Add any other context or screenshots about the feature request here.\"\n      placeholder: \"...\"\n\n  - type: checkboxes\n    id: acknowledgements\n    attributes:\n      label: 'Acknowledgements'\n      description: 'Please confirm the following:'\n      options:\n        - label: 'My issue title is concise, descriptive, and in title casing.'\n          required: true\n        - label: 'I have searched the existing issues to make sure this feature has not been requested yet.'\n          required: true\n        - label: 'I have provided enough information for the maintainers to understand and evaluate this request.'\n          required: true\n"
  },
  {
    "path": ".github/PULL_REQUEST_TEMPLATE.md",
    "content": "<!--- Provide a general summary of your changes in the Title above -->\n\n# Description\n\n<!--- Describe your changes in detail -->\n\n## Motivation and Context\n\n<!--- Why is this change required? What problem does it solve? -->\n<!--- If it fixes an open issue, please link to the issue here. -->\n\n## How has this been tested?\n\n<!--- Please describe in detail how you tested your changes. -->\n<!--- Include details of your testing environment, tests ran to see how -->\n<!--- your change affects other areas of the code, etc. -->\n\n## AI Usage Disclaimer\n\n<!--- Was AI (e.g., ChatGPT, Claude, Copilot) used to generate or assist with this PR? -->\n<!--- Please indicate: No / Yes (specify which tool and to what extent) -->\n\n## Screenshots (if appropriate)\n\n## Types of changes\n\n<!--- What types of changes does your code introduce? Put an `x` in all the boxes that apply: -->\n\n## Social Handles (Optional)\n\n<!-- Thanks for submitting a bugfix or enhancement. -->\n<!-- We'd love to show our thanks to you on Twitter & Discord if you provide your handle -->\n"
  },
  {
    "path": ".github/SECURITY.md",
    "content": "# Security Policy\n\n## Supported Versions\n\nDue to the nature of the fast development that is happening in this project, only the latest released version can be supported.\n\n## Reporting a Vulnerability\n\nIf you find a vulnerability, please contact us on  [Discord](https://discord.gg/xcu3ECkH9a) rather than creating a GitHub issue to allow us some time to fix it before it is a known vulnerability to others.\n"
  },
  {
    "path": ".github/SUPPORT.md",
    "content": "# Support\n\nIf you need help with this project or have questions, please:\n\n1. Check the documentation.\n2. Search the existing issues and pull requests.\n3. Create a new issue if your question is not answered or your problem is not solved.\n4. Have a look in the [Discord server](https://discord.gg/HhrNrHJPRb)\n\nPlease note that this project is maintained by volunteers who have limited availability. We'll do our best to address your questions and concerns in a timely manner.\n"
  },
  {
    "path": ".github/release-drafter.yml",
    "content": "name-template: 'v$RESOLVED_VERSION'\ntag-template: 'v$RESOLVED_VERSION'\ncategories:\n  - title: '🚀 Features'\n    labels:\n      - 'feature'\n      - 'enhancement'\n  - title: '🐛 Bug Fixes'\n    labels:\n      - 'fix'\n      - 'bugfix'\n      - 'bug'\n  - title: '🧰 Maintenance'\n    label: 'chore'\nchange-template: '- $TITLE @$AUTHOR (#$NUMBER)'\nchange-title-escapes: '\\<*_&' # You can add # and @ to disable mentions, and add ` to disable code blocks.\nversion-resolver:\n  major:\n    labels:\n      - 'major'\n  minor:\n    labels:\n      - 'minor'\n  patch:\n    labels:\n      - 'patch'\n  default: patch\ntemplate: |\n  ## What’s Changed\n\n  $CHANGES\n"
  },
  {
    "path": ".github/workflows/base.yml",
    "content": "name: ci-cd-base\n\non:\n  push:\n    branches:\n      - \"main\"\n    paths:\n      - 'docker/Dockerfile-base'\n      - 'docker/Dockerfile-uv-base'\n      - '.github/workflows/base.yml'\n  pull_request:\n    paths:\n      - 'docker/Dockerfile-base'\n      - 'docker/Dockerfile-uv-base'\n      - '.github/workflows/base.yml'\n  workflow_dispatch:\n\npermissions:\n  contents: read\n\njobs:\n  build-base:\n    if: ${{ github.repository_owner == 'axolotl-ai-cloud' && (github.event_name != 'pull_request' || !github.event.pull_request.draft) }}\n    timeout-minutes: 480\n    # this job needs to be run on self-hosted GPU runners...\n    runs-on: ubuntu-latest-m\n    env:\n      HAS_DOCKERHUB_CREDS: ${{ secrets.DOCKERHUB_USERNAME != '' && secrets.DOCKERHUB_TOKEN != '' }}\n    strategy:\n      fail-fast: false\n      matrix:\n        include:\n          - cuda: \"128\"\n            cuda_version: 12.8.1\n            cudnn_version: \"\"\n            python_version: \"3.11\"\n            pytorch: 2.8.0\n            torch_cuda_arch_list: \"7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX\"\n            dockerfile: \"Dockerfile-base\"\n            platforms: \"linux/amd64\"\n          - cuda: \"128\"\n            cuda_version: 12.8.1\n            cudnn_version: \"\"\n            python_version: \"3.11\"\n            pytorch: 2.9.0\n            torch_cuda_arch_list: \"7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX\"\n            dockerfile: \"Dockerfile-base\"\n            platforms: \"linux/amd64,linux/arm64\"\n          - cuda: \"128\"\n            cuda_version: 12.8.1\n            cudnn_version: \"\"\n            python_version: \"3.11\"\n            pytorch: 2.9.1\n            torch_cuda_arch_list: \"7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX\"\n            dockerfile: \"Dockerfile-base\"\n            platforms: \"linux/amd64,linux/arm64\"\n          - cuda: \"128\"\n            cuda_version: 12.8.1\n            cudnn_version: \"\"\n            python_version: \"3.11\"\n            pytorch: 2.10.0\n            torch_cuda_arch_list: \"7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX\"\n            dockerfile: \"Dockerfile-base\"\n            platforms: \"linux/amd64,linux/arm64\"\n          - cuda: \"128\"\n            cuda_version: 12.8.1\n            cudnn_version: \"\"\n            python_version: \"3.12\"\n            pytorch: 2.10.0\n            torch_cuda_arch_list: \"7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX\"\n            dockerfile: \"Dockerfile-base\"\n            platforms: \"linux/amd64,linux/arm64\"\n#          - cuda: \"129\"\n#            cuda_version: 12.9.1\n#            cudnn_version: \"\"\n#            python_version: \"3.12\"\n#            pytorch: 2.9.1\n#            torch_cuda_arch_list: \"7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX\"\n#            dockerfile: \"Dockerfile-base\"\n#            platforms: \"linux/amd64,linux/arm64\"\n          - cuda: \"130\"\n            cuda_version: 13.0.0\n            cudnn_version: \"\"\n            python_version: \"3.11\"\n            pytorch: 2.9.1\n            torch_cuda_arch_list: \"9.0+PTX\"\n            dockerfile: \"Dockerfile-base\"\n            platforms: \"linux/amd64,linux/arm64\"\n          - cuda: \"130\"\n            cuda_version: 13.0.0\n            cudnn_version: \"\"\n            python_version: \"3.12\"\n            pytorch: 2.9.1\n            torch_cuda_arch_list: \"9.0+PTX\"\n            dockerfile: \"Dockerfile-base\"\n            platforms: \"linux/amd64,linux/arm64\"\n          - cuda: \"130\"\n            cuda_version: 13.0.0\n            cudnn_version: \"\"\n            python_version: \"3.12\"\n            pytorch: 2.10.0\n            torch_cuda_arch_list: \"9.0+PTX\"\n            dockerfile: \"Dockerfile-base\"\n            platforms: \"linux/amd64,linux/arm64\"\n#          - cuda: \"128\"\n#            cuda_version: 12.8.1\n#            cudnn_version: \"\"\n#            python_version: \"3.11\"\n#            pytorch: nightly\n#            torch_cuda_arch_list: \"7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX\"\n#            dockerfile: \"Dockerfile-base-nightly\"\n#          # \"next\" is for release candidates of pytorch\n#          - cuda: \"128\"\n#            cuda_version: 12.8.1\n#            cudnn_version: \"\"\n#            python_version: \"3.11\"\n#            pytorch: next\n#            torch_cuda_arch_list: \"7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX\"\n#            dockerfile: \"Dockerfile-base-next\"\n    steps:\n      - name: Checkout\n        uses: actions/checkout@v4\n      - name: Docker metadata\n        id: metadata\n        uses: docker/metadata-action@v5\n        with:\n          images: |\n            axolotlai/axolotl-base\n      - name: Login to Docker Hub\n        uses: docker/login-action@v3\n        if: ${{ github.event_name != 'pull_request' && env.HAS_DOCKERHUB_CREDS == 'true' }}\n        with:\n          username: ${{ secrets.DOCKERHUB_USERNAME }}\n          password: ${{ secrets.DOCKERHUB_TOKEN }}\n      - name: Set up Docker Buildx\n        uses: docker/setup-buildx-action@v3\n      - name: Build\n        uses: docker/build-push-action@v5\n        with:\n          context: .\n          file: ./docker/${{ matrix.dockerfile }}\n          platforms: ${{ matrix.platforms }}\n          push: ${{ github.event_name != 'pull_request' }}\n          tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}\n          labels: ${{ steps.metadata.outputs.labels }}\n          build-args: |\n            CUDA_VERSION=${{ matrix.cuda_version }}\n            CUDNN_VERSION=${{ matrix.cudnn_version }}\n            CUDA=${{ matrix.cuda }}\n            PYTHON_VERSION=${{ matrix.python_version }}\n            PYTORCH_VERSION=${{ matrix.pytorch }}\n            TORCH_CUDA_ARCH_LIST=${{ matrix.torch_cuda_arch_list }}\n  build-base-uv:\n    if: ${{ github.repository_owner == 'axolotl-ai-cloud' && (github.event_name != 'pull_request' || !github.event.pull_request.draft) }}\n    timeout-minutes: 480\n    runs-on: ubuntu-latest-m\n    env:\n      HAS_DOCKERHUB_CREDS: ${{ secrets.DOCKERHUB_USERNAME != '' && secrets.DOCKERHUB_TOKEN != '' }}\n    strategy:\n      fail-fast: false\n      matrix:\n        include:\n          - cuda: \"128\"\n            cuda_version: 12.8.1\n            cudnn_version: \"\"\n            python_version: \"3.11\"\n            pytorch: 2.8.0\n            torch_cuda_arch_list: \"7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX\"\n            dockerfile: \"Dockerfile-uv-base\"\n            platforms: \"linux/amd64\"\n          - cuda: \"128\"\n            cuda_version: 12.8.1\n            cudnn_version: \"\"\n            python_version: \"3.11\"\n            pytorch: 2.9.1\n            torch_cuda_arch_list: \"7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX\"\n            dockerfile: \"Dockerfile-uv-base\"\n            platforms: \"linux/amd64,linux/arm64\"\n          - cuda: \"128\"\n            cuda_version: 12.8.1\n            cudnn_version: \"\"\n            python_version: \"3.12\"\n            pytorch: 2.9.1\n            torch_cuda_arch_list: \"7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX\"\n            dockerfile: \"Dockerfile-uv-base\"\n            platforms: \"linux/amd64,linux/arm64\"\n          - cuda: \"128\"\n            cuda_version: 12.8.1\n            cudnn_version: \"\"\n            python_version: \"3.11\"\n            pytorch: 2.9.0\n            torch_cuda_arch_list: \"7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX\"\n            dockerfile: \"Dockerfile-uv-base\"\n            platforms: \"linux/amd64,linux/arm64\"\n          - cuda: \"128\"\n            cuda_version: 12.8.1\n            cudnn_version: \"\"\n            python_version: \"3.11\"\n            pytorch: 2.10.0\n            torch_cuda_arch_list: \"7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX\"\n            dockerfile: \"Dockerfile-uv-base\"\n            platforms: \"linux/amd64,linux/arm64\"\n          - cuda: \"128\"\n            cuda_version: 12.8.1\n            cudnn_version: \"\"\n            python_version: \"3.12\"\n            pytorch: 2.10.0\n            torch_cuda_arch_list: \"7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX\"\n            dockerfile: \"Dockerfile-uv-base\"\n            platforms: \"linux/amd64,linux/arm64\"\n#          - cuda: \"129\"\n#            cuda_version: 12.9.1\n#            cudnn_version: \"\"\n#            python_version: \"3.12\"\n#            pytorch: 2.9.1\n#            torch_cuda_arch_list: \"7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX\"\n#            dockerfile: \"Dockerfile-uv-base\"\n#            platforms: \"linux/amd64,linux/arm64\"\n          - cuda: \"130\"\n            cuda_version: 13.0.0\n            cudnn_version: \"\"\n            python_version: \"3.11\"\n            pytorch: 2.9.1\n            torch_cuda_arch_list: \"9.0+PTX\"\n            dockerfile: \"Dockerfile-uv-base\"\n            platforms: \"linux/amd64,linux/arm64\"\n          - cuda: \"130\"\n            cuda_version: 13.0.0\n            cudnn_version: \"\"\n            python_version: \"3.12\"\n            pytorch: 2.9.1\n            torch_cuda_arch_list: \"9.0+PTX\"\n            dockerfile: \"Dockerfile-uv-base\"\n            platforms: \"linux/amd64,linux/arm64\"\n          - cuda: \"130\"\n            cuda_version: 13.0.0\n            cudnn_version: \"\"\n            python_version: \"3.12\"\n            pytorch: 2.10.0\n            torch_cuda_arch_list: \"9.0+PTX\"\n            dockerfile: \"Dockerfile-uv-base\"\n            platforms: \"linux/amd64,linux/arm64\"\n    steps:\n      - name: Checkout\n        uses: actions/checkout@v4\n      - name: Docker metadata\n        id: metadata\n        uses: docker/metadata-action@v5\n        with:\n          images: |\n            axolotlai/axolotl-base-uv\n      - name: Login to Docker Hub\n        uses: docker/login-action@v3\n        if: ${{ github.event_name != 'pull_request' && env.HAS_DOCKERHUB_CREDS == 'true' }}\n        with:\n          username: ${{ secrets.DOCKERHUB_USERNAME }}\n          password: ${{ secrets.DOCKERHUB_TOKEN }}\n      - name: Set up Docker Buildx\n        uses: docker/setup-buildx-action@v3\n      - name: Build\n        uses: docker/build-push-action@v5\n        with:\n          context: .\n          file: ./docker/${{ matrix.dockerfile }}\n          platforms: ${{ matrix.platforms }}\n          push: ${{ github.event_name != 'pull_request' }}\n          tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}\n          labels: ${{ steps.metadata.outputs.labels }}\n          build-args: |\n            CUDA_VERSION=${{ matrix.cuda_version }}\n            CUDNN_VERSION=${{ matrix.cudnn_version }}\n            CUDA=${{ matrix.cuda }}\n            PYTHON_VERSION=${{ matrix.python_version }}\n            PYTORCH_VERSION=${{ matrix.pytorch }}\n            TORCH_CUDA_ARCH_LIST=${{ matrix.torch_cuda_arch_list }}\n"
  },
  {
    "path": ".github/workflows/docs.yml",
    "content": "name: Publish Docs\non:\n  push:\n    branches:\n      - main\n\npermissions:\n    contents: write\n    pages: write\n\njobs:\n    build-deploy:\n        runs-on: ubuntu-latest\n        steps:\n        - name: cleanup node\n          run: |\n            sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc /opt/hostedtoolcache/CodeQL\n        - name: Check out repository\n          uses: actions/checkout@v4\n        - name: Set up Quarto\n          uses: quarto-dev/quarto-actions/setup@v2\n        - name: Setup Python\n          uses: actions/setup-python@v5\n          with:\n            python-version: '3.11'\n        - name: Install dependencies\n          run: |\n            python3 -m pip install jupyter quartodoc\n            python3 -m pip install -e .\n        - name: Build autodoc\n          run: quartodoc build\n        - name: Publish to GitHub Pages (and render)\n          uses: quarto-dev/quarto-actions/publish@v2\n          with:\n            target: gh-pages\n          env:\n            GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}\n"
  },
  {
    "path": ".github/workflows/lint.yml",
    "content": "name: lint\non:\n  # check on PRs, and manual triggers\n  merge_group:\n  pull_request:\n      types: [opened, synchronize, reopened, ready_for_review]\n      paths:\n       - '**.py'\n       - 'requirements.txt'\n       - '.github/workflows/*.yml'\n       - \"*.[q]md\"\n       - \"examples/**/*.y[a]?ml\"\n       - \".pre-commit-config.yaml\"\n  workflow_dispatch:\n\npermissions:\n  contents: read\n\njobs:\n  pre-commit:\n    name: pre-commit\n    runs-on: ubuntu-latest\n    if: ${{ !github.event.pull_request.draft }}\n    steps:\n      - uses: actions/checkout@v4\n      - uses: actions/setup-python@v5\n        with:\n          python-version: \"3.11\"\n          cache: 'pip' # caching pip dependencies\n      - uses: pre-commit/action@v3.0.1\n"
  },
  {
    "path": ".github/workflows/main.yml",
    "content": "name: ci-cd\n\non:\n  push:\n    branches:\n      - \"main\"\n    tags:\n      - \"v*\"\n  workflow_dispatch:\n\npermissions:\n  contents: read\n\njobs:\n  build-axolotl:\n    if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}\n    strategy:\n      fail-fast: false\n      matrix:\n        include:\n          - cuda: 128\n            cuda_version: 12.8.1\n            python_version: \"3.11\"\n            pytorch: 2.8.0\n            axolotl_extras:\n            platforms: \"linux/amd64\"\n          - cuda: 128\n            cuda_version: 12.8.1\n            python_version: \"3.11\"\n            pytorch: 2.9.0\n            axolotl_extras:\n            platforms: \"linux/amd64,linux/arm64\"\n          - cuda: 128\n            cuda_version: 12.8.1\n            python_version: \"3.11\"\n            pytorch: 2.9.1\n            axolotl_extras:\n            platforms: \"linux/amd64,linux/arm64\"\n            is_latest: true\n          - cuda: 128\n            cuda_version: 12.8.1\n            python_version: \"3.12\"\n            pytorch: 2.10.0\n            axolotl_extras:\n            platforms: \"linux/amd64,linux/arm64\"\n#          - cuda: 129\n#            cuda_version: 12.9.1\n#            python_version: \"3.12\"\n#            pytorch: 2.9.1\n#            axolotl_extras:\n#            platforms: \"linux/amd64,linux/arm64\"\n          - cuda: 130\n            cuda_version: 13.0.0\n            python_version: \"3.11\"\n            pytorch: 2.9.1\n            axolotl_extras:\n            platforms: \"linux/amd64,linux/arm64\"\n          - cuda: 130\n            cuda_version: 13.0.0\n            python_version: \"3.12\"\n            pytorch: 2.10.0\n            axolotl_extras:\n            platforms: \"linux/amd64,linux/arm64\"\n    runs-on: axolotl-gpu-runner\n    steps:\n      - name: Checkout\n        uses: actions/checkout@v4\n      - name: Docker metadata\n        id: metadata\n        uses: docker/metadata-action@v5\n        with:\n          images: |\n            axolotlai/axolotl\n          tags: |\n            type=ref,event=branch\n            type=pep440,pattern={{version}}\n      - name: Set up Docker Buildx\n        uses: docker/setup-buildx-action@v3\n      - name: Login to Docker Hub\n        uses: docker/login-action@v3\n        with:\n          username: ${{ secrets.DOCKERHUB_USERNAME }}\n          password: ${{ secrets.DOCKERHUB_TOKEN }}\n      # guidance for testing before pushing: https://docs.docker.com/build/ci/github-actions/test-before-push/\n      - name: Build and export to Docker\n        uses: docker/build-push-action@v5\n        with:\n          context: .\n          platforms: ${{ matrix.platforms }}\n          build-args: |\n            BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}\n            CUDA=${{ matrix.cuda }}\n            PYTORCH_VERSION=${{ matrix.pytorch }}\n            AXOLOTL_ARGS=${{ matrix.axolotl_args }}\n            AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}\n          file: ./docker/Dockerfile\n          push: ${{ github.event_name != 'pull_request' }}\n          tags: |\n            ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}\n            ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}\n            ${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}\n          labels: ${{ steps.metadata.outputs.labels }}\n\n  build-axolotl-uv:\n    if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}\n    strategy:\n      fail-fast: false\n      matrix:\n        include:\n          - cuda: 128\n            cuda_version: 12.8.1\n            python_version: \"3.11\"\n            pytorch: 2.9.1\n            axolotl_extras:\n            platforms: \"linux/amd64,linux/arm64\"\n          - cuda: 128\n            cuda_version: 12.8.1\n            python_version: \"3.12\"\n            pytorch: 2.9.1\n            axolotl_extras:\n            platforms: \"linux/amd64,linux/arm64\"\n            is_latest: true\n          - cuda: 128\n            cuda_version: 12.8.1\n            python_version: \"3.12\"\n            pytorch: 2.10.0\n            axolotl_extras:\n            platforms: \"linux/amd64,linux/arm64\"\n          - cuda: 130\n            cuda_version: 13.0.0\n            python_version: \"3.11\"\n            pytorch: 2.9.1\n            axolotl_extras:\n            platforms: \"linux/amd64,linux/arm64\"\n          - cuda: 130\n            cuda_version: 13.0.0\n            python_version: \"3.12\"\n            pytorch: 2.10.0\n            axolotl_extras:\n            platforms: \"linux/amd64,linux/arm64\"\n    runs-on: axolotl-gpu-runner\n    steps:\n      - name: Checkout\n        uses: actions/checkout@v4\n      - name: Docker metadata\n        id: metadata\n        uses: docker/metadata-action@v5\n        with:\n          images: |\n            axolotlai/axolotl-uv\n          tags: |\n            type=ref,event=branch\n            type=pep440,pattern={{version}}\n      - name: Set up Docker Buildx\n        uses: docker/setup-buildx-action@v3\n      - name: Login to Docker Hub\n        uses: docker/login-action@v3\n        with:\n          username: ${{ secrets.DOCKERHUB_USERNAME }}\n          password: ${{ secrets.DOCKERHUB_TOKEN }}\n      # guidance for testing before pushing: https://docs.docker.com/build/ci/github-actions/test-before-push/\n      - name: Build and export to Docker\n        uses: docker/build-push-action@v5\n        with:\n          context: .\n          platforms: ${{ matrix.platforms }}\n          build-args: |\n            BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}\n            CUDA=${{ matrix.cuda }}\n            PYTORCH_VERSION=${{ matrix.pytorch }}\n            AXOLOTL_ARGS=${{ matrix.axolotl_args }}\n            AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}\n          file: ./docker/Dockerfile-uv\n          push: ${{ github.event_name != 'pull_request' }}\n          tags: |\n            ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}\n            ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}\n            ${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}\n          labels: ${{ steps.metadata.outputs.labels }}\n\n  build-axolotl-cloud:\n    needs: build-axolotl\n    if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}\n    # this job needs to be run on self-hosted GPU runners...\n    strategy:\n      fail-fast: false\n      matrix:\n        include:\n          - cuda: 128\n            cuda_version: 12.8.1\n            python_version: \"3.11\"\n            pytorch: 2.8.0\n            axolotl_extras:\n            platforms: \"linux/amd64\"\n          - cuda: 128\n            cuda_version: 12.8.1\n            python_version: \"3.11\"\n            pytorch: 2.9.0\n            axolotl_extras:\n            platforms: \"linux/amd64,linux/arm64\"\n          - cuda: 128\n            cuda_version: 12.8.1\n            python_version: \"3.11\"\n            pytorch: 2.9.1\n            axolotl_extras:\n            is_latest: true\n            platforms: \"linux/amd64,linux/arm64\"\n          - cuda: 128\n            cuda_version: 12.8.1\n            python_version: \"3.12\"\n            pytorch: 2.10.0\n            axolotl_extras:\n            platforms: \"linux/amd64,linux/arm64\"\n#          - cuda: 129\n#            cuda_version: 12.9.1\n#            python_version: \"3.12\"\n#            pytorch: 2.9.1\n#            axolotl_extras:\n#            platforms: \"linux/amd64,linux/arm64\"\n          - cuda: 130\n            cuda_version: 13.0.0\n            python_version: \"3.11\"\n            pytorch: 2.9.1\n            axolotl_extras:\n            platforms: \"linux/amd64,linux/arm64\"\n          - cuda: 130\n            cuda_version: 13.0.0\n            python_version: \"3.12\"\n            pytorch: 2.10.0\n            axolotl_extras:\n            platforms: \"linux/amd64,linux/arm64\"\n    runs-on: axolotl-gpu-runner\n    steps:\n      - name: Checkout\n        uses: actions/checkout@v4\n      - name: Docker metadata\n        id: metadata\n        uses: docker/metadata-action@v5\n        with:\n          images: |\n            axolotlai/axolotl-cloud\n          tags: |\n            type=ref,event=branch\n            type=pep440,pattern={{version}}\n      - name: Login to Docker Hub\n        uses: docker/login-action@v3\n        with:\n          username: ${{ secrets.DOCKERHUB_USERNAME }}\n          password: ${{ secrets.DOCKERHUB_TOKEN }}\n      - name: Set up Docker Buildx\n        uses: docker/setup-buildx-action@v3\n      - name: Build\n        uses: docker/build-push-action@v5\n        with:\n          context: .\n          platforms: ${{ matrix.platforms }}\n          build-args: |\n            BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}\n            CUDA=${{ matrix.cuda }}\n          file: ./docker/Dockerfile-cloud\n          push: ${{ github.event_name != 'pull_request' }}\n          tags: |\n             ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}\n             ${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}\n          labels: ${{ steps.metadata.outputs.labels }}\n\n  build-axolotl-cloud-uv:\n    needs: build-axolotl-uv\n    if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}\n    # this job needs to be run on self-hosted GPU runners...\n    strategy:\n      fail-fast: false\n      matrix:\n        include:\n          - cuda: 128\n            cuda_version: 12.8.1\n            python_version: \"3.11\"\n            pytorch: 2.9.1\n            axolotl_extras:\n            platforms: \"linux/amd64,linux/arm64\"\n          - cuda: 128\n            cuda_version: 12.8.1\n            python_version: \"3.12\"\n            pytorch: 2.9.1\n            axolotl_extras:\n            is_latest: true\n            platforms: \"linux/amd64,linux/arm64\"\n          - cuda: 128\n            cuda_version: 12.8.1\n            python_version: \"3.12\"\n            pytorch: 2.10.0\n            axolotl_extras:\n            platforms: \"linux/amd64,linux/arm64\"\n          - cuda: 130\n            cuda_version: 13.0.0\n            python_version: \"3.11\"\n            pytorch: 2.9.1\n            axolotl_extras:\n            platforms: \"linux/amd64,linux/arm64\"\n          - cuda: 130\n            cuda_version: 13.0.0\n            python_version: \"3.12\"\n            pytorch: 2.10.0\n            axolotl_extras:\n            platforms: \"linux/amd64,linux/arm64\"\n    runs-on: axolotl-gpu-runner\n    steps:\n      - name: Checkout\n        uses: actions/checkout@v4\n      - name: Docker metadata\n        id: metadata\n        uses: docker/metadata-action@v5\n        with:\n          images: |\n            axolotlai/axolotl-cloud-uv\n          tags: |\n            type=ref,event=branch\n            type=pep440,pattern={{version}}\n      - name: Login to Docker Hub\n        uses: docker/login-action@v3\n        with:\n          username: ${{ secrets.DOCKERHUB_USERNAME }}\n          password: ${{ secrets.DOCKERHUB_TOKEN }}\n      - name: Set up Docker Buildx\n        uses: docker/setup-buildx-action@v3\n      - name: Build\n        uses: docker/build-push-action@v5\n        with:\n          context: .\n          platforms: ${{ matrix.platforms }}\n          build-args: |\n            BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}\n            CUDA=${{ matrix.cuda }}\n          file: ./docker/Dockerfile-cloud-uv\n          push: ${{ github.event_name != 'pull_request' }}\n          tags: |\n             ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}\n             ${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}\n          labels: ${{ steps.metadata.outputs.labels }}\n\n  build-axolotl-cloud-no-tmux:\n    needs: build-axolotl\n    if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}\n    # this job needs to be run on self-hosted GPU runners...\n    strategy:\n      fail-fast: false\n      matrix:\n        include:\n          - cuda: 128\n            cuda_version: 12.8.1\n            python_version: \"3.11\"\n            pytorch: 2.9.1\n            axolotl_extras:\n            is_latest: true\n          - cuda: 130\n            cuda_version: 13.0.0\n            python_version: \"3.11\"\n            pytorch: 2.9.1\n            axolotl_extras:\n            is_latest:\n    runs-on: axolotl-gpu-runner\n    steps:\n      - name: Checkout\n        uses: actions/checkout@v4\n      - name: Docker metadata\n        id: metadata\n        uses: docker/metadata-action@v5\n        with:\n          images: |\n            axolotlai/axolotl-cloud-term\n          tags: |\n            type=ref,event=branch\n            type=pep440,pattern={{version}}\n      - name: Login to Docker Hub\n        uses: docker/login-action@v3\n        with:\n          username: ${{ secrets.DOCKERHUB_USERNAME }}\n          password: ${{ secrets.DOCKERHUB_TOKEN }}\n      - name: Set up Docker Buildx\n        uses: docker/setup-buildx-action@v3\n      - name: Build\n        uses: docker/build-push-action@v5\n        with:\n          context: .\n          platforms: linux/amd64,linux/arm64\n          build-args: |\n            BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}\n            CUDA=${{ matrix.cuda }}\n          file: ./docker/Dockerfile-cloud-no-tmux\n          push: ${{ github.event_name != 'pull_request' }}\n          tags: |\n             ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}\n             ${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}\n          labels: ${{ steps.metadata.outputs.labels }}\n"
  },
  {
    "path": ".github/workflows/multi-gpu-e2e.yml",
    "content": "name: docker-multigpu-tests-biweekly\n\non:\n  pull_request:\n    paths:\n      - 'tests/e2e/multigpu/**.py'\n      - 'requirements.txt'\n      - 'setup.py'\n      - 'pyproject.toml'\n      - '.github/workflows/multi-gpu-e2e.yml'\n      - 'scripts/cutcrossentropy_install.py'\n      - 'src/axolotl/core/trainers/mixins/sequence_parallel.py'\n      - 'src/axolotl/utils/distributed.py'\n  workflow_dispatch:\n  schedule:\n    - cron: '0 0 * * 1,4'  # Runs at 00:00 UTC every monday & thursday\n\n# Cancel jobs on the same ref if a new one is triggered\nconcurrency:\n  group: ${{ github.workflow }}-${{ github.ref }}\n  cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}\n\npermissions:\n  contents: read\n\nenv:\n  MODAL_IMAGE_BUILDER_VERSION: \"2025.06\"\n\njobs:\n  test-axolotl-multigpu:\n    if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' && (github.event_name != 'pull_request' || !github.event.pull_request.draft) }}\n    strategy:\n      fail-fast: false\n      matrix:\n        include:\n          - cuda: 128\n            cuda_version: 12.8.1\n            python_version: \"3.11\"\n            pytorch: 2.8.0\n            axolotl_extras: fbgemm-gpu\n            num_gpus: 2\n#          - cuda: 129\n#            cuda_version: 12.9.1\n#            python_version: \"3.12\"\n#            pytorch: 2.9.1\n#            axolotl_extras: \"fbgemm-gpu\"\n#            num_gpus: 2\n#            dockerfile: \"Dockerfile-uv.jinja\"\n          - cuda: 130\n            cuda_version: 13.0.0\n            python_version: \"3.11\"\n            pytorch: 2.9.1\n            axolotl_extras:\n#            axolotl_extras: fbgemm-gpu\n            num_gpus: 2\n          - cuda: 128\n            cuda_version: 12.8.1\n            python_version: \"3.11\"\n            pytorch: 2.10.0\n            axolotl_extras: \"fbgemm-gpu\"\n            num_gpus: 2\n            dockerfile: \"Dockerfile-uv.jinja\"\n    runs-on: [self-hosted, modal]\n    timeout-minutes: 120\n    steps:\n      - name: Checkout\n        uses: actions/checkout@v4\n      - name: Install Python\n        uses: actions/setup-python@v5\n        with:\n          python-version: \"3.11\"\n      - name: Install Modal\n        run: |\n          python -m pip install --upgrade pip\n          pip install modal==1.3.0.post1 jinja2\n      - name: Update env vars\n        run: |\n          echo \"BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}\" >> $GITHUB_ENV\n          echo \"PYTORCH_VERSION=${{ matrix.pytorch}}\" >> $GITHUB_ENV\n          echo \"AXOLOTL_ARGS=${{ matrix.axolotl_args}}\" >> $GITHUB_ENV\n          echo \"AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}\" >> $GITHUB_ENV\n          echo \"CUDA=${{ matrix.cuda }}\" >> $GITHUB_ENV\n          echo \"N_GPUS=${{ matrix.num_gpus }}\" >> $GITHUB_ENV\n          echo \"E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}\" >> $GITHUB_ENV\n      - name: Run tests job on Modal\n        env:\n          CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}\n        run: |\n          modal run -m cicd.multigpu\n"
  },
  {
    "path": ".github/workflows/nightlies.yml",
    "content": "name: docker-nightlies\n\non:\n  workflow_dispatch:\n  schedule:\n    - cron: '0 0 * * *'  # Runs at 00:00 UTC every day\n\npermissions:\n  contents: read\n\njobs:\n  build-axolotl:\n    if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}\n    strategy:\n      fail-fast: false\n      matrix:\n        include:\n          - cuda: 128\n            cuda_version: 12.8.1\n            python_version: \"3.11\"\n            pytorch: 2.8.0\n            axolotl_extras:\n          - cuda: 128\n            cuda_version: 12.8.1\n            python_version: \"3.11\"\n            pytorch: 2.9.1\n            axolotl_extras:\n    runs-on: axolotl-gpu-runner\n    steps:\n      - name: Checkout\n        uses: actions/checkout@v4\n      - name: Docker metadata\n        id: metadata\n        uses: docker/metadata-action@v5\n        with:\n          images: |\n            axolotlai/axolotl\n          tags: |\n            type=raw,value={{ branch }}-{{ date 'YYYYMMDD' }}\n      - name: Set up Docker Buildx\n        uses: docker/setup-buildx-action@v3\n      - name: Login to Docker Hub\n        uses: docker/login-action@v3\n        with:\n          username: ${{ secrets.DOCKERHUB_USERNAME }}\n          password: ${{ secrets.DOCKERHUB_TOKEN }}\n      # guidance for testing before pushing: https://docs.docker.com/build/ci/github-actions/test-before-push/\n      - name: Build and export to Docker\n        uses: docker/build-push-action@v5\n        with:\n          context: .\n          build-args: |\n            BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}\n            CUDA=${{ matrix.cuda }}\n            PYTORCH_VERSION=${{ matrix.pytorch }}\n            AXOLOTL_ARGS=${{ matrix.axolotl_args }}\n          file: ./docker/Dockerfile\n          push: ${{ github.event_name != 'pull_request' }}\n          tags: |\n            ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}\n          labels: ${{ steps.metadata.outputs.labels }}\n\n  build-axolotl-cloud:\n    needs: build-axolotl\n    if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}\n    # this job needs to be run on self-hosted GPU runners...\n    strategy:\n      matrix:\n        include:\n          - cuda: 128\n            cuda_version: 12.8.1\n            python_version: \"3.11\"\n            pytorch: 2.8.0\n            axolotl_extras:\n          - cuda: 128\n            cuda_version: 12.8.1\n            python_version: \"3.11\"\n            pytorch: 2.9.1\n            axolotl_extras:\n    runs-on: axolotl-gpu-runner\n    steps:\n      - name: Checkout\n        uses: actions/checkout@v4\n      - name: Docker metadata\n        id: metadata\n        uses: docker/metadata-action@v5\n        with:\n          images: |\n            axolotlai/axolotl-cloud\n          tags: |\n            type=raw,value={{ branch }}-{{ date 'YYYYMMDD' }}\n      - name: Login to Docker Hub\n        uses: docker/login-action@v3\n        with:\n          username: ${{ secrets.DOCKERHUB_USERNAME }}\n          password: ${{ secrets.DOCKERHUB_TOKEN }}\n      - name: Set up Docker Buildx\n        uses: docker/setup-buildx-action@v3\n      - name: Build\n        uses: docker/build-push-action@v5\n        with:\n          context: .\n          build-args: |\n            BASE_TAG=${{ github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}\n            CUDA=${{ matrix.cuda }}\n          file: ./docker/Dockerfile-cloud\n          push: ${{ github.event_name != 'pull_request' }}\n          tags: |\n             ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}\n          labels: ${{ steps.metadata.outputs.labels }}\n"
  },
  {
    "path": ".github/workflows/precommit-autoupdate.yml",
    "content": "name: Pre-commit auto-update\n\non:\n  schedule:\n    - cron: '0 0 1 * *'  # Run monthly\n  workflow_dispatch:  # Manual kickoff\n\npermissions: {}\n\njobs:\n  auto-update:\n    runs-on: ubuntu-latest\n    permissions:\n      contents: write\n      pull-requests: write\n    steps:\n      - uses: actions/checkout@v4\n\n      - uses: actions/setup-python@v5\n        with:\n          python-version: '3.11'\n\n      - name: Update pre-commit hooks\n        id: update\n        run: |\n          pip install pre-commit\n          pre-commit autoupdate\n          if [[ -n $(git status --porcelain) ]]; then\n            echo \"changes=true\" >> $GITHUB_OUTPUT\n          fi\n\n      - name: Create Pull Request\n        if: steps.update.outputs.changes == 'true'\n        uses: peter-evans/create-pull-request@v6\n        with:\n          token: ${{ secrets.GITHUB_TOKEN }}\n          branch: update/pre-commit-hooks\n          delete-branch: true\n          title: \"chore: update pre-commit hooks\"\n          commit-message: \"chore: update pre-commit hooks\"\n          body: |\n            Automated PR to update pre-commit hooks to their latest versions.\n"
  },
  {
    "path": ".github/workflows/preview-docs.yml",
    "content": "name: Preview\non:\n  workflow_dispatch:\n  pull_request:\n    types: [opened, synchronize, reopened, ready_for_review]\n\n    # Run the workflow only when one of these files changes\n    paths:\n      - '**/*.md'      # any Markdown file\n      - '**/*.qmd'     # any Quarto file\n      - '_quarto.yml'\n      - docs/scripts/generate_config_docs.py\n      - src/axolotl/utils/schemas/**.py\n      - .github/workflows/preview-docs.yml\n\npermissions:\n  contents: read\n  pull-requests: write\n\njobs:\n  preview:\n    runs-on: ubuntu-latest\n    if: ${{ !github.event.pull_request.draft }}\n    steps:\n      - name: cleanup node\n        run: |\n          sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc /opt/hostedtoolcache/CodeQL\n\n      - name: Check out repository\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ github.event.pull_request.head.sha }}\n\n      - name: Set up Quarto\n        uses: quarto-dev/quarto-actions/setup@v2\n\n      - name: Setup Python\n        uses: actions/setup-python@v5\n        with:\n          python-version: '3.11'\n\n      - name: Install dependencies\n        run: |\n          python3 -m pip install jupyter quartodoc\n          python3 -m pip install -e .\n\n      - name: Build autodoc\n        run: quartodoc build\n\n      - name: Quarto render\n        run: quarto render\n\n      - name: Netlify Publish\n        uses: nwtgck/actions-netlify@v3.0\n        if: ${{ github.event.pull_request.head.repo.full_name == github.repository }}\n        id: netlify\n        with:\n          publish-dir: './_site'\n          enable-pull-request-comment: false\n          enable-github-deployment: false\n          github-token: ${{ secrets.GITHUB_TOKEN }}\n          deploy-message: \"Deployed On Netlify\"\n          github-deployment-environment: 'preview'\n          github-deployment-description: 'Preview Deployment'\n        env:\n          NETLIFY_AUTH_TOKEN: ${{ secrets.NETLIFY_AUTH_TOKEN }}\n          NETLIFY_SITE_ID: ${{ secrets.NETLIFY_SITE_ID }}\n\n      - name: Update PR with preview link\n        if: ${{ steps.netlify.outcome == 'success' }}\n        uses: marocchino/sticky-pull-request-comment@v2\n        with:\n          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}\n          message: |\n            📖 **Documentation Preview**: ${{ steps.netlify.outputs.deploy-url }}\n\n            Deployed on Netlify from commit ${{ github.event.pull_request.head.sha }}\n"
  },
  {
    "path": ".github/workflows/pypi.yml",
    "content": "name: publish pypi\n\non:\n  push:\n    tags:\n      - \"v*\"\n  workflow_dispatch:\n\npermissions: {}\n\njobs:\n  setup_release:\n    name: Create Release\n    runs-on: ubuntu-latest\n    permissions:\n      contents: write\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n\n      - name: Create release\n        env:\n          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}\n        run: gh release create \"$GITHUB_REF_NAME\" --generate-notes\n  pypi-publish:\n    name: Upload release to PyPI\n    runs-on: ubuntu-latest\n    needs: [setup_release]\n    environment:\n      name: pypi\n      url: https://pypi.org/p/axolotl\n    permissions:\n      contents: read\n      id-token: write # IMPORTANT: this permission is mandatory for trusted publishing\n    steps:\n      - name: Check out repository code\n        uses: actions/checkout@v4\n\n      - name: Setup Python\n        uses: actions/setup-python@v5\n        with:\n          python-version: \"3.11\"\n\n      - name: Install dependencies\n        run: |\n          pip3 install wheel packaging==26.0\n          pip3 install --no-build-isolation -e .\n          pip3 install -r requirements-dev.txt -r requirements-tests.txt\n\n      - name: Extract tag name\n        id: tag\n        run: echo \"TAG_NAME=$(echo $GITHUB_REF | cut -d / -f 3)\" >> \"$GITHUB_OUTPUT\"\n\n      - name: Update version in VERSION file\n        run: |\n          echo \"${{ steps.tag.outputs.TAG_NAME }}\" | sed 's/^v//' > VERSION\n\n      - name: Build a source dist\n        run: |\n          python setup.py sdist\n\n      - name: Publish package distributions to PyPI\n        uses: pypa/gh-action-pypi-publish@release/v1\n"
  },
  {
    "path": ".github/workflows/tests-nightly.yml",
    "content": "name: Tests Nightly against upstream main\non:\n  workflow_dispatch:\n  schedule:\n    - cron: '0 0 * * *'  # Runs at 00:00 UTC every day\n  pull_request:\n    types: [opened, synchronize, reopened, ready_for_review]\n    paths:\n      - '.github/workflows/tests-nightly.yml'\n\npermissions:\n  contents: read\n\njobs:\n  pre-commit:\n    name: pre-commit\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/checkout@v4\n      - uses: actions/setup-python@v5\n        with:\n          python-version: \"3.11\"\n          cache: 'pip' # caching pip dependencies\n      - uses: pre-commit/action@v3.0.1\n        env:\n          SKIP: no-commit-to-branch\n\n  prime-cdn-s3-cache:\n    name: Prefetch S3 once to prime the CDN cache\n    runs-on: ubuntu-latest\n    if: ${{ !github.event.pull_request.draft }}\n    timeout-minutes: 10\n    steps:\n      - name: Restore Cache from S3\n        id: hf-cache-restore-s3\n        run: |\n          curl -v -H \"Range: bytes=0-1023\" -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst > /dev/null\n\n  pytest:\n    name: PyTest\n    runs-on: ubuntu-latest\n    needs: [prime-cdn-s3-cache]\n    strategy:\n      fail-fast: false\n      matrix:\n        python_version: [\"3.12\"]  # TODO include py3.14 once https://github.com/mistralai/mistral-common/pull/194 is merged\n        pytorch_version: [\"2.8.0\", \"2.9.1\", \"2.10.0\"]\n    timeout-minutes: 20\n\n    steps:\n      - name: Check out repository code\n        uses: actions/checkout@v4\n\n      - name: Restore Cache from S3\n        id: hf-cache-restore-s3\n        run: |\n          mkdir -p /home/runner/.cache/huggingface/hub\n          curl -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/  --use-compress-program unzstd\n\n      - name: Setup Python\n        uses: actions/setup-python@v5\n        with:\n          python-version: ${{ matrix.python_version }}\n          cache: 'pip' # caching pip dependencies\n\n      - name: upgrade pip\n        run: |\n          pip3 install --upgrade pip\n          pip3 install --upgrade packaging==26.0 setuptools==78.1.1 wheel\n\n      - name: Install PyTorch\n        run: |\n          pip3 install torch==${{ matrix.pytorch_version }} torchvision\n\n      - name: Update requirements.txt\n        run: |\n          sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt\n          sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt\n          sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt\n          sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt\n          sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt\n\n      - name: Install dependencies\n        run: |\n          pip3 show torch\n          pip3 install --no-build-isolation -U -e .\n          python scripts/unsloth_install.py | sh\n          python scripts/cutcrossentropy_install.py | sh\n          pip3 install -r requirements-dev.txt -r requirements-tests.txt\n\n      - name: Make sure PyTorch version wasn't clobbered\n        run: |\n          python -c \"import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__\"\n\n      - name: Ensure axolotl CLI was installed\n        run: |\n          axolotl --help\n\n      - name: Run tests\n        run: |\n          pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/\n          pytest -v --durations=10 tests/patched/\n          pytest -v --durations=10 tests/cli/\n\n      - name: cleanup pip cache\n        run: |\n          find \"$(pip cache dir)/http-v2\" -type f -mtime +14 -exec rm {} \\;\n\n  docker-e2e-tests:\n    if: github.repository_owner == 'axolotl-ai-cloud'\n    # this job needs to be run on self-hosted GPU runners...\n    runs-on: [self-hosted, modal]\n    timeout-minutes: 120\n    needs: [pre-commit, pytest]\n\n    strategy:\n      fail-fast: false\n      matrix:\n        include:\n          - cuda: 128\n            cuda_version: 12.8.1\n            python_version: \"3.11\"\n            pytorch: 2.9.1\n            num_gpus: 1\n            axolotl_extras:\n            nightly_build: \"true\"\n          - cuda: 128\n            cuda_version: 12.8.1\n            python_version: \"3.11\"\n            pytorch: 2.10.0\n            num_gpus: 1\n            axolotl_extras:\n          - cuda: 130\n            cuda_version: 13.0.0\n            python_version: \"3.12\"\n            pytorch: 2.9.1\n            num_gpus: 1\n            axolotl_extras:\n            dockerfile: \"Dockerfile-uv.jinja\"\n            nightly_build: \"true\"\n    steps:\n      - name: Checkout\n        uses: actions/checkout@v4\n      - name: Install Python\n        uses: actions/setup-python@v5\n        with:\n          python-version: \"3.11\"\n      - name: Install Modal\n        run: |\n          python -m pip install --upgrade pip\n          pip install modal==1.3.0.post1 jinja2\n      - name: Update env vars\n        run: |\n          echo \"BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}\" >> $GITHUB_ENV\n          echo \"PYTORCH_VERSION=${{ matrix.pytorch}}\" >> $GITHUB_ENV\n          echo \"AXOLOTL_ARGS=${{ matrix.axolotl_args}}\" >> $GITHUB_ENV\n          echo \"AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}\" >> $GITHUB_ENV\n          echo \"CUDA=${{ matrix.cuda }}\" >> $GITHUB_ENV\n          echo \"N_GPUS=${{ matrix.num_gpus }}\" >> $GITHUB_ENV\n          echo \"E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}\" >> $GITHUB_ENV\n          echo \"NIGHTLY_BUILD=${{ matrix.nightly_build }}\" >> $GITHUB_ENV\n      - name: Run tests job on Modal\n        env:\n          CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}\n        run: |\n          modal run cicd.e2e_tests\n  docker-e2e-multigpu-tests:\n    if: github.repository_owner == 'axolotl-ai-cloud'\n    # this job needs to be run on self-hosted GPU runners...\n    runs-on: [self-hosted, modal]\n    timeout-minutes: 120\n    needs: [pre-commit, pytest, docker-e2e-tests]\n\n    strategy:\n      fail-fast: false\n      matrix:\n        include:\n          - cuda: 128\n            cuda_version: 12.8.1\n            python_version: \"3.11\"\n            pytorch: 2.9.1\n            num_gpus: 2\n            axolotl_extras:\n            nightly_build: \"true\"\n    steps:\n      - name: Checkout\n        uses: actions/checkout@v4\n      - name: Install Python\n        uses: actions/setup-python@v5\n        with:\n          python-version: \"3.11\"\n      - name: Install Modal\n        run: |\n          python -m pip install --upgrade pip\n          pip install modal==1.3.0.post1 jinja2\n      - name: Update env vars\n        run: |\n          echo \"BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}\" >> $GITHUB_ENV\n          echo \"PYTORCH_VERSION=${{ matrix.pytorch}}\" >> $GITHUB_ENV\n          echo \"AXOLOTL_ARGS=${{ matrix.axolotl_args}}\" >> $GITHUB_ENV\n          echo \"AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}\" >> $GITHUB_ENV\n          echo \"CUDA=${{ matrix.cuda }}\" >> $GITHUB_ENV\n          echo \"N_GPUS=${{ matrix.num_gpus }}\" >> $GITHUB_ENV\n          echo \"NIGHTLY_BUILD=${{ matrix.nightly_build }}\" >> $GITHUB_ENV\n      - name: Run tests job on Modal\n        env:\n          CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}\n        run: |\n          modal run cicd.multigpu\n"
  },
  {
    "path": ".github/workflows/tests.yml",
    "content": "name: Tests\non:\n  # check on push/merge to main, PRs, and manual triggers\n  merge_group:\n  push:\n    branches:\n      - \"main\"\n    paths:\n      - '**.py'\n      - 'requirements.txt'\n      - '.github/workflows/*.yml'\n      - 'requirements-tests.txt'\n      - 'cicd/cicd.sh'\n      - 'cicd/Dockerfile.jinja'\n  pull_request:\n      types: [opened, synchronize, reopened, ready_for_review]\n      paths:\n       - '**.py'\n       - 'requirements.txt'\n       - '.github/workflows/*.yml'\n       - 'requirements-tests.txt'\n       - 'cicd/cicd.sh'\n       - 'cicd/Dockerfile.jinja'\n  workflow_dispatch:\n\n# Cancel jobs on the same ref if a new one is triggered\nconcurrency:\n  group: ${{ github.workflow }}-${{ github.ref }}\n  cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}\n\npermissions:\n  contents: read\n\nenv:\n  TRANSFORMERS_IS_CI: \"yes\"\n\njobs:\n  pre-commit:\n    name: pre-commit\n    runs-on: ubuntu-latest\n    if: ${{ !github.event.pull_request.draft }}\n    steps:\n      - uses: actions/checkout@v4\n      - uses: actions/setup-python@v5\n        with:\n          python-version: \"3.11\"\n          cache: 'pip' # caching pip dependencies\n      - uses: pre-commit/action@v3.0.1\n        env:\n          SKIP: no-commit-to-branch\n\n  prime-cdn-s3-cache:\n    name: Prefetch S3 once to prime the CDN cache\n    runs-on: ubuntu-latest\n    if: ${{ !github.event.pull_request.draft }}\n    timeout-minutes: 10\n    steps:\n      - name: Restore Cache from S3\n        id: hf-cache-restore-s3\n        run: |\n          curl -v -H \"Range: bytes=0-1023\" -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst > /dev/null\n\n  pytest:\n    name: PyTest\n    runs-on: ubuntu-latest\n    if: ${{ !github.event.pull_request.draft }}\n    needs: [prime-cdn-s3-cache]\n    strategy:\n      fail-fast: false\n      matrix:\n        python_version: [\"3.12\"]  # TODO include py3.14 once https://github.com/mistralai/mistral-common/pull/194 is merged\n        pytorch_version: [\"2.8.0\", \"2.9.1\", \"2.10.0\"]\n#        exclude:\n#          - python_version: \"3.14\"\n#            pytorch_version: \"2.8.0\"\n#          - python_version: \"3.14\"\n#            pytorch_version: \"2.9.1\"\n    timeout-minutes: 20\n\n    steps:\n      - name: cleanup node\n        run: |\n          sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc /opt/hostedtoolcache/CodeQL\n\n      - name: Check out repository code\n        uses: actions/checkout@v4\n\n      - name: Restore Cache from S3\n        id: hf-cache-restore-s3\n        run: |\n          mkdir -p ~/.cache/huggingface/hub\n          curl -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xpf - -C ~/.cache/huggingface/hub/  --use-compress-program unzstd --strip-components=1\n          ls -ltr ~/.cache/huggingface/hub/\n\n      - name: Setup Python\n        uses: actions/setup-python@v5\n        with:\n          python-version: ${{ matrix.python_version }}\n          cache: 'pip' # caching pip dependencies\n\n      - name: upgrade pip\n        run: |\n          pip3 install --upgrade pip\n          pip3 install --upgrade packaging==26.0 setuptools==75.8.0 wheel\n\n      - name: Install PyTorch\n        run: |\n          pip3 install --no-cache-dir torch==${{ matrix.pytorch_version }} torchvision\n\n      - name: Install dependencies\n        run: |\n          pip3 show torch\n          pip3 install --no-cache-dir --no-build-isolation -U -e .\n          python scripts/unsloth_install.py | sh\n          python scripts/cutcrossentropy_install.py | sh\n          pip3 install -r requirements-dev.txt -r requirements-tests.txt\n\n      - name: cleanup pip cache\n        run: |\n          find \"$(pip cache dir)/http-v2\" -type f -mtime +14 -exec rm {} \\;\n\n      - name: Make sure PyTorch version wasn't clobbered\n        run: |\n          python -c \"import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__\"\n\n      - name: Ensure axolotl CLI was installed\n        run: |\n          axolotl --help\n\n      - name: Pre-Download dataset fixture\n        run: |\n          hf download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures\n\n      - name: Show HF cache\n        run: hf cache ls\n\n      - name: Run tests\n        run: |\n          df -h\n          pytest -v --durations=10 -n4 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml\n          df -h\n          pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml\n          df -h\n          pytest -v --durations=10 tests/patched/ --cov=axolotl --cov-append --cov-report=xml\n          df -h\n          pytest -v --durations=10 tests/cli/ --cov=axolotl --cov-append --cov-report=xml\n\n      - name: Show HF cache\n        run: hf cache ls\n\n      - name: Upload coverage to Codecov\n        uses: codecov/codecov-action@v5\n        with:\n          token: ${{ secrets.CODECOV_TOKEN }}\n          files: ./coverage.xml\n          flags: unittests,pytorch-${{ matrix.pytorch_version }}\n          fail_ci_if_error: false\n\n  pytest-sdist:\n    name: PyTest from Source Dist\n    runs-on: ubuntu-latest\n    if: ${{ !github.event.pull_request.draft }}\n    needs: [prime-cdn-s3-cache]\n    strategy:\n      fail-fast: false\n      matrix:\n        python_version: [\"3.12\"]  # TODO include py3.14 once https://github.com/mistralai/mistral-common/pull/194 is merged\n        pytorch_version: [\"2.8.0\", \"2.9.1\", \"2.10.0\"]\n#        exclude:\n#          - python_version: \"3.14\"\n#            pytorch_version: \"2.8.0\"\n#          - python_version: \"3.14\"\n#            pytorch_version: \"2.9.1\"\n    timeout-minutes: 30\n\n    steps:\n      - name: cleanup node\n        run: |\n          sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc /opt/hostedtoolcache/CodeQL\n\n      - name: Check out repository code\n        uses: actions/checkout@v4\n\n      - name: Restore Cache from S3\n        id: hf-cache-restore-s3\n        run: |\n          mkdir -p ~/.cache/huggingface/hub\n          curl -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xpf - -C ~/.cache/huggingface/hub/  --use-compress-program unzstd --strip-components=1\n          ls -ltr ~/.cache/huggingface/hub/\n\n      - name: Setup Python\n        uses: actions/setup-python@v5\n        with:\n          python-version: ${{ matrix.python_version }}\n          cache: 'pip' # caching pip dependencies\n\n      - name: upgrade pip\n        run: |\n          pip3 install --upgrade pip\n          pip3 install --upgrade packaging==26.0 setuptools==75.8.0 setuptools_scm build wheel psutil\n\n      - name: Install PyTorch\n        run: |\n          pip3 install --no-cache-dir torch==${{ matrix.pytorch_version }} torchvision\n\n      - name: Install dependencies\n        run: |\n          pip3 show torch\n          python -m build --no-isolation --sdist\n          pip3 install --no-cache-dir --no-build-isolation dist/axolotl*.tar.gz\n          python scripts/unsloth_install.py | sh\n          python scripts/cutcrossentropy_install.py | sh\n          pip3 install -r requirements-dev.txt -r requirements-tests.txt\n\n      - name: cleanup pip cache\n        run: |\n          find \"$(pip cache dir)/http-v2\" -type f -mtime +14 -exec rm {} \\;\n\n      - name: Make sure PyTorch version wasn't clobbered\n        run: |\n          python -c \"import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__\"\n\n      - name: Ensure axolotl CLI was installed\n        run: |\n          axolotl --help\n\n      - name: Show HF cache\n        run: hf cache ls\n\n      - name: Run tests\n        run: |\n          pytest -v --durations=10 -n4 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml\n          pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml\n          pytest -v --durations=10 tests/cli/\n\n      - name: Show HF cache\n        run: hf cache ls\n\n  gate-skip-e2e:\n    needs: [pre-commit]\n    runs-on: ubuntu-latest\n    outputs:\n      skip: ${{ steps.compute.outputs.skip }}\n    steps:\n      - uses: actions/github-script@v7\n        id: compute\n        with:\n          script: |\n            const token = /\\[skip-e2e\\]/i;\n            let msg = '';\n            if (context.eventName === 'push') {\n              msg = context.payload.head_commit?.message || '';\n            } else if (context.eventName === 'pull_request') {\n              const { owner, repo } = context.repo;\n              const prNumber = context.payload.pull_request.number;\n              const commits = await github.paginate(\n                github.rest.pulls.listCommits,\n                { owner, repo, pull_number: prNumber, per_page: 100 }\n              );\n              msg = commits.at(-1)?.commit?.message || '';\n            }\n            const title = context.payload.pull_request?.title || '';\n            const body  = context.payload.pull_request?.body  || '';\n            const skip = token.test(msg) || token.test(title) || token.test(body);\n            core.setOutput('skip', String(skip));\n\n  docker-e2e-tests-1st:\n    # Run this job first as a gate for running the remainder of the test matrix\n    if: >\n      github.repository_owner == 'axolotl-ai-cloud' &&\n      (github.event_name != 'pull_request' || !github.event.pull_request.draft) &&\n      needs.gate-skip-e2e.outputs.skip != 'true'\n    # this job needs to be run on self-hosted GPU runners...\n    runs-on: [self-hosted, modal]\n    timeout-minutes: 120\n    needs: [pre-commit, pytest]\n\n    strategy:\n      fail-fast: false\n      matrix:\n        include:\n          - cuda: 130\n            cuda_version: 13.0.0\n            python_version: \"3.12\"\n            pytorch: 2.9.1\n            num_gpus: 1\n            axolotl_extras:\n            dockerfile: \"Dockerfile-uv.jinja\"\n    steps:\n      - name: Checkout\n        uses: actions/checkout@v4\n      - name: Install Python\n        uses: actions/setup-python@v5\n        with:\n          python-version: \"3.11\"\n      - name: Install Modal\n        run: |\n          python -m pip install --upgrade pip\n          pip install modal==1.3.0.post1 jinja2\n      - name: Update env vars\n        run: |\n          echo \"BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}\" >> $GITHUB_ENV\n          echo \"PYTORCH_VERSION=${{ matrix.pytorch}}\" >> $GITHUB_ENV\n          echo \"AXOLOTL_ARGS=${{ matrix.axolotl_args}}\" >> $GITHUB_ENV\n          echo \"AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}\" >> $GITHUB_ENV\n          echo \"CUDA=${{ matrix.cuda }}\" >> $GITHUB_ENV\n          echo \"MODAL_IMAGE_BUILDER_VERSION=2024.10\" >> $GITHUB_ENV\n          echo \"N_GPUS=${{ matrix.num_gpus }}\" >> $GITHUB_ENV\n          echo \"E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}\" >> $GITHUB_ENV\n      - name: Run tests job on Modal\n        env:\n          CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}\n        run: |\n          modal run cicd.e2e_tests\n\n  docker-e2e-tests:\n    if: >\n      github.repository_owner == 'axolotl-ai-cloud' &&\n      (github.event_name != 'pull_request' || !github.event.pull_request.draft) &&\n      needs.gate-skip-e2e.outputs.skip != 'true'\n    # this job needs to be run on self-hosted GPU runners...\n    runs-on: [self-hosted, modal]\n    timeout-minutes: 120\n    # Only run the remainder of the matrix if the first e2e check passed;\n    # this is to save on wasted compute costs for known failures that get caught in the first run\n    needs: [pre-commit, pytest, gate-skip-e2e, docker-e2e-tests-1st]\n\n    strategy:\n      fail-fast: false\n      matrix:\n        include:\n          - cuda: 128\n            cuda_version: 12.8.1\n            python_version: \"3.11\"\n            pytorch: 2.8.0\n            num_gpus: 1\n            gpu_type: \"B200\"\n            axolotl_extras: fbgemm-gpu\n          - cuda: 128\n            cuda_version: 12.8.1\n            python_version: \"3.11\"\n            pytorch: 2.9.1\n            num_gpus: 1\n            axolotl_extras:\n          - cuda: 128\n            cuda_version: 12.8.1\n            python_version: \"3.11\"\n            pytorch: 2.10.0\n            num_gpus: 1\n            axolotl_extras:\n          - cuda: 130\n            cuda_version: 13.0.0\n            python_version: \"3.11\"\n            pytorch: 2.9.1\n            num_gpus: 1\n            axolotl_extras:\n    steps:\n      - name: Checkout\n        uses: actions/checkout@v4\n      - name: Install Python\n        uses: actions/setup-python@v5\n        with:\n          python-version: \"3.11\"\n      - name: Install Modal\n        run: |\n          python -m pip install --upgrade pip\n          pip install modal==1.3.0.post1 jinja2\n      - name: Update env vars\n        run: |\n          echo \"BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}\" >> $GITHUB_ENV\n          echo \"PYTORCH_VERSION=${{ matrix.pytorch}}\" >> $GITHUB_ENV\n          echo \"AXOLOTL_ARGS=${{ matrix.axolotl_args}}\" >> $GITHUB_ENV\n          echo \"AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}\" >> $GITHUB_ENV\n          echo \"CUDA=${{ matrix.cuda }}\" >> $GITHUB_ENV\n          echo \"MODAL_IMAGE_BUILDER_VERSION=2024.10\" >> $GITHUB_ENV\n          echo \"N_GPUS=${{ matrix.num_gpus }}\" >> $GITHUB_ENV\n          echo \"GPU_TYPE=${{ matrix.gpu_type || 'L40S'}}\" >> $GITHUB_ENV\n          echo \"E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}\" >> $GITHUB_ENV\n      - name: Run tests job on Modal\n        env:\n          CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}\n        run: |\n          modal run cicd.e2e_tests\n\n  docker-e2e-cleanup:\n    runs-on: [self-hosted, modal]\n    timeout-minutes: 90\n    needs: [docker-e2e-tests]\n    if: ${{ !github.event.pull_request.draft }}\n\n    strategy:\n      fail-fast: false\n      matrix:\n        include:\n          - cuda: 128\n            cuda_version: 12.8.1\n            python_version: \"3.11\"\n            pytorch: 2.9.1\n            num_gpus: 1\n            axolotl_extras:\n    steps:\n      - name: Checkout\n        uses: actions/checkout@v4\n      - name: Install Python\n        uses: actions/setup-python@v5\n        with:\n          python-version: \"3.11\"\n      - name: Install Modal\n        run: |\n          python -m pip install --upgrade pip\n          pip install modal==1.3.0.post1 jinja2\n      - name: Update env vars\n        run: |\n          echo \"BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}\" >> $GITHUB_ENV\n          echo \"PYTORCH_VERSION=${{ matrix.pytorch}}\" >> $GITHUB_ENV\n          echo \"AXOLOTL_ARGS=${{ matrix.axolotl_args}}\" >> $GITHUB_ENV\n          echo \"AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}\" >> $GITHUB_ENV\n          echo \"CUDA=${{ matrix.cuda }}\" >> $GITHUB_ENV\n          echo \"MODAL_IMAGE_BUILDER_VERSION=2024.10\" >> $GITHUB_ENV\n          echo \"N_GPUS=${{ matrix.num_gpus }}\" >> $GITHUB_ENV\n      - name: Run tests job on Modal\n        run: |\n          modal run cicd.cleanup\n"
  },
  {
    "path": ".gitignore",
    "content": "**/axolotl.egg-info\nconfigs\nlast_run_prepared/\noutputs\n.vscode\n_site/\n\n# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\ncover/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\n.pybuilder/\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n#   For a library or package, you might want to ignore these files since the code is\n#   intended to run in multiple environments; otherwise, check them in:\n# .python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# poetry\n#   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.\n#   This is especially recommended for binary packages to ensure reproducibility, and is more\n#   commonly ignored for libraries.\n#   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control\n#poetry.lock\n\n# pdm\n#   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.\n#pdm.lock\n#   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it\n#   in version control.\n#   https://pdm.fming.dev/#use-with-ide\n.pdm.toml\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\nvenv3.10/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n\n# pytype static type analyzer\n.pytype/\n\n# Cython debug symbols\ncython_debug/\n\n# PyCharm\n#  JetBrains specific template is maintained in a separate JetBrains.gitignore that can\n#  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore\n#  and can be added to the global gitignore or merged into this file.  For a more nuclear\n#  option (not recommended) you can uncomment the following to ignore the entire idea folder.\n.idea/\n\n# WandB\n# wandb creates a folder to store logs for training runs\nwandb\n\n# Runs\nlora-out/*\nqlora-out/*\nmlruns/*\n\n/.quarto/\nprepared-datasets/\nsubmit.sh\n*.out*\n\n# Quartodoc generated files\nobjects.json\nsite_libs/\n\ntypings/\nout/\n\n# vim\n*.swp\n\n# scm auto-versioning\nsrc/axolotl/_version.py\n"
  },
  {
    "path": ".mypy.ini",
    "content": "[mypy]\nplugins = pydantic.mypy\nexclude = venv\n\n[mypy-alpaca_lora_4bit.*]\nignore_missing_imports = True\n\n[mypy-axolotl.monkeypatch.*]\nignore_errors = True\n\n[mypy-axolotl.models.mixtral.*]\nignore_errors = True\n\n[mypy-axolotl.integrations.liger.models.*]\nignore_errors = True\n\n[mypy-axolotl.models.phi.*]\nignore_errors = True\n\n[mypy-flash_attn.*]\nignore_missing_imports = True\n\n[mypy-huggingface_hub]\nignore_missing_imports = True\n\n[mypy-transformers.*]\nignore_missing_imports = True\n\n[mypy-peft]\nignore_missing_imports = True\n\n[mypy-wandb]\nignore_missing_imports = True\n\n[mypy-bitsandbytes]\nignore_missing_imports = True\n\n[mypy-requests]\nignore_missing_imports = True\n\n[mypy-datasets]\nignore_missing_imports = True\n\n[mypy-fire]\nignore_missing_imports = True\n\n[mypy-setuptools]\nignore_missing_imports = True\n\n[mypy-addict]\nignore_missing_imports = True\n\n[mypy-xformers.*]\nignore_missing_imports = True\n"
  },
  {
    "path": ".pre-commit-config.yaml",
    "content": "default_language_version:\n    python: python3\n\nrepos:\n-   repo: https://github.com/pre-commit/pre-commit-hooks\n    rev: v6.0.0\n    hooks:\n    -   id: check-yaml\n    -   id: end-of-file-fixer\n    -   id: trailing-whitespace\n    -   id: no-commit-to-branch\n        args: ['--branch', 'main']\n-   repo: https://github.com/astral-sh/ruff-pre-commit\n    rev: v0.15.4\n    hooks:\n    -   id: ruff\n        args: [--fix]\n    -   id: ruff-format\n-   repo: https://github.com/pre-commit/mirrors-mypy\n    rev: v1.19.1\n    hooks:\n    - id: mypy\n      additional_dependencies:\n        [\n            'types-PyYAML',\n            'pydantic>=2.5.3',\n        ]\n-   repo: https://github.com/PyCQA/bandit\n    rev: 1.9.4\n    hooks:\n    -   id: bandit\n        args: [\n            '--ini',\n            '.bandit',\n        ]\n"
  },
  {
    "path": ".runpod/.gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\ncover/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\n.pybuilder/\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n#   For a library or package, you might want to ignore these files since the code is\n#   intended to run in multiple environments; otherwise, check them in:\n# .python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# poetry\n#   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.\n#   This is especially recommended for binary packages to ensure reproducibility, and is more\n#   commonly ignored for libraries.\n#   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control\n#poetry.lock\n\n# pdm\n#   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.\n#pdm.lock\n#   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it\n#   in version control.\n#   https://pdm.fming.dev/#use-with-ide\n.pdm.toml\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n\n# pytype static type analyzer\n.pytype/\n\n# Cython debug symbols\ncython_debug/\n\n# PyCharm\n#  JetBrains specific template is maintained in a separate JetBrains.gitignore that can\n#  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore\n#  and can be added to the global gitignore or merged into this file.  For a more nuclear\n#  option (not recommended) you can uncomment the following to ignore the entire idea folder.\n#.idea/\npod/scripts/config.yaml\n"
  },
  {
    "path": ".runpod/Dockerfile",
    "content": "FROM axolotlai/axolotl-cloud:main-py3.11-cu124-2.6.0\n\nCOPY .runpod/requirements.txt /requirements.txt\nRUN --mount=type=cache,target=/root/.cache/pip \\\n    python3 -m pip install --upgrade pip && \\\n    python3 -m pip install --upgrade -r /requirements.txt\n\n# Environment settings\nARG BASE_VOLUME=\"/runpod-volume\"\nENV BASE_VOLUME=$BASE_VOLUME\nENV HF_DATASETS_CACHE=\"${BASE_VOLUME}/huggingface-cache/datasets\"\nENV HUGGINGFACE_HUB_CACHE=\"${BASE_VOLUME}/huggingface-cache/hub\"\nENV HF_HUB_CACHE=\"${BASE_VOLUME}/huggingface-cache/hub\"\nENV TRANSFORMERS_CACHE=\"${BASE_VOLUME}/huggingface-cache/hub\"\n\nCOPY .runpod/src /src\n\nWORKDIR /src\nCMD [\"python3\", \"/src/handler.py\"]\n"
  },
  {
    "path": ".runpod/README.md",
    "content": "<h1>LLM Post Training- Full fine-tune, LoRA, QLoRa etc. Llama/Mistral/Gemma and more</h1>\n\n# Configuration Options\n\nThis document outlines all available configuration options for training models. The configuration can be provided as a JSON request.\n\n## Usage\n\nYou can use these configuration Options:\n\n1. As a JSON request body:\n\n```json\n{\n  \"input\": {\n    \"user_id\": \"user\",\n    \"model_id\": \"model-name\",\n    \"run_id\": \"run-id\",\n    \"credentials\": {\n      \"wandb_api_key\": \"\", # add your Weights & biases key. TODO:  you will be able to set this in Enviornment variables.\n      \"hf_token\": \"\", # add your HF_token. TODO:  you will be able to set this in Enviornment variables.\n    },\n    \"args\": {\n      \"base_model\": \"NousResearch/Llama-3.2-1B\",\n      // ... other options\n    }\n  }\n}\n```\n\n## Configuration Options\n\n### Model Configuration\n\n| Option              | Description                                                                                   | Default              |\n| ------------------- | --------------------------------------------------------------------------------------------- | -------------------- |\n| `base_model`        | Path to the base model (local or HuggingFace)                                                 | Required             |\n| `base_model_config` | Configuration path for the base model                                                         | Same as base_model   |\n| `revision_of_model` | Specific model revision from HuggingFace hub                                                  | Latest               |\n| `tokenizer_config`  | Custom tokenizer configuration path                                                           | Optional             |\n| `model_type`        | Type of model to load                                                                         | AutoModelForCausalLM |\n| `tokenizer_type`    | Type of tokenizer to use                                                                      | AutoTokenizer        |\n| `hub_model_id`      | Repository ID where the model will be pushed on Hugging Face Hub (format: username/repo-name) | Optional             |\n\n## Model Family Identification\n\n| Option                     | Default | Description                    |\n| -------------------------- | ------- | ------------------------------ |\n| `is_falcon_derived_model`  | `false` | Whether model is Falcon-based  |\n| `is_llama_derived_model`   | `false` | Whether model is LLaMA-based   |\n| `is_qwen_derived_model`    | `false` | Whether model is Qwen-based    |\n| `is_mistral_derived_model` | `false` | Whether model is Mistral-based |\n\n## Model Configuration Overrides\n\n| Option                                          | Default    | Description                        |\n| ----------------------------------------------- | ---------- | ---------------------------------- |\n| `overrides_of_model_config.rope_scaling.type`   | `\"linear\"` | RoPE scaling type (linear/dynamic) |\n| `overrides_of_model_config.rope_scaling.factor` | `1.0`      | RoPE scaling factor                |\n\n### Model Loading Options\n\n| Option         | Description                   | Default |\n| -------------- | ----------------------------- | ------- |\n| `load_in_8bit` | Load model in 8-bit precision | false   |\n| `load_in_4bit` | Load model in 4-bit precision | false   |\n| `bf16`         | Use bfloat16 precision        | false   |\n| `fp16`         | Use float16 precision         | false   |\n| `tf32`         | Use tensor float 32 precision | false   |\n\n## Memory and Device Settings\n\n| Option             | Default   | Description             |\n| ------------------ | --------- | ----------------------- |\n| `gpu_memory_limit` | `\"20GiB\"` | GPU memory limit        |\n| `lora_on_cpu`      | `false`   | Load LoRA on CPU        |\n| `device_map`       | `\"auto\"`  | Device mapping strategy |\n| `max_memory`       | `null`    | Max memory per device   |\n\n## Training Hyperparameters\n\n| Option                        | Default   | Description                 |\n| ----------------------------- | --------- | --------------------------- |\n| `gradient_accumulation_steps` | `1`       | Gradient accumulation steps |\n| `micro_batch_size`            | `2`       | Batch size per GPU          |\n| `eval_batch_size`             | `null`    | Evaluation batch size       |\n| `num_epochs`                  | `4`       | Number of training epochs   |\n| `warmup_steps`                | `100`     | Warmup steps                |\n| `warmup_ratio`                | `0.05`    | Warmup ratio                |\n| `learning_rate`               | `0.00003` | Learning rate               |\n| `lr_quadratic_warmup`         | `false`   | Quadratic warmup            |\n| `logging_steps`               | `null`    | Logging frequency           |\n| `eval_steps`                  | `null`    | Evaluation frequency        |\n| `evals_per_epoch`             | `null`    | Evaluations per epoch       |\n| `save_strategy`               | `\"epoch\"` | Checkpoint saving strategy  |\n| `save_steps`                  | `null`    | Saving frequency            |\n| `saves_per_epoch`             | `null`    | Saves per epoch             |\n| `save_total_limit`            | `null`    | Maximum checkpoints to keep |\n| `max_steps`                   | `null`    | Maximum training steps      |\n\n### Dataset Configuration\n\n```yaml\ndatasets:\n  - path: vicgalle/alpaca-gpt4 # HuggingFace dataset or TODO: You will be able to add the local path.\n    type: alpaca # Format type (alpaca, gpteacher, oasst, etc.)\n    ds_type: json # Dataset type\n    data_files: path/to/data # Source data files\n    train_on_split: train # Dataset split to use\n```\n\n## Chat Template Settings\n\n| Option                   | Default                          | Description            |\n| ------------------------ | -------------------------------- | ---------------------- |\n| `chat_template`          | `\"tokenizer_default\"`            | Chat template type     |\n| `chat_template_jinja`    | `null`                           | Custom Jinja template  |\n| `default_system_message` | `\"You are a helpful assistant.\"` | Default system message |\n\n## Dataset Processing\n\n| Option                            | Default                    | Description                         |\n| --------------------------------- | -------------------------- | ----------------------------------- |\n| `dataset_prepared_path`           | `\"data/last_run_prepared\"` | Path for prepared dataset           |\n| `push_dataset_to_hub`             | `\"\"`                       | Push dataset to HF hub              |\n| `dataset_num_proc`                | `4`                        | Number of preprocessing processes   |\n| `dataset_keep_in_memory`          | `false`                    | Keep dataset in memory              |\n| `shuffle_merged_datasets`         | `true`                     | Shuffle merged datasets             |\n| `shuffle_before_merging_datasets` | `false`                    | Shuffle each dataset before merging |\n| `dataset_exact_deduplication`     | `true`                     | Deduplicate datasets                |\n\n## LoRA Configuration\n\n| Option                     | Default                | Description                    |\n| -------------------------- | ---------------------- | ------------------------------ |\n| `adapter`                  | `\"lora\"`               | Adapter type (lora/qlora)      |\n| `lora_model_dir`           | `\"\"`                   | Directory with pretrained LoRA |\n| `lora_r`                   | `8`                    | LoRA attention dimension       |\n| `lora_alpha`               | `16`                   | LoRA alpha parameter           |\n| `lora_dropout`             | `0.05`                 | LoRA dropout                   |\n| `lora_target_modules`      | `[\"q_proj\", \"v_proj\"]` | Modules to apply LoRA          |\n| `lora_target_linear`       | `false`                | Target all linear modules      |\n| `peft_layers_to_transform` | `[]`                   | Layers to transform            |\n| `lora_modules_to_save`     | `[]`                   | Modules to save                |\n| `lora_fan_in_fan_out`      | `false`                | Fan in/out structure           |\n\n## Optimization Settings\n\n| Option                    | Default | Description                |\n| ------------------------- | ------- | -------------------------- |\n| `train_on_inputs`         | `false` | Train on input prompts     |\n| `group_by_length`         | `false` | Group by sequence length   |\n| `gradient_checkpointing`  | `false` | Use gradient checkpointing |\n| `early_stopping_patience` | `3`     | Early stopping patience    |\n\n## Learning Rate Scheduling\n\n| Option                     | Default    | Description          |\n| -------------------------- | ---------- | -------------------- |\n| `lr_scheduler`             | `\"cosine\"` | Scheduler type       |\n| `lr_scheduler_kwargs`      | `{}`       | Scheduler parameters |\n| `cosine_min_lr_ratio`      | `null`     | Minimum LR ratio     |\n| `cosine_constant_lr_ratio` | `null`     | Constant LR ratio    |\n| `lr_div_factor`            | `null`     | LR division factor   |\n\n## Optimizer Settings\n\n| Option                 | Default      | Description         |\n| ---------------------- | ------------ | ------------------- |\n| `optimizer`            | `\"adamw_hf\"` | Optimizer choice    |\n| `optim_args`           | `{}`         | Optimizer arguments |\n| `optim_target_modules` | `[]`         | Target modules      |\n| `weight_decay`         | `null`       | Weight decay        |\n| `adam_beta1`           | `null`       | Adam beta1          |\n| `adam_beta2`           | `null`       | Adam beta2          |\n| `adam_epsilon`         | `null`       | Adam epsilon        |\n| `max_grad_norm`        | `null`       | Gradient clipping   |\n\n## Attention Implementations\n\n| Option                     | Default | Description                   |\n| -------------------------- | ------- | ----------------------------- |\n| `flash_optimum`            | `false` | Use better transformers       |\n| `xformers_attention`       | `false` | Use xformers                  |\n| `flash_attention`          | `false` | Use flash attention           |\n| `flash_attn_cross_entropy` | `false` | Flash attention cross entropy |\n| `flash_attn_rms_norm`      | `false` | Flash attention RMS norm      |\n| `flash_attn_fuse_mlp`      | `false` | Fuse MLP operations           |\n| `sdp_attention`            | `false` | Use scaled dot product        |\n| `s2_attention`             | `false` | Use shifted sparse attention  |\n\n## Tokenizer Modifications\n\n| Option           | Default | Description                  |\n| ---------------- | ------- | ---------------------------- |\n| `special_tokens` | -       | Special tokens to add/modify |\n| `tokens`         | `[]`    | Additional tokens            |\n\n## Distributed Training\n\n| Option                  | Default | Description           |\n| ----------------------- | ------- | --------------------- |\n| `fsdp`                  | `null`  | FSDP configuration    |\n| `fsdp_config`           | `null`  | FSDP config options   |\n| `deepspeed`             | `null`  | Deepspeed config path |\n| `ddp_timeout`           | `null`  | DDP timeout           |\n| `ddp_bucket_cap_mb`     | `null`  | DDP bucket capacity   |\n| `ddp_broadcast_buffers` | `null`  | DDP broadcast buffers |\n\n<details>\n<summary><h3>Example Configuration Request:</h3></summary>\n\nHere's a complete example for fine-tuning a LLaMA model using LoRA:\n\n```json\n{\n  \"input\": {\n    \"user_id\": \"user\",\n    \"model_id\": \"llama-test\",\n    \"run_id\": \"test-run\",\n    \"credentials\": {\n      \"wandb_api_key\": \"\",\n      \"hf_token\": \"\"\n    },\n    \"args\": {\n      \"base_model\": \"NousResearch/Llama-3.2-1B\",\n      \"load_in_8bit\": false,\n      \"load_in_4bit\": false,\n      \"strict\": false,\n      \"datasets\": [\n        {\n          \"path\": \"teknium/GPT4-LLM-Cleaned\",\n          \"type\": \"alpaca\"\n        }\n      ],\n      \"dataset_prepared_path\": \"last_run_prepared\",\n      \"val_set_size\": 0.1,\n      \"output_dir\": \"./outputs/lora-out\",\n      \"adapter\": \"lora\",\n      \"sequence_len\": 2048,\n      \"sample_packing\": true,\n      \"eval_sample_packing\": true,\n      \"pad_to_sequence_len\": true,\n      \"lora_r\": 16,\n      \"lora_alpha\": 32,\n      \"lora_dropout\": 0.05,\n      \"lora_target_modules\": [\n        \"gate_proj\",\n        \"down_proj\",\n        \"up_proj\",\n        \"q_proj\",\n        \"v_proj\",\n        \"k_proj\",\n        \"o_proj\"\n      ],\n      \"gradient_accumulation_steps\": 2,\n      \"micro_batch_size\": 2,\n      \"num_epochs\": 1,\n      \"optimizer\": \"adamw_8bit\",\n      \"lr_scheduler\": \"cosine\",\n      \"learning_rate\": 0.0002,\n      \"train_on_inputs\": false,\n      \"group_by_length\": false,\n      \"bf16\": \"auto\",\n      \"tf32\": false,\n      \"gradient_checkpointing\": true,\n      \"logging_steps\": 1,\n      \"flash_attention\": true,\n      \"loss_watchdog_threshold\": 5,\n      \"loss_watchdog_patience\": 3,\n      \"warmup_steps\": 10,\n      \"evals_per_epoch\": 4,\n      \"saves_per_epoch\": 1,\n      \"weight_decay\": 0,\n      \"hub_model_id\": \"runpod/llama-fr-lora\",\n      \"wandb_name\": \"test-run-1\",\n      \"wandb_project\": \"test-run-1\",\n      \"wandb_entity\": \"axo-test\",\n      \"special_tokens\": {\n        \"pad_token\": \"<|end_of_text|>\"\n      }\n    }\n  }\n}\n```\n\n</details>\n\n### Advanced Features\n\n#### Wandb Integration\n\n- `wandb_project`: Project name for Weights & Biases\n- `wandb_entity`: Team name in W&B\n- `wandb_watch`: Monitor model with W&B\n- `wandb_name`: Name of the W&B run\n- `wandb_run_id`: ID for the W&B run\n\n#### Performance Optimization\n\n- `sample_packing`: Enable efficient sequence packing\n- `eval_sample_packing`: Use sequence packing during evaluation\n- `torch_compile`: Enable PyTorch 2.0 compilation\n- `flash_attention`: Use Flash Attention implementation\n- `xformers_attention`: Use xFormers attention implementation\n\n### Available Optimizers\n\nThe following optimizers are supported:\n\n- `adamw_hf`: HuggingFace's AdamW implementation\n- `adamw_torch`: PyTorch's AdamW\n- `adamw_torch_fused`: Fused AdamW implementation\n- `adamw_torch_xla`: XLA-optimized AdamW\n- `adamw_apex_fused`: NVIDIA Apex fused AdamW\n- `adafactor`: Adafactor optimizer\n- `adamw_anyprecision`: Anyprecision AdamW\n- `adamw_bnb_8bit`: 8-bit AdamW from bitsandbytes\n- `lion_8bit`: 8-bit Lion optimizer\n- `lion_32bit`: 32-bit Lion optimizer\n- `sgd`: Stochastic Gradient Descent\n- `adagrad`: Adagrad optimizer\n\n## Notes\n\n- Set `load_in_8bit: true` or `load_in_4bit: true` for memory-efficient training\n- Enable `flash_attention: true` for faster training on modern GPUs\n- Use `gradient_checkpointing: true` to reduce memory usage\n- Adjust `micro_batch_size` and `gradient_accumulation_steps` based on your GPU memory\n\nFor more detailed information, please refer to the [documentation](https://axolotl-ai-cloud.github.io/axolotl/docs/config-reference.html).\n\n### Errors:\n\n- if you face any issues with the Flash Attention-2, Delete yoor worker and Re-start.\n"
  },
  {
    "path": ".runpod/hub.json",
    "content": "{\n  \"title\": \"Axolotl Fine-Tuning\",\n  \"description\": \"Serverless fine-tuning of open-source LLMs with Axolotl. Supports LoRA, QLoRA, DPO, and more using Hugging Face models and datasets.\",\n  \"type\": \"serverless\",\n  \"category\": \"language\",\n  \"iconUrl\": \"https://avatars.githubusercontent.com/u/167502477\",\n  \"config\": {\n    \"runsOn\": \"GPU\",\n    \"containerDiskInGb\": 200,\n    \"gpuCount\": 1,\n    \"allowedCudaVersions\": [\n      \"12.8\",\n      \"12.7\",\n      \"12.6\",\n      \"12.5\",\n      \"12.4\"\n    ],\n    \"presets\": [],\n    \"env\": [\n      {\n        \"key\": \"TOKENIZER\",\n        \"input\": {\n          \"name\": \"Tokenizer\",\n          \"type\": \"string\",\n          \"description\": \"Name or path of the Hugging Face tokenizer to use.\",\n          \"default\": \"\",\n          \"advanced\": true\n        }\n      },\n      {\n        \"key\": \"MAX_NUM_SEQS\",\n        \"input\": {\n          \"name\": \"Max Num Seqs\",\n          \"type\": \"number\",\n          \"description\": \"Maximum number of sequences per iteration.\",\n          \"default\": 256,\n          \"advanced\": true\n        }\n      },\n      {\n        \"key\": \"DISABLE_LOG_STATS\",\n        \"input\": {\n          \"name\": \"Disable Log Stats\",\n          \"type\": \"boolean\",\n          \"description\": \"Disable logging statistics.\",\n          \"default\": false,\n          \"trueValue\": \"true\",\n          \"falseValue\": \"false\"\n        }\n      },\n      {\n        \"key\": \"LOAD_FORMAT\",\n        \"input\": {\n          \"name\": \"Load Format\",\n          \"type\": \"string\",\n          \"description\": \"The format of the model weights to load.\",\n          \"default\": \"auto\",\n          \"options\": [\n            {\n              \"label\": \"auto\",\n              \"value\": \"auto\"\n            },\n            {\n              \"label\": \"pt\",\n              \"value\": \"pt\"\n            },\n            {\n              \"label\": \"safetensors\",\n              \"value\": \"safetensors\"\n            },\n            {\n              \"label\": \"npcache\",\n              \"value\": \"npcache\"\n            },\n            {\n              \"label\": \"dummy\",\n              \"value\": \"dummy\"\n            },\n            {\n              \"label\": \"tensorizer\",\n              \"value\": \"tensorizer\"\n            },\n            {\n              \"label\": \"bitsandbytes\",\n              \"value\": \"bitsandbytes\"\n            }\n          ],\n          \"advanced\": true\n        }\n      }\n    ]\n  }\n}\n"
  },
  {
    "path": ".runpod/requirements.txt",
    "content": "# Required Python packages get listed here, one per line.\n# Reccomended to lock the version number to avoid unexpected changes.\n\n# You can also install packages from a git repository, e.g.:\n# git+https://github.com/runpod/runpod-python.git\n# To learn more, see https://pip.pypa.io/en/stable/reference/requirements-file-format/\nrunpod~=1.7.0\n"
  },
  {
    "path": ".runpod/src/config/config.yaml",
    "content": "# # This is the huggingface model that contains *.pt, *.safetensors, or *.bin files\n# # This can also be a relative path to a model on disk\n# base_model: ./llama-7b-hf\n# # You can specify an ignore pattern if the model repo contains more than 1 model type (*.pt, etc)\n# base_model_ignore_patterns:\n# # If the base_model repo on hf hub doesn't include configuration .json files,\n# # You can set that here, or leave this empty to default to base_model\n# base_model_config: ./llama-7b-hf\n# # You can specify to choose a specific model revision from huggingface hub\n# model_revision:\n# # Optional tokenizer configuration override in case you want to use a different tokenizer\n# # than the one defined in the base model\n# tokenizer_config:\n# # If you want to specify the type of model to load, AutoModelForCausalLM is a good choice too\n# model_type: AutoModelForCausalLM\n# # Corresponding tokenizer for the model AutoTokenizer is a good choice\n# tokenizer_type: AutoTokenizer\n# # Trust remote code for untrusted source\n# trust_remote_code:\n# # use_fast option for tokenizer loading from_pretrained, default to True\n# tokenizer_use_fast:\n# # Whether to use the legacy tokenizer setting, defaults to True\n# tokenizer_legacy:\n# # Resize the model embeddings when new tokens are added to multiples of 32\n# # This is reported to improve training speed on some models\n# resize_token_embeddings_to_32x:\n\n# # Used to identify which the model is based on\n# is_falcon_derived_model:\n# is_llama_derived_model:\n# # Please note that if you set this to true, `padding_side` will be set to \"left\" by default\n# is_mistral_derived_model:\n# is_qwen_derived_model:\n\n# # optional overrides to the base model configuration\n# model_config:\n#   # RoPE Scaling https://github.com/huggingface/transformers/pull/24653\n#   rope_scaling:\n#     type: # linear | dynamic\n#     factor: # float\n\n# # Whether you are training a 4-bit GPTQ quantized model\n# gptq: true\n# gptq_groupsize: 128 # group size\n# gptq_model_v1: false # v1 or v2\n\n# # This will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer\n# load_in_8bit: true\n# # Use bitsandbytes 4 bit\n# load_in_4bit:\n\n# # Use CUDA bf16\n# bf16: true # bool or 'full' for `bf16_full_eval`. require >=ampere\n# # Use CUDA fp16\n# fp16: true\n# # Use CUDA tf32\n# tf32: true # require >=ampere\n\n# # No AMP (automatic mixed precision)\n# bfloat16: true # require >=ampere\n# float16: true\n\n# # A list of one or more datasets to finetune the model with\n# datasets:\n#   # HuggingFace dataset repo | s3://,gs:// path | \"json\" for local dataset, make sure to fill data_files\n#   - path: vicgalle/alpaca-gpt4\n#   # The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]\n#     type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>\n#     ds_type: # Optional[str] (json|arrow|parquet|text|csv) defines the datatype when path is a file\n#     data_files: # Optional[str] path to source data files\n#     shards: # Optional[int] number of shards to split data into\n#     name: # Optional[str] name of dataset configuration to load\n#     train_on_split: train # Optional[str] name of dataset split to load from\n\n#     # Optional[str] fastchat conversation type, only used with type: sharegpt\n#     conversation:  # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py\n#     field_human: # Optional[str]. Human key to use for conversation.\n#     field_model: # Optional[str]. Assistant key to use for conversation.\n\n#   # Custom user prompt\n#   - path: repo\n#     type:\n#       # The below are defaults. only set what's needed.\n#       system_prompt: \"\"\n#       system_format: \"{system}\"\n#       field_system: system\n#       field_instruction: instruction\n#       field_input: input\n#       field_output: output\n\n#       # Customizable to be single line or multi-line\n#       # 'format' can include {input}\n#       format: |-\n#         User: {instruction} {input}\n#         Assistant:\n#       # 'no_input_format' cannot include {input}\n#       no_input_format: \"{instruction} \"\n\n#       # For `completion` datasets only, uses the provided field instead of `text` column\n#       field:\n\n# # Axolotl attempts to save the dataset as an arrow after packing the data together so\n# # subsequent training attempts load faster, relative path\n# dataset_prepared_path: data/last_run_prepared\n# # Push prepared dataset to hub\n# push_dataset_to_hub: # repo path\n# # The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()`\n# # if not set.\n# dataset_num_proc: # defaults to os.cpu_count() if not set\n# # push checkpoints to hub\n# hub_model_id: # repo path to push finetuned model\n# # how to push checkpoints to hub\n# # https://huggingface.co/docs/transformers/v4.31.0/en/main_classes/trainer#transformers.TrainingArguments.hub_strategy\n# hub_strategy:\n# # Whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets\n# # Required to be true when used in combination with `push_dataset_to_hub`\n# hf_use_auth_token: # boolean\n# # How much of the dataset to set aside as evaluation. 1 = 100%, 0.50 = 50%, etc. 0 for no eval.\n# val_set_size: 0.04\n# # Num shards for whole dataset\n# dataset_shard_num:\n# # Index of shard to use for whole dataset\n# dataset_shard_idx:\n\n# # The maximum length of an input to train with, this should typically be less than 2048\n# # as most models have a token/context limit of 2048\n# sequence_len: 2048\n# # Pad inputs so each step uses constant sized buffers\n# # This will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently\n# pad_to_sequence_len:\n# # Max sequence length to concatenate training samples together up to\n# # Inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning\n# # FutureWarning: This will soon be DEPRECATED\n# max_packed_sequence_len: 1024\n# # Use efficient multi-packing with block diagonal attention and per sequence position_ids. Recommend set to 'true'\n# sample_packing:\n# # Set to 'false' if getting errors during eval with sample_packing on.\n# eval_sample_packing:\n# # You can set these packing optimizations AFTER starting a training at least once.\n# # The trainer will provide recommended values for these values.\n# sample_packing_eff_est:\n# total_num_tokens:\n\n# # If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model\n# adapter: lora\n# # If you already have a lora model trained that you want to load, put that here.\n# # This means after training, if you want to test the model, you should set this to the value of `lora_out_dir`.\n# lora_model_dir:\n\n# # LoRA hyperparameters\n# # For more details about the following options, see:\n# # https://www.anyscale.com/blog/fine-tuning-llms-lora-or-full-parameter-an-in-depth-analysis-with-llama-2\n# lora_r: 8\n# lora_alpha: 16\n# lora_dropout: 0.05\n# lora_target_modules:\n#   - q_proj\n#   - v_proj\n# #  - k_proj\n# #  - o_proj\n# #  - gate_proj\n# #  - down_proj\n# #  - up_proj\n# lora_target_linear: # If true, will target all linear layers\n\n# # If you added new tokens to the tokenizer, you may need to save some LoRA modules because they need to know the new tokens.\n# # For LLaMA and Mistral, you need to save `embed_tokens` and `lm_head`. It may vary for other models.\n# # `embed_tokens` converts tokens to embeddings, and `lm_head` converts embeddings to token probabilities.\n# # https://github.com/huggingface/peft/issues/334#issuecomment-1561727994\n# lora_modules_to_save:\n# #  - embed_tokens\n# #  - lm_head\n\n# # Once you complete training, the model will be saved to the following directory.\n# # If you merge the adapter to the base model, a subdirectory `merged` will be created under this directory.\n# # Make sure `lora_model_dir` points to this directory if you want to use the trained model.\n# lora_out_dir:\n# lora_fan_in_fan_out: false\n\n# # ReLoRA configuration\n# # Must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed\n# relora_steps: # Number of steps per ReLoRA restart\n# relora_warmup_steps: # Number of per-restart warmup steps\n# relora_cpu_offload: # True to perform lora weight merges on cpu during restarts, for modest gpu memory savings\n\n# # wandb configuration if you're using it\n# wandb_mode: # \"offline\" to save run metadata locally and not sync to the server, \"disabled\" to turn off wandb\n# wandb_project: # Your wandb project name\n# wandb_entity: # A wandb Team name if using a Team\n# wandb_watch:\n# wandb_run_id: # Set the name of your wandb run\n# wandb_log_model: # \"checkpoint\" to log model to wandb Artifacts every `save_steps` or \"end\" to log only at the end of training\n\n# # Where to save the full-finetuned model to\n# output_dir: ./completed-model\n\n# # Whether to use torch.compile and which backend to use\n# torch_compile:  # bool\n# torch_compile_backend:  # Optional[str]\n\n# # Training hyperparameters\n\n# # If greater than 1, backpropagation will be skipped and the gradients will be accumulated for the given number of steps.\n# gradient_accumulation_steps: 1\n# # The number of samples to include in each batch. This is the number of samples sent to each GPU.\n# micro_batch_size: 2\n# eval_batch_size:\n# num_epochs: 4\n# warmup_steps: 100  # cannot use with warmup_ratio\n# warmup_ratio: 0.05  # cannot use with warmup_steps\n# learning_rate: 0.00003\n# lr_quadratic_warmup:\n# logging_steps:\n# save_strategy: # Set to `no` to skip checkpoint saves\n# save_steps: # Leave empty to save at each epoch\n# eval_steps: # Leave empty to eval at each epoch, integers for every N steps. decimal for fraction of total steps\n# save_total_limit: # Checkpoints saved at a time\n# # Maximum number of iterations to train for. It precedes num_epochs which means that\n# # if both are set, num_epochs will not be guaranteed.\n# # e.g., when 1 epoch is 1000 steps => `num_epochs: 2` and `max_steps: 100` will train for 100 steps\n# max_steps:\n\n# eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0\n# eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128\n\n# # Whether to mask out or include the human's prompt from the training labels\n# train_on_inputs: false\n# # Group similarly sized data to minimize padding.\n# # May be slower to start, as it must download and sort the entire dataset.\n# # Note that training loss may have an oscillating pattern with this enabled.\n# group_by_length: false\n\n# # Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing\n# gradient_checkpointing: false\n\n# # Stop training after this many evaluation losses have increased in a row\n# # https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback\n# early_stopping_patience: 3\n\n# # Specify a scheduler and kwargs to use with the optimizer\n# lr_scheduler: # 'one_cycle' | empty for cosine\n# lr_scheduler_kwargs:\n\n# # For one_cycle optim\n# lr_div_factor: # Learning rate div factor\n\n# # Specify optimizer\n# # Valid values are driven by the Transformers OptimizerNames class, see:\n# # https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/training_args.py#L134\n# #\n# # Note that not all optimizers may be available in your environment, ex: 'adamw_anyprecision' is part of\n# # torchdistx, 'adamw_bnb_8bit' is part of bnb.optim.Adam8bit, etc. When in doubt, it is recommended to start with the optimizer used\n# # in the examples/ for your model and fine-tuning use case.\n# #\n# # Valid values for 'optimizer' include:\n# # - adamw_hf\n# # - adamw_torch\n# # - adamw_torch_fused\n# # - adamw_torch_xla\n# # - adamw_apex_fused\n# # - adafactor\n# # - adamw_anyprecision\n# # - sgd\n# # - adagrad\n# # - adamw_bnb_8bit\n# # - lion_8bit\n# # - lion_32bit\n# # - paged_adamw_32bit\n# # - paged_adamw_8bit\n# # - paged_lion_32bit\n# # - paged_lion_8bit\n# optimizer:\n# # Specify weight decay\n# weight_decay:\n# # adamw hyperparams\n# adam_beta1:\n# adam_beta2:\n# adam_epsilon:\n# # Gradient clipping max norm\n# max_grad_norm:\n\n# # Augmentation techniques\n# # NEFT https://arxiv.org/abs/2310.05914, set this to a number (paper default is 5) to add noise to embeddings\n# # currently only supported on Llama and Mistral\n# noisy_embedding_alpha:\n\n# # Whether to bettertransformers\n# flash_optimum:\n# # Whether to use xformers attention patch https://github.com/facebookresearch/xformers:\n# xformers_attention:\n# # Whether to use flash attention patch https://github.com/Dao-AILab/flash-attention:\n# flash_attention:\n# flash_attn_cross_entropy:  # Whether to use flash-attention cross entropy implementation - advanced use only\n# flash_attn_rms_norm:  # Whether to use flash-attention rms norm implementation - advanced use only\n# flash_attn_fuse_mlp: # Whether to fuse part of the MLP into a single operation\n# # Whether to use scaled-dot-product attention\n# # https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html\n# sdp_attention:\n# # Landmark attention (only llama)\n# landmark_attention:\n# # xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py\n# # LLaMA only\n# xpos_rope:\n\n# # Resume from a specific checkpoint dir\n# resume_from_checkpoint:\n# # If resume_from_checkpoint isn't set and you simply want it to start where it left off.\n# # Be careful with this being turned on between different models.\n# auto_resume_from_checkpoints: false\n\n# # Don't mess with this, it's here for accelerate and torchrun\n# local_rank:\n\n# # Add or change special tokens.\n# # If you add tokens here, you don't need to add them to the `tokens` list.\n# special_tokens:\n#   # bos_token: \"<s>\"\n#   # eos_token: \"</s>\"\n#   # unk_token: \"<unk>\"\n\n# # Add extra tokens.\n# tokens:\n\n# # FSDP\n# fsdp:\n# fsdp_config:\n\n# # Deepspeed config path. e.g., deepspeed/zero3.json\n# deepspeed:\n\n# # Advanced DDP Arguments\n# ddp_timeout:\n# ddp_bucket_cap_mb:\n# ddp_broadcast_buffers:\n\n# # Path to torch distx for optim 'adamw_anyprecision'\n# torchdistx_path:\n\n# # Set to HF dataset for type: 'completion' for streaming instead of pre-tokenize\n# pretraining_dataset:\n\n# # Debug mode\n# debug:\n\n# # Seed\n# seed:\n\n# # Allow overwrite yml config using from cli\n# strict:\n\nbase_model: ${BASE_MODEL}\nbase_model_ignore_patterns: ${BASE_MODEL_IGNORE_PATTERNS}\nbase_model_config: ${BASE_MODEL_CONFIG}\nrevision_of_model: ${REVISION_OF_MODEL}\ntokenizer_config: ${TOKENIZER_CONFIG}\nmodel_type: ${MODEL_TYPE}\ntokenizer_type: ${TOKENIZER_TYPE}\ntrust_remote_code: ${TRUST_REMOTE_CODE}\ntokenizer_use_fast: ${TOKENIZER_USE_FAST}\ntokenizer_legacy: ${TOKENIZER_LEGACY}\nresize_token_embeddings_to_32x: ${RESIZE_TOKEN_EMBEDDINGS_TO_32X}\n\nis_falcon_derived_model: ${IS_FALCON_DERIVED_MODEL}\nis_llama_derived_model: ${IS_LLAMA_DERIVED_MODEL}\nis_qwen_derived_model: ${IS_QWEN_DERIVED_MODEL}\nis_mistral_derived_model: ${IS_MISTRAL_DERIVED_MODEL}\n\noverrides_of_model_config:\n  rope_scaling:\n    type: ${ROPE_SCALING_TYPE}\n    factor: ${ROPE_SCALING_FACTOR}\n\nbnb_config_kwargs:\n  llm_int8_has_fp16_weight: ${BNB_LLM_INT8_HAS_FP16_WEIGHT}\n  bnb_4bit_quant_type: ${BNB_4BIT_QUANT_TYPE}\n  bnb_4bit_use_double_quant: ${BNB_4BIT_USE_DOUBLE_QUANT}\n\ngptq: ${GPTQ}\nload_in_8bit: ${LOAD_IN_8BIT}\nload_in_4bit: ${LOAD_IN_4BIT}\nbf16: ${BF16}\nfp16: ${FP16}\ntf32: ${TF32}\nbfloat16: ${BFLOAT16}\nfloat16: ${FLOAT16}\n\ngpu_memory_limit: ${GPU_MEMORY_LIMIT}\nlora_on_cpu: ${LORA_ON_CPU}\n\ndatasets:\n  - path: ${DATASET_PATH}\n    type: ${DATASET_TYPE}\n    ds_type: ${DATASET_DS_TYPE}\n    data_files: ${DATASET_DATA_FILES}\n    shards: ${DATASET_SHARDS}\n    name: ${DATASET_NAME}\n    train_on_split: ${DATASET_TRAIN_ON_SPLIT}\n    revision: ${DATASET_REVISION}\n    trust_remote_code: ${DATASET_TRUST_REMOTE_CODE}\n\nrl: ${RL}\ndpo_use_weighting: ${DPO_USE_WEIGHTING}\n\nchat_template: ${CHAT_TEMPLATE}\nchat_template_jinja: ${CHAT_TEMPLATE_JINJA}\ndefault_system_message: ${DEFAULT_SYSTEM_MESSAGE}\ndataset_prepared_path: ${DATASET_PREPARED_PATH}\npush_dataset_to_hub: ${PUSH_DATASET_TO_HUB}\ndataset_num_proc: ${DATASET_NUM_PROC}\ndataset_keep_in_memory: ${DATASET_KEEP_IN_MEMORY}\nhub_model_id: ${HUB_MODEL_ID}\nhub_strategy: ${HUB_STRATEGY}\nhf_use_auth_token: ${HF_USE_AUTH_TOKEN}\nval_set_size: ${VAL_SET_SIZE}\ndataset_shard_num: ${DATASET_SHARD_NUM}\ndataset_shard_idx: ${DATASET_SHARD_IDX}\n\nsequence_len: ${SEQUENCE_LEN}\npad_to_sequence_len: ${PAD_TO_SEQUENCE_LEN}\nsample_packing: ${SAMPLE_PACKING}\neval_sample_packing: ${EVAL_SAMPLE_PACKING}\nsample_packing_eff_est: ${SAMPLE_PACKING_EFF_EST}\ntotal_num_tokens: ${TOTAL_NUM_TOKENS}\nsample_packing_group_size: ${SAMPLE_PACKING_GROUP_SIZE}\nsample_packing_bin_size: ${SAMPLE_PACKING_BIN_SIZE}\n\nbatch_flattening: ${BATCH_FLATTENING}\ndevice_map: ${DEVICE_MAP}\nmax_memory: ${MAX_MEMORY}\n\nadapter: ${ADAPTER}\nlora_model_dir: ${LORA_MODEL_DIR}\n\nlora_r: ${LORA_R}\nlora_alpha: ${LORA_ALPHA}\nlora_dropout: ${LORA_DROPOUT}\nlora_target_modules:\n  - ${LORA_TARGET_MODULES}\nlora_target_linear: ${LORA_TARGET_LINEAR}\npeft_layers_to_transform: ${PEFT_LAYERS_TO_TRANSFORM}\nlora_modules_to_save: ${LORA_MODULES_TO_SAVE}\nlora_fan_in_fan_out: ${LORA_FAN_IN_FAN_OUT}\n\nloraplus_lr_ratio: ${LORAPLUS_LR_RATIO}\nloraplus_lr_embedding: ${LORAPLUS_LR_EMBEDDING}\n\npeft:\n  loftq_config:\n    loftq_bits: ${LOFTQ_BITS}\n\nrelora_steps: ${RELORA_STEPS}\nrelora_warmup_steps: ${RELORA_WARMUP_STEPS}\nrelora_anneal_steps: ${RELORA_ANNEAL_STEPS}\nrelora_prune_ratio: ${RELORA_PRUNE_RATIO}\nrelora_cpu_offload: ${RELORA_CPU_OFFLOAD}\n\nwandb_mode: ${WANDB_MODE}\nwandb_project: ${WANDB_PROJECT}\nwandb_entity: ${WANDB_ENTITY}\nwandb_watch: ${WANDB_WATCH}\nwandb_name: ${WANDB_NAME}\nwandb_run_id: ${WANDB_RUN_ID}\nwandb_log_model: ${WANDB_LOG_MODEL}\n\nmlflow_tracking_uri: ${MLFLOW_TRACKING_URI}\nmlflow_experiment_name: ${MLFLOW_EXPERIMENT_NAME}\nmlflow_run_name: ${MLFLOW_RUN_NAME}\nhf_mlflow_log_artifacts: ${HF_MLFLOW_LOG_ARTIFACTS}\n\nuse_comet: ${USE_COMET}\ncomet_api_key: ${COMET_API_KEY}\ncomet_workspace: ${COMET_WORKSPACE}\ncomet_project_name: ${COMET_PROJECT_NAME}\ncomet_experiment_key: ${COMET_EXPERIMENT_KEY}\ncomet_mode: ${COMET_MODE}\ncomet_online: ${COMET_ONLINE}\ncomet_experiment_config: ${COMET_EXPERIMENT_CONFIG}\n\noutput_dir: ${OUTPUT_DIR}\n\ntorch_compile: ${TORCH_COMPILE}\ntorch_compile_backend: ${TORCH_COMPILE_BACKEND}\n\ngradient_accumulation_steps: ${GRADIENT_ACCUMULATION_STEPS}\nmicro_batch_size: ${MICRO_BATCH_SIZE}\neval_batch_size: ${EVAL_BATCH_SIZE}\nnum_epochs: ${NUM_EPOCHS}\nwarmup_steps: ${WARMUP_STEPS}\nwarmup_ratio: ${WARMUP_RATIO}\nlearning_rate: ${LEARNING_RATE}\nlr_quadratic_warmup: ${LR_QUADRATIC_WARMUP}\nlogging_steps: ${LOGGING_STEPS}\neval_steps: ${EVAL_STEPS}\nevals_per_epoch: ${EVALS_PER_EPOCH}\nsave_strategy: ${SAVE_STRATEGY}\nsave_steps: ${SAVE_STEPS}\nsaves_per_epoch: ${SAVES_PER_EPOCH}\nsave_total_limit: ${SAVE_TOTAL_LIMIT}\nmax_steps: ${MAX_STEPS}\n\neval_table_size: ${EVAL_TABLE_SIZE}\neval_max_new_tokens: ${EVAL_MAX_NEW_TOKENS}\neval_causal_lm_metrics: ${EVAL_CAUSAL_LM_METRICS}\n\nprofiler_steps: ${PROFILER_STEPS}\nloss_watchdog_threshold: ${LOSS_WATCHDOG_THRESHOLD}\nloss_watchdog_patience: ${LOSS_WATCHDOG_PATIENCE}\n\ntrain_on_inputs: ${TRAIN_ON_INPUTS}\ngroup_by_length: ${GROUP_BY_LENGTH}\ngradient_checkpointing: ${GRADIENT_CHECKPOINTING}\nearly_stopping_patience: ${EARLY_STOPPING_PATIENCE}\n\nlr_scheduler: ${LR_SCHEDULER}\nlr_scheduler_kwargs: ${LR_SCHEDULER_KWARGS}\ncosine_min_lr_ratio: ${COSINE_MIN_LR_RATIO}\ncosine_constant_lr_ratio: ${COSINE_CONSTANT_LR_RATIO}\nlr_div_factor: ${LR_DIV_FACTOR}\n\noptimizer: ${OPTIMIZER}\noptim_args: ${OPTIM_ARGS}\noptim_target_modules: ${OPTIM_TARGET_MODULES}\nweight_decay: ${WEIGHT_DECAY}\nadam_beta1: ${ADAM_BETA1}\nadam_beta2: ${ADAM_BETA2}\nadam_epsilon: ${ADAM_EPSILON}\nmax_grad_norm: ${MAX_GRAD_NORM}\n\nneftune_noise_alpha: ${NEFTUNE_NOISE_ALPHA}\n\nflash_optimum: ${FLASH_OPTIMUM}\nxformers_attention: ${XFORMERS_ATTENTION}\nflash_attention: ${FLASH_ATTENTION}\nflash_attn_cross_entropy: ${FLASH_ATTN_CROSS_ENTROPY}\nflash_attn_rms_norm: ${FLASH_ATTN_RMS_NORM}\nflash_attn_fuse_mlp: ${FLASH_ATTN_FUSE_MLP}\nsdp_attention: ${SDP_ATTENTION}\ns2_attention: ${S2_ATTENTION}\nresume_from_checkpoint: ${RESUME_FROM_CHECKPOINT}\nauto_resume_from_checkpoints: ${AUTO_RESUME_FROM_CHECKPOINTS}\n\nlocal_rank: ${LOCAL_RANK}\n\nspecial_tokens:\n  bos_token: ${SPECIAL_TOKEN_BOS}\n  eos_token: ${SPECIAL_TOKEN_EOS}\n  unk_token: ${SPECIAL_TOKEN_UNK}\n  pad_token: ${SPECIAL_TOKEN_PAD}\n\ntokens: ${TOKENS}\n\nfsdp: ${FSDP}\nfsdp_config: ${FSDP_CONFIG}\ndeepspeed: ${DEEPSPEED}\n\nddp_timeout: ${DDP_TIMEOUT}\nddp_bucket_cap_mb: ${DDP_BUCKET_CAP_MB}\nddp_broadcast_buffers: ${DDP_BROADCAST_BUFFERS}\n\ntorchdistx_path: ${TORCHDISTX_PATH}\npretraining_dataset: ${PRETRAINING_DATASET}\ndebug: ${DEBUG}\nseed: ${SEED}\nstrict: ${STRICT}\n"
  },
  {
    "path": ".runpod/src/handler.py",
    "content": "\"\"\"\nRunpod serverless entrypoint handler\n\"\"\"\n\nimport os\n\nimport runpod\nimport yaml\nfrom huggingface_hub._login import login\nfrom train import train\nfrom utils import get_output_dir\n\nBASE_VOLUME = os.environ.get(\"BASE_VOLUME\", \"/runpod-volume\")\nif not os.path.exists(BASE_VOLUME):\n    os.makedirs(BASE_VOLUME)\n\nlogger = runpod.RunPodLogger()\n\n\nasync def handler(job):\n    runpod_job_id = job[\"id\"]\n    inputs = job[\"input\"]\n    run_id = inputs.get(\"run_id\", \"default_run_id\")\n    args = inputs.get(\"args\", {})\n\n    # Set output directory\n    output_dir = os.path.join(BASE_VOLUME, get_output_dir(run_id))\n    args[\"output_dir\"] = output_dir\n\n    # First save args to a temporary config file\n    config_path = \"/workspace/test_config.yaml\"\n\n    # Add run_name and job_id to args before saving\n    args[\"run_name\"] = run_id\n    args[\"runpod_job_id\"] = runpod_job_id\n\n    yaml_data = yaml.dump(args, default_flow_style=False)\n    with open(config_path, \"w\", encoding=\"utf-8\") as file:\n        file.write(yaml_data)\n\n    # Handle credentials\n    credentials = inputs.get(\"credentials\", {})\n\n    if \"wandb_api_key\" in credentials:\n        os.environ[\"WANDB_API_KEY\"] = credentials[\"wandb_api_key\"]\n    if \"hf_token\" in credentials:\n        os.environ[\"HF_TOKEN\"] = credentials[\"hf_token\"]\n\n    if os.environ.get(\"HF_TOKEN\"):\n        login(token=os.environ[\"HF_TOKEN\"])\n    else:\n        logger.info(\"No HF_TOKEN provided. Skipping login.\")\n\n    logger.info(\"Starting Training.\")\n    async for result in train(config_path):  # Pass the config path instead of args\n        logger.info(result)\n    logger.info(\"Training Complete.\")\n\n    # Cleanup\n    if \"WANDB_API_KEY\" in os.environ:\n        del os.environ[\"WANDB_API_KEY\"]\n    if \"HF_TOKEN\" in os.environ:\n        del os.environ[\"HF_TOKEN\"]\n\n\nrunpod.serverless.start({\"handler\": handler, \"return_aggregate_stream\": True})\n"
  },
  {
    "path": ".runpod/src/test_input.json",
    "content": "{\n  \"input\": {\n    \"user_id\": \"user\",\n    \"model_id\": \"llama-test\",\n    \"run_id\": \"llama-test\",\n    \"credentials\": {\n      \"wandb_api_key\": \"\",\n      \"hf_token\": \"\"\n    },\n    \"args\": {\n      \"base_model\": \"NousResearch/Meta-Llama-3-8B\",\n      \"model_type\": \"LlamaForCausalLM\",\n      \"tokenizer_type\": \"AutoTokenizer\",\n      \"load_in_8bit\": true,\n      \"load_in_4bit\": false,\n      \"strict\": false,\n      \"datasets\": [\n        {\n          \"path\": \"mhenrichsen/alpaca_2k_test\",\n          \"type\": \"alpaca\"\n        }\n      ],\n      \"val_set_size\": 0.05,\n      \"output_dir\": \"./outputs/lora-out\",\n      \"sequence_len\": 4096,\n      \"sample_packing\": true,\n      \"eval_sample_packing\": false,\n      \"pad_to_sequence_len\": true,\n      \"adapter\": \"lora\",\n      \"lora_r\": 32,\n      \"lora_alpha\": 16,\n      \"lora_dropout\": 0.05,\n      \"lora_target_linear\": true,\n      \"lora_modules_to_save\": [\n        \"embed_tokens\",\n        \"lm_head\"\n      ],\n      \"gradient_accumulation_steps\": 4,\n      \"micro_batch_size\": 2,\n      \"num_epochs\": 1,\n      \"optimizer\": \"adamw_bnb_8bit\",\n      \"lr_scheduler\": \"cosine\",\n      \"learning_rate\": 0.0002,\n      \"train_on_inputs\": false,\n      \"group_by_length\": false,\n      \"bf16\": \"auto\",\n      \"tf32\": false,\n      \"gradient_checkpointing\": true,\n      \"logging_steps\": 1,\n      \"flash_attention\": true,\n      \"warmup_steps\": 1,\n      \"evals_per_epoch\": 1,\n      \"eval_max_new_tokens\": 128,\n      \"saves_per_epoch\": 1,\n      \"weight_decay\": 0.0,\n      \"special_tokens\": {\n        \"pad_token\": \"<|end_of_text|>\"\n      }\n    }\n  }\n}\n"
  },
  {
    "path": ".runpod/src/train.py",
    "content": "\"\"\"\nRunpod train entrypoint\n\"\"\"\n\nimport asyncio\n\n\nasync def train(config_path: str, gpu_id: str = \"0\", preprocess: bool = True):\n    \"\"\"\n    Run preprocessing (if enabled) and training with the given config file\n    :param config_path: Path to the YAML config file\n    :param gpu_id: GPU ID to use (default: \"0\")\n    :param preprocess: Whether to run preprocessing (default: True)\n\n    \"\"\"\n    # First check if preprocessing is needed\n    if preprocess:\n        # Preprocess command\n        preprocess_cmd = (\n            f\"CUDA_VISIBLE_DEVICES={gpu_id} axolotl preprocess {config_path}\"\n        )\n        process = await asyncio.create_subprocess_shell(\n            preprocess_cmd,\n            stdout=asyncio.subprocess.PIPE,\n            stderr=asyncio.subprocess.STDOUT,\n        )\n\n        if process.stdout is not None:\n            async for line in process.stdout:\n                yield f\"Preprocessing: {line.decode().strip()}\"\n        await process.wait()\n        yield \"Preprocessing completed.\"\n    else:\n        yield \"Skipping preprocessing step.\"\n\n    # Training command\n    train_cmd = f\"axolotl train {config_path}\"\n    process = await asyncio.create_subprocess_shell(\n        train_cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.STDOUT\n    )\n\n    if process.stdout is not None:\n        async for line in process.stdout:\n            yield f\"Training: {line.decode().strip()}\"\n    await process.wait()\n"
  },
  {
    "path": ".runpod/src/utils.py",
    "content": "\"\"\"\nRunpod launcher utils\n\"\"\"\n\nimport os\n\nimport yaml\n\n\ndef get_output_dir(run_id):\n    path = f\"fine-tuning/{run_id}\"\n    return path\n\n\ndef make_valid_config(input_args):\n    \"\"\"\n    Creates and saves updated config file, returns the path to the new config\n    :param input_args: dict of input args\n    :return: str, path to the updated config file\n    \"\"\"\n    # Load default config\n    with open(\"config/config.yaml\", \"r\", encoding=\"utf-8\") as fin:\n        all_args = yaml.safe_load(fin)\n\n    if not input_args:\n        print(\"No args provided, using defaults\")\n    else:\n        all_args.update(input_args)\n\n    # Create updated config path\n    updated_config_path = \"config/updated_config.yaml\"\n\n    # Save updated config to new file\n    with open(updated_config_path, \"w\", encoding=\"utf-8\") as f:\n        yaml.dump(all_args, f)\n\n    return updated_config_path\n\n\ndef set_config_env_vars(args: dict):\n    \"\"\"\n    Convert API arguments into environment variables.\n    Handles nested dictionaries, lists, and special values.\n\n    Args:\n        args (dict): The arguments dictionary from the API request\n    \"\"\"\n\n    def process_value(value):\n        \"\"\"Convert Python values to string format for environment variables\"\"\"\n        if value is None:\n            return \"\"\n        if isinstance(value, bool):\n            return str(value).lower()\n        if isinstance(value, (list, dict)):\n            return str(value)\n        return str(value)\n\n    def set_env_vars(data, prefix=\"\"):\n        \"\"\"Recursively set environment variables from nested dictionary\"\"\"\n        for key, value in data.items():\n            env_key = prefix + key.upper()\n\n            # Handle special cases\n            if isinstance(value, dict):\n                # For nested dictionaries (like special_tokens)\n                set_env_vars(value, f\"{env_key}_\")\n            elif isinstance(value, list):\n                # Handle list of dictionaries (like datasets)\n                if value and isinstance(value[0], dict):\n                    for i, item in enumerate(value):\n                        set_env_vars(item, f\"{env_key}_{i}_\")\n                else:\n                    # For simple lists (like lora_target_modules)\n                    os.environ[env_key] = process_value(value)\n            else:\n                # Handle all other cases\n                os.environ[env_key] = process_value(value)\n\n    # Clear any existing related environment variables\n    # This prevents old values from persisting\n    for key in list(os.environ.keys()):\n        if key.startswith(\n            (\"BASE_MODEL\", \"MODEL_TYPE\", \"TOKENIZER_TYPE\", \"DATASET\", \"LORA_\", \"WANDB_\")\n        ):\n            del os.environ[key]\n\n    # Set new environment variables\n    set_env_vars(args)\n"
  },
  {
    "path": ".runpod/test-input.json",
    "content": "{\n  \"input\": {\n    \"name\": \"quick_smoke_test_sft\",\n    \"user_id\": \"user\",\n    \"model_id\": \"llama-test\",\n    \"run_id\": \"llama-test\",\n    \"credentials\": {\n      \"wandb_api_key\": \"\",\n      \"hf_token\": \"\"\n    },\n    \"args\": {\n      \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n      \"model_type\": \"AutoModelForCausalLM\",\n      \"tokenizer_type\": \"AutoTokenizer\",\n      \"load_in_4bit\": true,\n      \"strict\": false,\n      \"datasets\": [\n        {\n          \"path\": \"mhenrichsen/alpaca_2k_test\",\n          \"type\": \"alpaca\",\n          \"split\": \"train[:10%]\"\n        }\n      ],\n      \"val_set_size\": 0.02,\n      \"output_dir\": \"./outputs/lora-out\",\n      \"sequence_len\": 4096,\n      \"sample_packing\": true,\n      \"eval_sample_packing\": false,\n      \"pad_to_sequence_len\": true,\n      \"adapter\": \"qlora\",\n      \"lora_r\": 32,\n      \"lora_alpha\": 64,\n      \"lora_dropout\": 0.05,\n      \"lora_target_linear\": true,\n      \"lora_modules_to_save\": [\n        \"embed_tokens\",\n        \"lm_head\"\n      ],\n      \"gradient_accumulation_steps\": 2,\n      \"micro_batch_size\": 1,\n      \"num_epochs\": 1,\n      \"optimizer\": \"adamw_torch_fused\",\n      \"lr_scheduler\": \"cosine\",\n      \"learning_rate\": 0.0002,\n      \"train_on_inputs\": false,\n      \"group_by_length\": false,\n      \"bf16\": \"auto\",\n      \"tf32\": true,\n      \"gradient_checkpointing\": true,\n      \"logging_steps\": 1,\n      \"flash_attention\": true,\n      \"warmup_steps\": 1,\n      \"evals_per_epoch\": 1,\n      \"eval_max_new_tokens\": 128,\n      \"saves_per_epoch\": 1,\n      \"weight_decay\": 0.0,\n      \"special_tokens\": {\n        \"pad_token\": \"<|endoftext|>\"\n      },\n      \"max_steps\": 20\n    },\n    \"timeout\": 100000\n  },\n  \"config\": {\n    \"gpuTypeId\": \"NVIDIA GeForce RTX 4090\",\n    \"gpuCount\": 1,\n    \"containerDiskInGb\": 200,\n    \"env\": [\n      {\n        \"key\": \"TOKENIZER\",\n        \"value\": \"\"\n      },\n      {\n        \"key\": \"DISABLE_LOG_STATS\",\n        \"value\": \"true\"\n      }\n    ],\n    \"allowedCudaVersions\": [\n      \"12.8\",\n      \"12.7\",\n      \"12.6\",\n      \"12.5\",\n      \"12.4\"\n    ]\n  }\n}\n"
  },
  {
    "path": ".runpod/tests.json",
    "content": "{\n  \"tests\": [\n    {\n      \"name\": \"quick_smoke_test_sft\",\n      \"input\": {\n        \"user_id\": \"user\",\n        \"model_id\": \"llama-test\",\n        \"run_id\": \"llama-test\",\n        \"credentials\": {\n          \"wandb_api_key\": \"\",\n          \"hf_token\": \"\"\n        },\n        \"args\": {\n          \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n          \"model_type\": \"AutoModelForCausalLM\",\n          \"tokenizer_type\": \"AutoTokenizer\",\n          \"load_in_4bit\": true,\n          \"strict\": false,\n          \"datasets\": [\n            {\n              \"path\": \"mhenrichsen/alpaca_2k_test\",\n              \"type\": \"alpaca\",\n              \"split\": \"train[:10%]\"\n            }\n          ],\n          \"val_set_size\": 0.02,\n          \"output_dir\": \"./outputs/lora-out\",\n          \"sequence_len\": 4096,\n          \"sample_packing\": true,\n          \"eval_sample_packing\": false,\n          \"pad_to_sequence_len\": true,\n          \"adapter\": \"qlora\",\n          \"lora_r\": 32,\n          \"lora_alpha\": 64,\n          \"lora_dropout\": 0.05,\n          \"lora_target_linear\": true,\n          \"lora_modules_to_save\": [\n            \"embed_tokens\",\n            \"lm_head\"\n          ],\n          \"gradient_accumulation_steps\": 2,\n          \"micro_batch_size\": 1,\n          \"num_epochs\": 1,\n          \"optimizer\": \"adamw_torch_fused\",\n          \"lr_scheduler\": \"cosine\",\n          \"learning_rate\": 0.0002,\n          \"train_on_inputs\": false,\n          \"group_by_length\": false,\n          \"bf16\": \"auto\",\n          \"tf32\": true,\n          \"gradient_checkpointing\": true,\n          \"logging_steps\": 1,\n          \"flash_attention\": true,\n          \"warmup_steps\": 1,\n          \"evals_per_epoch\": 1,\n          \"eval_max_new_tokens\": 128,\n          \"saves_per_epoch\": 1,\n          \"weight_decay\": 0.0,\n          \"special_tokens\": {\n            \"pad_token\": \"<|endoftext|>\"\n          },\n          \"max_steps\": 20\n        }\n      },\n      \"timeout\": 100000\n    }\n  ],\n  \"config\": {\n    \"gpuTypeId\": \"NVIDIA GeForce RTX 4090\",\n    \"gpuCount\": 1,\n    \"containerDiskInGb\": 200,\n    \"env\": [\n      {\n        \"key\": \"TOKENIZER\",\n        \"value\": \"\"\n      },\n      {\n        \"key\": \"DISABLE_LOG_STATS\",\n        \"value\": \"true\"\n      }\n    ],\n    \"allowedCudaVersions\": [\n      \"12.8\",\n      \"12.7\",\n      \"12.6\",\n      \"12.5\",\n      \"12.4\"\n    ]\n  }\n}\n"
  },
  {
    "path": "CITATION.cff",
    "content": "cff-version: 1.2.0\ntype: software\ntitle: \"Axolotl: Open Source LLM Post-Training\"\nmessage: \"If you use this software, please cite it as below.\"\nauthors:\n  - name: \"Axolotl maintainers and contributors\"\nrepository-code: \"https://github.com/axolotl-ai-cloud/axolotl\"\nurl: \"https://axolotl.ai/\"\nlicense: Apache-2.0\ndate-released: \"2023-05-30\"\n"
  },
  {
    "path": "CNAME",
    "content": "docs.axolotl.ai\n"
  },
  {
    "path": "FAQS.md",
    "content": "# FAQs\n\n- Can you train StableLM with this? Yes, but only with a single GPU atm. Multi GPU support is coming soon! Just waiting on this [PR](https://github.com/huggingface/transformers/pull/22874)\n- Will this work with Deepspeed? That's still a WIP, but setting `export ACCELERATE_USE_DEEPSPEED=true` should work in some cases\n- `Error invalid argument at line 359 in file /workspace/bitsandbytes/csrc/pythonInterface.c`\n`/arrow/cpp/src/arrow/filesystem/s3fs.cc:2598:  arrow::fs::FinalizeS3 was not called even though S3 was initialized.`\nThis could lead to a segmentation fault at exit. Try reinstalling bitsandbytes and transformers from source.\n"
  },
  {
    "path": "LICENSE",
    "content": "\n                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "MANIFEST.in",
    "content": "include requirements.txt\ninclude README.md\ninclude LICENSE\ninclude src/setuptools_axolotl_dynamic_dependencies.py\ninclude src/axolotl/utils/chat_templates/templates/*.jinja\nrecursive-include axolotl *.py\n"
  },
  {
    "path": "README.md",
    "content": "<p align=\"center\">\n    <picture>\n        <source media=\"(prefers-color-scheme: dark)\" srcset=\"https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/887513285d98132142bf5db2a74eb5e0928787f1/image/axolotl_logo_digital_white.svg\">\n        <source media=\"(prefers-color-scheme: light)\" srcset=\"https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/887513285d98132142bf5db2a74eb5e0928787f1/image/axolotl_logo_digital_black.svg\">\n        <img alt=\"Axolotl\" src=\"https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/887513285d98132142bf5db2a74eb5e0928787f1/image/axolotl_logo_digital_black.svg\" width=\"400\" height=\"104\" style=\"max-width: 100%;\">\n    </picture>\n</p>\n  <p align=\"center\">\n      <strong>A Free and Open Source LLM Fine-tuning Framework</strong><br>\n  </p>\n\n<p align=\"center\">\n    <img src=\"https://img.shields.io/github/license/axolotl-ai-cloud/axolotl.svg?color=blue\" alt=\"GitHub License\">\n    <img src=\"https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests.yml/badge.svg\" alt=\"tests\">\n    <a href=\"https://codecov.io/gh/axolotl-ai-cloud/axolotl\"><img src=\"https://codecov.io/gh/axolotl-ai-cloud/axolotl/branch/main/graph/badge.svg\" alt=\"codecov\"></a>\n    <a href=\"https://github.com/axolotl-ai-cloud/axolotl/releases\"><img src=\"https://img.shields.io/github/release/axolotl-ai-cloud/axolotl.svg\" alt=\"Releases\"></a>\n    <br/>\n    <a href=\"https://github.com/axolotl-ai-cloud/axolotl/graphs/contributors\"><img src=\"https://img.shields.io/github/contributors-anon/axolotl-ai-cloud/axolotl?color=yellow&style=flat-square\" alt=\"contributors\" style=\"height: 20px;\"></a>\n    <img src=\"https://img.shields.io/github/stars/axolotl-ai-cloud/axolotl\" alt=\"GitHub Repo stars\">\n    <br/>\n    <a href=\"https://discord.com/invite/HhrNrHJPRb\"><img src=\"https://img.shields.io/badge/discord-7289da.svg?style=flat-square&logo=discord\" alt=\"discord\" style=\"height: 20px;\"></a>\n    <a href=\"https://twitter.com/axolotl_ai\"><img src=\"https://img.shields.io/twitter/follow/axolotl_ai?style=social\" alt=\"twitter\" style=\"height: 20px;\"></a>\n    <a href=\"https://colab.research.google.com/github/axolotl-ai-cloud/axolotl/blob/main/examples/colab-notebooks/colab-axolotl-example.ipynb\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"google-colab\" style=\"height: 20px;\"></a>\n    <br/>\n    <img src=\"https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests-nightly.yml/badge.svg\" alt=\"tests-nightly\">\n    <img src=\"https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/multi-gpu-e2e.yml/badge.svg\" alt=\"multigpu-semi-weekly tests\">\n</p>\n\n\n## 🎉 Latest Updates\n\n- 2026/03:\n  - New model support has been added in Axolotl for [Mistral Small 4](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/mistral4), [Qwen3.5, Qwen3.5 MoE](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen3.5), [GLM-4.7-Flash](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm47-flash), [GLM-4.6V](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm46v), and [GLM-4.5-Air](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm45).\n  - [MoE expert quantization](https://docs.axolotl.ai/docs/expert_quantization.html) support (via `quantize_moe_experts: true`) greatly reduces VRAM when training MoE models (FSDP2 compat).\n- 2026/02:\n  - [ScatterMoE LoRA](https://github.com/axolotl-ai-cloud/axolotl/pull/3410) support. LoRA fine-tuning directly on MoE expert weights using custom Triton kernels.\n  - Axolotl now has support for [SageAttention](https://github.com/axolotl-ai-cloud/axolotl/pull/2823) and [GDPO](https://github.com/axolotl-ai-cloud/axolotl/pull/3353) (Generalized DPO).\n- 2026/01:\n  - New integration for [EAFT](https://github.com/axolotl-ai-cloud/axolotl/pull/3366) (Entropy-Aware Focal Training), weights loss by entropy of the top-k logit distribution, and [Scalable Softmax](https://github.com/axolotl-ai-cloud/axolotl/pull/3338), improves long context in attention.\n- 2025/12:\n  - Axolotl now includes support for [Kimi-Linear](https://docs.axolotl.ai/docs/models/kimi-linear.html), [Plano-Orchestrator](https://docs.axolotl.ai/docs/models/plano.html), [MiMo](https://docs.axolotl.ai/docs/models/mimo.html), [InternVL 3.5](https://docs.axolotl.ai/docs/models/internvl3_5.html), [Olmo3](https://docs.axolotl.ai/docs/models/olmo3.html), [Trinity](https://docs.axolotl.ai/docs/models/trinity.html), and [Ministral3](https://docs.axolotl.ai/docs/models/ministral3.html).\n  - [Distributed Muon Optimizer](https://github.com/axolotl-ai-cloud/axolotl/pull/3264) support has been added for FSDP2 pretraining.\n- 2025/10: New model support has been added in Axolotl for: [Qwen3 Next](https://docs.axolotl.ai/docs/models/qwen3-next.html), [Qwen2.5-vl, Qwen3-vl](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen2_5-vl), [Qwen3, Qwen3MoE](https://docs.axolotl.ai/docs/models/qwen3.html), [Granite 4](https://docs.axolotl.ai/docs/models/granite4.html), [HunYuan](https://docs.axolotl.ai/docs/models/hunyuan.html), [Magistral 2509](https://docs.axolotl.ai/docs/models/magistral/vision.html), [Apertus](https://docs.axolotl.ai/docs/models/apertus.html), and [Seed-OSS](https://docs.axolotl.ai/docs/models/seed-oss.html).\n\n<details>\n\n<summary>Expand older updates</summary>\n\n- 2025/09: Axolotl now has text diffusion training. Read more [here](https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/integrations/diffusion).\n- 2025/08: QAT has been updated to include NVFP4 support. See [PR](https://github.com/axolotl-ai-cloud/axolotl/pull/3107).\n- 2025/07:\n  - ND Parallelism support has been added into Axolotl. Compose Context Parallelism (CP), Tensor Parallelism (TP), and Fully Sharded Data Parallelism (FSDP) within a single node and across multiple nodes. Check out the [blog post](https://huggingface.co/blog/accelerate-nd-parallel) for more info.\n  - Axolotl adds more models: [GPT-OSS](https://docs.axolotl.ai/docs/models/gpt-oss.html), [Gemma 3n](https://docs.axolotl.ai/docs/models/gemma3n.html), [Liquid Foundation Model 2 (LFM2)](https://docs.axolotl.ai/docs/models/LiquidAI.html), and [Arcee Foundation Models (AFM)](https://docs.axolotl.ai/docs/models/arcee.html).\n  - FP8 finetuning with fp8 gather op is now possible in Axolotl via `torchao`. Get started [here](https://docs.axolotl.ai/docs/mixed_precision.html#sec-fp8)!\n  - [Voxtral](https://docs.axolotl.ai/docs/models/voxtral.html), [Magistral 1.1](https://docs.axolotl.ai/docs/models/magistral.html), and [Devstral](https://docs.axolotl.ai/docs/models/devstral.html) with mistral-common tokenizer support has been integrated in Axolotl!\n  - TiledMLP support for single-GPU to multi-GPU training with DDP, DeepSpeed and FSDP support has been added to support Arctic Long Sequence Training. (ALST). See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/alst) for using ALST with Axolotl!\n- 2025/06: Magistral with mistral-common tokenizer support has been added to Axolotl. See [docs](https://docs.axolotl.ai/docs/models/magistral.html) to start training your own Magistral models with Axolotl!\n- 2025/05: Quantization Aware Training (QAT) support has been added to Axolotl. Explore the [docs](https://docs.axolotl.ai/docs/qat.html) to learn more!\n- 2025/04: Llama 4 support has been added in Axolotl. See [docs](https://docs.axolotl.ai/docs/models/llama-4.html) to start training your own Llama 4 models with Axolotl's linearized version!\n- 2025/03: Axolotl has implemented Sequence Parallelism (SP) support. Read the [blog](https://huggingface.co/blog/axolotl-ai-co/long-context-with-sequence-parallelism-in-axolotl) and [docs](https://docs.axolotl.ai/docs/sequence_parallelism.html) to learn how to scale your context length when fine-tuning.\n- 2025/03: (Beta) Fine-tuning Multimodal models is now supported in Axolotl. Check out the [docs](https://docs.axolotl.ai/docs/multimodal.html) to fine-tune your own!\n- 2025/02: Axolotl has added LoRA optimizations to reduce memory usage and improve training speed for LoRA and QLoRA in single GPU and multi-GPU training (DDP and DeepSpeed). Jump into the [docs](https://docs.axolotl.ai/docs/lora_optims.html) to give it a try.\n- 2025/02: Axolotl has added GRPO support. Dive into our [blog](https://huggingface.co/blog/axolotl-ai-co/training-llms-w-interpreter-feedback-wasm) and [GRPO example](https://github.com/axolotl-ai-cloud/grpo_code) and have some fun!\n- 2025/01: Axolotl has added Reward Modelling / Process Reward Modelling fine-tuning support. See [docs](https://docs.axolotl.ai/docs/reward_modelling.html).\n\n</details>\n\n## ✨ Overview\n\nAxolotl is a free and open-source tool designed to streamline post-training and fine-tuning for the latest large language models (LLMs).\n\nFeatures:\n\n- **Multiple Model Support**: Train various models like GPT-OSS, LLaMA, Mistral, Mixtral, Pythia, and many more models available on the Hugging Face Hub.\n- **Multimodal Training**: Fine-tune vision-language models (VLMs) including LLaMA-Vision, Qwen2-VL, Pixtral, LLaVA, SmolVLM2, GLM-4.6V, InternVL 3.5, Gemma 3n, and audio models like Voxtral with image, video, and audio support.\n- **Training Methods**: Full fine-tuning, LoRA, QLoRA, GPTQ, QAT, Preference Tuning (DPO, IPO, KTO, ORPO), RL (GRPO, GDPO), and Reward Modelling (RM) / Process Reward Modelling (PRM).\n- **Easy Configuration**: Re-use a single YAML configuration file across the full fine-tuning pipeline: dataset preprocessing, training, evaluation, quantization, and inference.\n- **Performance Optimizations**: [Multipacking](https://docs.axolotl.ai/docs/multipack.html), [Flash Attention 2/3/4](https://docs.axolotl.ai/docs/attention.html#flash-attention), [Xformers](https://docs.axolotl.ai/docs/attention.html#xformers), [Flex Attention](https://docs.axolotl.ai/docs/attention.html#flex-attention), [SageAttention](https://docs.axolotl.ai/docs/attention.html#sageattention), [Liger Kernel](https://docs.axolotl.ai/docs/custom_integrations.html#liger-kernels), [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy), [ScatterMoE](https://docs.axolotl.ai/docs/custom_integrations.html#kernels-integration), [Sequence Parallelism (SP)](https://docs.axolotl.ai/docs/sequence_parallelism.html), [LoRA optimizations](https://docs.axolotl.ai/docs/lora_optims.html), [Multi-GPU training (FSDP1, FSDP2, DeepSpeed)](https://docs.axolotl.ai/docs/multi-gpu.html), [Multi-node training (Torchrun, Ray)](https://docs.axolotl.ai/docs/multi-node.html), and many more!\n- **Flexible Dataset Handling**: Load from local, HuggingFace, and cloud (S3, Azure, GCP, OCI) datasets.\n- **Cloud Ready**: We ship [Docker images](https://hub.docker.com/u/axolotlai) and also [PyPI packages](https://pypi.org/project/axolotl/) for use on cloud platforms and local hardware.\n\n\n\n## 🚀 Quick Start - LLM Fine-tuning in Minutes\n\n**Requirements**:\n\n- NVIDIA GPU (Ampere or newer for `bf16` and Flash Attention) or AMD GPU\n- Python 3.11\n- PyTorch ≥2.8.0\n\n### Google Colab\n\n[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/axolotl-ai-cloud/axolotl/blob/main/examples/colab-notebooks/colab-axolotl-example.ipynb#scrollTo=msOCO4NRmRLa)\n\n### Installation\n\n#### Using pip\n\n```bash\npip3 install -U packaging==26.0 setuptools==75.8.0 wheel ninja\npip3 install --no-build-isolation axolotl[flash-attn,deepspeed]\n\n# Download example axolotl configs, deepspeed configs\naxolotl fetch examples\naxolotl fetch deepspeed_configs  # OPTIONAL\n```\n\n#### Using Docker\n\nInstalling with Docker can be less error prone than installing in your own environment.\n```bash\ndocker run --gpus '\"all\"' --rm -it axolotlai/axolotl:main-latest\n```\n\nOther installation approaches are described [here](https://docs.axolotl.ai/docs/installation.html).\n\n#### Cloud Providers\n\n<details>\n\n- [RunPod](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)\n- [Vast.ai](https://cloud.vast.ai?ref_id=62897&template_id=bdd4a49fa8bce926defc99471864cace&utm_source=github&utm_medium=developer_community&utm_campaign=template_launch_axolotl&utm_content=readme)\n- [PRIME Intellect](https://app.primeintellect.ai/dashboard/create-cluster?image=axolotl&location=Cheapest&security=Cheapest&show_spot=true)\n- [Modal](https://www.modal.com?utm_source=github&utm_medium=github&utm_campaign=axolotl)\n- [Novita](https://novita.ai/gpus-console?templateId=311)\n- [JarvisLabs.ai](https://jarvislabs.ai/templates/axolotl)\n- [Latitude.sh](https://latitude.sh/blueprint/989e0e79-3bf6-41ea-a46b-1f246e309d5c)\n\n</details>\n\n### Your First Fine-tune\n\n```bash\n# Fetch axolotl examples\naxolotl fetch examples\n\n# Or, specify a custom path\naxolotl fetch examples --dest path/to/folder\n\n# Train a model using LoRA\naxolotl train examples/llama-3/lora-1b.yml\n```\n\nThat's it! Check out our [Getting Started Guide](https://docs.axolotl.ai/docs/getting-started.html) for a more detailed walkthrough.\n\n\n## 📚 Documentation\n\n- [Installation Options](https://docs.axolotl.ai/docs/installation.html) - Detailed setup instructions for different environments\n- [Configuration Guide](https://docs.axolotl.ai/docs/config-reference.html) - Full configuration options and examples\n- [Dataset Loading](https://docs.axolotl.ai/docs/dataset_loading.html) - Loading datasets from various sources\n- [Dataset Guide](https://docs.axolotl.ai/docs/dataset-formats/) - Supported formats and how to use them\n- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)\n- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)\n- [Multipacking](https://docs.axolotl.ai/docs/multipack.html)\n- [API Reference](https://docs.axolotl.ai/docs/api/) - Auto-generated code documentation\n- [FAQ](https://docs.axolotl.ai/docs/faq.html) - Frequently asked questions\n\n## 🤝 Getting Help\n\n- Join our [Discord community](https://discord.gg/HhrNrHJPRb) for support\n- Check out our [Examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/) directory\n- Read our [Debugging Guide](https://docs.axolotl.ai/docs/debugging.html)\n- Need dedicated support? Please contact [✉️wing@axolotl.ai](mailto:wing@axolotl.ai) for options\n\n## 🌟 Contributing\n\nContributions are welcome! Please see our [Contributing Guide](https://github.com/axolotl-ai-cloud/axolotl/blob/main/.github/CONTRIBUTING.md) for details.\n\n## 📈 Telemetry\n\nAxolotl has opt-out telemetry that helps us understand how the project is being used\nand prioritize improvements. We collect basic system information, model types, and\nerror rates—never personal data or file paths. Telemetry is enabled by default. To\ndisable it, set AXOLOTL_DO_NOT_TRACK=1. For more details, see our [telemetry documentation](https://docs.axolotl.ai/docs/telemetry.html).\n\n## ❤️ Sponsors\n\nInterested in sponsoring? Contact us at [wing@axolotl.ai](mailto:wing@axolotl.ai)\n\n## 📝 Citing Axolotl\n\nIf you use Axolotl in your research or projects, please cite it as follows:\n\n```bibtex\n@software{axolotl,\n  title = {Axolotl: Open Source LLM Post-Training},\n  author = {{Axolotl maintainers and contributors}},\n  url = {https://github.com/axolotl-ai-cloud/axolotl},\n  license = {Apache-2.0},\n  year = {2023}\n}\n```\n\n## 📜 License\n\nThis project is licensed under the Apache 2.0 License - see the [LICENSE](LICENSE) file for details.\n"
  },
  {
    "path": "VERSION",
    "content": "0.16.0.dev0\n"
  },
  {
    "path": "_quarto.yml",
    "content": "project:\n  type: website\n  pre-render:\n   - docs/scripts/generate_config_docs.py\n   - docs/scripts/generate_examples_docs.py\n\nquartodoc:\n  dir: docs/api\n  package: axolotl\n  title: API Reference\n  parser: google\n\n  sections:\n    - title: Core\n      desc: Core functionality for training\n      contents:\n        - train\n        - evaluate\n        - datasets\n        - convert\n        - prompt_tokenizers\n        - logging_config\n        - core.builders.base\n        - core.builders.causal\n        - core.builders.rl\n        - core.training_args\n        - core.chat.messages\n        - core.chat.format.chatml\n        - core.chat.format.llama3x\n        - core.chat.format.shared\n        - core.datasets.chat\n        - core.datasets.transforms.chat_builder\n    - title: CLI\n      desc: Command-line interface\n      contents:\n        - cli.main\n        - cli.train\n        - cli.evaluate\n        - cli.args\n        - cli.art\n        - cli.checks\n        - cli.config\n        - cli.delinearize_llama4\n        - cli.inference\n        - cli.merge_lora\n        - cli.merge_sharded_fsdp_weights\n        - cli.preprocess\n        - cli.quantize\n        - cli.vllm_serve\n        - cli.cloud.base\n        - cli.cloud.modal_\n        - cli.utils\n        - cli.utils.args\n        - cli.utils.fetch\n        - cli.utils.load\n        - cli.utils.sweeps\n        - cli.utils.train\n    - title: Trainers\n      desc: Training implementations\n      contents:\n        - core.trainers.base\n        - core.trainers.trl\n        - core.trainers.mamba\n        - core.trainers.dpo.trainer\n        - core.trainers.grpo.trainer\n        - core.trainers.grpo.sampler\n        - core.trainers.utils\n    - title: Model Loading\n      desc: Functionality for loading and patching models, tokenizers, etc.\n      contents:\n        - loaders.model\n        - loaders.tokenizer\n        - loaders.processor\n        - loaders.adapter\n        - loaders.patch_manager\n        - loaders.constants\n    - title: Mixins\n      desc: Mixin classes for augmenting trainers\n      contents:\n        - core.trainers.mixins.optimizer\n        - core.trainers.mixins.rng_state_loader\n        - core.trainers.mixins.scheduler\n    - title: Context Managers\n      desc: Context managers for altering trainer behaviors\n      contents:\n        - utils.ctx_managers.sequence_parallel\n    - title: Prompt Strategies\n      desc: Prompt formatting strategies\n      contents:\n        - prompt_strategies.base\n        - prompt_strategies.chat_template\n        - prompt_strategies.alpaca_chat\n        - prompt_strategies.alpaca_instruct\n        - prompt_strategies.alpaca_w_system\n        - prompt_strategies.user_defined\n        - prompt_strategies.llama2_chat\n        - prompt_strategies.completion\n        - prompt_strategies.input_output\n        - prompt_strategies.stepwise_supervised\n        - prompt_strategies.metharme\n        - prompt_strategies.orcamini\n        - prompt_strategies.pygmalion\n        - prompt_strategies.messages.chat\n        - prompt_strategies.dpo.chat_template\n        - prompt_strategies.dpo.llama3\n        - prompt_strategies.dpo.chatml\n        - prompt_strategies.dpo.zephyr\n        - prompt_strategies.dpo.user_defined\n        - prompt_strategies.dpo.passthrough\n        - prompt_strategies.kto.llama3\n        - prompt_strategies.kto.chatml\n        - prompt_strategies.kto.user_defined\n        - prompt_strategies.orpo.chat_template\n        - prompt_strategies.bradley_terry.llama3\n    - title: Kernels\n      desc: Low-level performance optimizations\n      contents:\n        - kernels.lora\n        - kernels.geglu\n        - kernels.swiglu\n        - kernels.quantize\n        - kernels.utils\n    - title: Monkey Patches\n      desc: Runtime patches for model optimizations\n      contents:\n        - monkeypatch.llama_attn_hijack_flash\n        - monkeypatch.llama_attn_hijack_xformers\n        - monkeypatch.mistral_attn_hijack_flash\n        - monkeypatch.multipack\n        - monkeypatch.relora\n        - monkeypatch.lora_kernels\n        - monkeypatch.utils\n        - monkeypatch.btlm_attn_hijack_flash\n        - monkeypatch.stablelm_attn_hijack_flash\n        - monkeypatch.trainer_fsdp_optim\n        - monkeypatch.transformers_fa_utils\n        - monkeypatch.unsloth_\n        - monkeypatch.data.batch_dataset_fetcher\n        - monkeypatch.mixtral\n        - monkeypatch.gradient_checkpointing.offload_cpu\n        - monkeypatch.gradient_checkpointing.offload_disk\n    - title: Utils\n      desc: Utility functions\n      contents:\n        - utils.tokenization\n        - utils.chat_templates\n        - utils.lora\n        - utils.model_shard_quant\n        - utils.bench\n        - utils.freeze\n        - utils.trainer\n        - utils.schedulers\n        - utils.distributed\n        - utils.dict\n        - utils.optimizers.adopt\n        - utils.data.streaming\n        - utils.data.sft\n        - utils.quantization\n    - title: Schemas\n      desc: Pydantic data models for Axolotl config\n      contents:\n        - utils.schemas.config\n        - utils.schemas.model\n        - utils.schemas.training\n        - utils.schemas.datasets\n        - utils.schemas.peft\n        - utils.schemas.trl\n        - utils.schemas.multimodal\n        - utils.schemas.integrations\n        - utils.schemas.enums\n        - utils.schemas.utils\n    - title: Integrations\n      desc: Third-party integrations and extensions\n      contents:\n        - integrations.base\n        - integrations.cut_cross_entropy.args\n        - integrations.grokfast.optimizer\n        - integrations.kd.trainer\n        - integrations.liger.args\n        - integrations.lm_eval.args\n        - integrations.spectrum.args\n    - title: Common\n      desc: Common utilities and shared functionality\n      contents:\n        - common.architectures\n        - common.const\n        - common.datasets\n    - title: Models\n      desc: Custom model implementations\n      contents:\n        - models.mamba.modeling_mamba\n    - title: Data Processing\n      desc: Data processing utilities\n      contents:\n        - utils.collators.core\n        - utils.collators.batching\n        - utils.collators.mamba\n        - utils.collators.mm_chat\n        - utils.samplers.multipack\n    - title: Callbacks\n      desc: Training callbacks\n      contents:\n        - utils.callbacks.perplexity\n        - utils.callbacks.profiler\n        - utils.callbacks.lisa\n        - utils.callbacks.mlflow_\n        - utils.callbacks.comet_\n        - utils.callbacks.qat\nwebsite:\n  title: \"Axolotl\"\n  description: \"We make fine-tuning accessible, scalable, and fun\"\n  favicon: favicon.jpg\n\n  google-analytics: \"G-9KYCVJBNMQ\"\n\n  navbar:\n    logo: image/axolotl_logo_digital_white.svg\n    title: false\n    background: dark\n    pinned: false\n    collapse: false\n    tools:\n    - icon: twitter\n      href: https://twitter.com/axolotl_ai\n    - icon: github\n      href: https://github.com/axolotl-ai-cloud/axolotl/\n    - icon: discord\n      href: https://discord.gg/7m9sfhzaf3\n\n  sidebar:\n      pinned: true\n      collapse-level: 2\n      style: docked\n      contents:\n        - text: Home\n          href: index.qmd\n\n        - section: \"Getting Started\"\n          contents:\n            - docs/getting-started.qmd\n            - docs/installation.qmd\n            - docs/inference.qmd\n            - section: \"Model Guides\"\n              contents:\n                - docs/models/kimi-linear.qmd\n                - docs/models/plano.qmd\n                - docs/models/mimo.qmd\n                - docs/models/internvl3_5.qmd\n                - docs/models/olmo3.qmd\n                - docs/models/trinity.qmd\n                - docs/models/arcee.qmd\n                - section: \"Ministral3\"\n                  contents:\n                    - docs/models/ministral3.qmd\n                    - docs/models/ministral3/think.qmd\n                    - docs/models/ministral3/vision.qmd\n                - section: \"Magistral\"\n                  contents:\n                    - docs/models/magistral.qmd\n                    - docs/models/magistral/think.qmd\n                    - docs/models/magistral/vision.qmd\n                - docs/models/ministral.qmd\n                - docs/models/mistral-small.qmd\n                - docs/models/voxtral.qmd\n                - docs/models/devstral.qmd\n                - docs/models/mistral.qmd\n                - docs/models/llama-4.qmd\n                - docs/models/llama-2.qmd\n                - docs/models/qwen3-next.qmd\n                - docs/models/qwen3.qmd\n                - docs/models/gemma3n.qmd\n                - docs/models/apertus.qmd\n                - docs/models/gpt-oss.qmd\n                - docs/models/seed-oss.qmd\n                - docs/models/phi.qmd\n                - docs/models/smolvlm2.qmd\n                - docs/models/granite4.qmd\n                - docs/models/LiquidAI.qmd\n                - docs/models/hunyuan.qmd\n                - docs/models/jamba.qmd\n                - docs/models/orpheus.qmd\n\n            - docs/cli.qmd\n            - docs/telemetry.qmd\n            - docs/config-reference.qmd\n            - text: \"API Reference\"\n              href: docs/api\n\n        - section: \"Dataset Formats\"\n          contents: docs/dataset-formats/*\n\n        - section: \"Deployments\"\n          contents:\n            - docs/docker.qmd\n            - docs/multi-gpu.qmd\n            - docs/multi-node.qmd\n            - docs/ray-integration.qmd\n            - docs/amd_hpc.qmd\n            - docs/mac.qmd\n\n        - section: \"How To Guides\"\n          contents:\n            - docs/multimodal.qmd\n            - docs/rlhf.qmd\n            - docs/reward_modelling.qmd\n            - docs/lr_groups.qmd\n            - docs/lora_optims.qmd\n            - docs/dataset_loading.qmd\n            - docs/qat.qmd\n            - docs/quantize.qmd\n            - docs/optimizations.qmd\n\n        - section: \"Core Concepts\"\n          contents:\n            - docs/batch_vs_grad.qmd\n            - docs/dataset_preprocessing.qmd\n            - docs/streaming.qmd\n            - docs/multipack.qmd\n            - docs/mixed_precision.qmd\n            - docs/optimizers.qmd\n            - docs/attention.qmd\n\n        - section: \"Advanced Features\"\n          contents:\n            - docs/fsdp_qlora.qmd\n            - docs/unsloth.qmd\n            - docs/torchao.qmd\n            - docs/custom_integrations.qmd\n            - docs/sequence_parallelism.qmd\n            - docs/gradient_checkpointing.qmd\n            - docs/nd_parallelism.qmd\n            - docs/expert_quantization.qmd\n\n        - section: \"Troubleshooting\"\n          contents:\n            - docs/faq.qmd\n            - docs/debugging.qmd\n            - docs/nccl.qmd\n\nformat:\n  html:\n    theme: darkly\n    css: styles.css\n    toc: true\n    # Enable better handling of line breaks in markdown\n    preserve-tabs: true\n    html-math-method: mathjax\n    # Improved markdown processing options\n    md-extensions:\n      - markdown_it\n      - def_list\n      - attr_list\n      - fenced_divs\n      - tables\n      - html_admonition\n      - lineblocks\n      - fancy_lists\n    # Control whitespace handling\n    whitespace: preserve\n    # Process newlines in paragraphs\n    wrap: preserve\n    # Better line break handling\n    preserve-linebreaks: true\n"
  },
  {
    "path": "benchmarks/bench_entropy.py",
    "content": "\"\"\"Benchmark for entropy_from_logits Triton kernel vs original chunked implementation.\n\nUsage: CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_entropy.py\n\"\"\"\n\nimport gc\nimport statistics\n\nimport torch\nimport torch.nn.functional as F\n\nfrom axolotl.monkeypatch.trainer.utils import entropy_from_logits\n\nV = 151936  # Qwen vocab\nWARMUP = 5\nBENCH_ITERS = 20\nMEM_ITERS = 10\n\n\ndef entropy_from_logits_original(logits: torch.Tensor, chunk_size: int = 128):\n    \"\"\"Original chunked implementation (reference).\"\"\"\n    original_shape = logits.shape[:-1]\n    num_classes = logits.shape[-1]\n    flat_logits = logits.reshape(-1, num_classes)\n    entropies = []\n    for chunk in flat_logits.split(chunk_size, dim=0):\n        logps = F.log_softmax(chunk, dim=-1)\n        chunk_entropy = -(torch.exp(logps) * logps).sum(-1)\n        entropies.append(chunk_entropy)\n    return torch.cat(entropies, dim=0).reshape(original_shape)\n\n\ndef _clean_gpu():\n    gc.collect()\n    torch.cuda.empty_cache()\n    torch.cuda.reset_peak_memory_stats()\n    torch.cuda.reset_accumulated_memory_stats()\n    torch.cuda.synchronize()\n\n\ndef profile_time(fn, logits, n_iters=BENCH_ITERS):\n    for _ in range(WARMUP):\n        out = fn(logits, chunk_size=128)\n        del out\n    torch.cuda.synchronize()\n\n    times = []\n    for _ in range(n_iters):\n        s = torch.cuda.Event(enable_timing=True)\n        e = torch.cuda.Event(enable_timing=True)\n        s.record()\n        out = fn(logits, chunk_size=128)\n        e.record()\n        torch.cuda.synchronize()\n        times.append(s.elapsed_time(e))\n        del out\n    return times\n\n\ndef profile_memory(fn, logits, n_iters=MEM_ITERS):\n    for _ in range(WARMUP):\n        out = fn(logits, chunk_size=128)\n        del out\n    torch.cuda.synchronize()\n\n    peaks = []\n    for _ in range(n_iters):\n        _clean_gpu()\n        base = torch.cuda.max_memory_allocated()\n        out = fn(logits, chunk_size=128)\n        torch.cuda.synchronize()\n        peaks.append(torch.cuda.max_memory_allocated() - base)\n        del out\n    return [p / 1e6 for p in peaks]\n\n\ndef fmt(values, unit=\"\"):\n    mean = statistics.mean(values)\n    std = statistics.stdev(values) if len(values) > 1 else 0.0\n    return f\"{mean:8.2f} ± {std:5.2f} {unit}  [min={min(values):.2f}, max={max(values):.2f}]\"\n\n\ndef benchmark_contiguous():\n    print(\"=\" * 60)\n    print(\n        f\"CONTIGUOUS BENCHMARK  (warmup={WARMUP}, time={BENCH_ITERS}, mem={MEM_ITERS})\"\n    )\n    print(\"=\" * 60)\n\n    configs = [\n        (1, 2048),\n        (1, 8192),\n        (1, 16384),\n        (4, 4096),\n        (8, 2048),\n        (16, 2048),\n        (16, 4096),\n    ]\n\n    for B, L in configs:\n        mem_gb = B * L * V * 2 / 1e9\n        if mem_gb > 28:\n            print(f\"\\n  skip B={B}, L={L} ({mem_gb:.1f} GB)\")\n            continue\n\n        N = B * L\n        print(f\"\\n{'─' * 60}\")\n        print(f\"B={B:2d}, L={L:5d}  ({N:6d} rows, logits {mem_gb:.2f} GB)\")\n        print(f\"{'─' * 60}\")\n\n        torch.manual_seed(42)\n        logits = torch.randn(B, L, V, device=\"cuda\", dtype=torch.bfloat16)\n\n        t_orig = profile_time(entropy_from_logits_original, logits)\n        t_triton = profile_time(entropy_from_logits, logits)\n        orig_mean = statistics.mean(t_orig)\n        triton_mean = statistics.mean(t_triton)\n\n        print(\"  TIME (ms):\")\n        print(f\"    original: {fmt(t_orig, 'ms')}\")\n        print(f\"    triton:   {fmt(t_triton, 'ms')}\")\n        print(f\"    speedup:  {orig_mean / triton_mean:.2f}x\")\n\n        m_orig = profile_memory(entropy_from_logits_original, logits)\n        m_triton = profile_memory(entropy_from_logits, logits)\n        orig_peak = statistics.mean(m_orig)\n        triton_peak = statistics.mean(m_triton)\n\n        print(\"  MEMORY (peak overhead):\")\n        print(f\"    original: {fmt(m_orig, 'MB')}\")\n        print(f\"    triton:   {fmt(m_triton, 'MB')}\")\n        print(f\"    saved:    {orig_peak - triton_peak:.1f} MB\")\n\n        del logits\n        _clean_gpu()\n\n\ndef benchmark_noncontiguous():\n    print(\"\\n\" + \"=\" * 60)\n    print(\n        f\"NON-CONTIGUOUS BENCHMARK  (warmup={WARMUP}, time={BENCH_ITERS}, mem={MEM_ITERS})\"\n    )\n    print(\"=\" * 60)\n\n    configs = [\n        (4, 2048, \"transpose\"),\n        (4, 8192, \"transpose\"),\n        (8, 2048, \"transpose\"),\n        (4, 4096, \"slice_batch\"),\n    ]\n\n    for B, L, method in configs:\n        torch.manual_seed(42)\n\n        if method == \"transpose\":\n            raw = torch.randn(L, B, V, device=\"cuda\", dtype=torch.bfloat16)\n            logits_nc = raw.transpose(0, 1)\n            raw_gb = L * B * V * 2 / 1e9\n        elif method == \"slice_batch\":\n            raw = torch.randn(B * 2, L, V, device=\"cuda\", dtype=torch.bfloat16)\n            logits_nc = raw[::2]\n            raw_gb = B * 2 * L * V * 2 / 1e9\n        else:\n            continue\n\n        if raw_gb > 28:\n            print(f\"\\n  skip B={B}, L={L}, {method} ({raw_gb:.1f} GB)\")\n            del raw, logits_nc\n            torch.cuda.empty_cache()\n            continue\n\n        N = B * L\n        print(f\"\\n{'─' * 60}\")\n        print(f\"B={B}, L={L}  {method}  ({N} rows, raw {raw_gb:.2f} GB)\")\n        print(f\"{'─' * 60}\")\n\n        def original_with_copy(logits, chunk_size=128):\n            return entropy_from_logits_original(\n                logits.contiguous(), chunk_size=chunk_size\n            )\n\n        t_orig = profile_time(original_with_copy, logits_nc)\n        t_triton = profile_time(entropy_from_logits, logits_nc)\n        orig_mean = statistics.mean(t_orig)\n        triton_mean = statistics.mean(t_triton)\n\n        print(\"  TIME (ms):\")\n        print(f\"    orig+copy:     {fmt(t_orig, 'ms')}\")\n        print(f\"    triton-strided:{fmt(t_triton, 'ms')}\")\n        print(f\"    speedup:       {orig_mean / triton_mean:.2f}x\")\n\n        m_orig = profile_memory(original_with_copy, logits_nc)\n        m_triton = profile_memory(entropy_from_logits, logits_nc)\n        orig_peak = statistics.mean(m_orig)\n        triton_peak = statistics.mean(m_triton)\n\n        print(\"  MEMORY (peak overhead):\")\n        print(f\"    orig+copy:     {fmt(m_orig, 'MB')}\")\n        print(f\"    triton-strided:{fmt(m_triton, 'MB')}\")\n        print(f\"    saved:         {orig_peak - triton_peak:.1f} MB\")\n\n        del raw, logits_nc\n        _clean_gpu()\n\n\nif __name__ == \"__main__\":\n    benchmark_contiguous()\n    benchmark_noncontiguous()\n"
  },
  {
    "path": "benchmarks/bench_scattermoe_lora.py",
    "content": "\"\"\"Benchmark for ScatterMoE LoRA Triton kernels.\n\nMeasures forward, backward dX, and backward dA/dB kernels at common MoE\nmodel shapes. Reports per-kernel timings, LoRA overhead vs base scatter2scatter,\nand full fwd+bwd autograd throughput.\n\nUsage:\n  CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_scattermoe_lora.py\n  CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_scattermoe_lora.py --ranks 16 64\n  CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_scattermoe_lora.py --models Qwen/Qwen3.5-35B-A3B\n\"\"\"\n\nimport argparse\nimport gc\nimport time\nfrom functools import partial\n\nimport torch\n\nfrom axolotl.integrations.kernels.libs.scattermoe_lora.kernels import (\n    lora_ops,\n    ops as base_ops,\n)\nfrom axolotl.integrations.kernels.libs.scattermoe_lora.parallel_experts import (\n    flatten_sort_count,\n)\nfrom axolotl.integrations.kernels.libs.scattermoe_lora.parallel_linear_lora import (\n    ScatterMoELoRA,\n)\n\nDEVICE = \"cuda\"\nDTYPE = torch.bfloat16\nWARMUP = 5\nITERS = 20\n\n# ─── Model configs ──────────────────────────────────────────────────────────\n\nBUILTIN_CONFIGS = {\n    \"Qwen3.5-35B-A3B\": (256, 2048, 512, 8),  # E, H, I, k\n    \"Qwen3-30B-A3B\": (128, 2048, 768, 8),\n    \"OLMoE-1B-7B\": (64, 2048, 1024, 8),\n    \"Mixtral-8x7B\": (8, 4096, 14336, 2),\n}\n\n\ndef _resolve_config(spec):\n    \"\"\"Resolve a model spec to (E, H, I, k). Accepts builtin names or HF IDs.\"\"\"\n    key = spec.lower().replace(\"/\", \"-\")\n    for name, cfg in BUILTIN_CONFIGS.items():\n        if key in name.lower() or name.lower() in key:\n            return name, cfg\n\n    from transformers import AutoConfig\n\n    hf_cfg = AutoConfig.from_pretrained(spec, trust_remote_code=True)\n    if callable(getattr(hf_cfg, \"get_text_config\", None)):\n        tc = hf_cfg.get_text_config()\n        if hasattr(tc, \"model_type\") and tc.model_type != hf_cfg.model_type:\n            hf_cfg = tc\n    hidden = hf_cfg.hidden_size\n    inter = getattr(hf_cfg, \"moe_intermediate_size\", None) or hf_cfg.intermediate_size\n    experts = (\n        getattr(hf_cfg, \"num_experts\", None)\n        or getattr(hf_cfg, \"num_local_experts\", None)\n        or getattr(hf_cfg, \"n_routed_experts\", None)\n    )\n    top_k = (\n        getattr(hf_cfg, \"num_experts_per_tok\", None)\n        or getattr(hf_cfg, \"num_experts_per_token\", None)\n        or 2\n    )\n    name = spec.split(\"/\")[-1]\n    return name, (experts, hidden, inter, top_k)\n\n\n# ─── Benchmark helpers ──────────────────────────────────────────────────────\n\n\ndef _clean():\n    gc.collect()\n    torch.cuda.empty_cache()\n    torch.cuda.synchronize()\n\n\ndef _bench(fn, warmup=WARMUP, iters=ITERS):\n    for _ in range(warmup):\n        fn()\n    torch.cuda.synchronize()\n    times = []\n    for _ in range(iters):\n        torch.cuda.synchronize()\n        t0 = time.perf_counter()\n        fn()\n        torch.cuda.synchronize()\n        times.append((time.perf_counter() - t0) * 1000)\n    times.sort()\n    return times[len(times) // 2]\n\n\ndef _setup(num_experts, K, N, T, top_k, R):\n    torch.manual_seed(42)\n    x = torch.randn(T, K, device=DEVICE, dtype=DTYPE)\n    W = torch.randn(num_experts, K, N, device=DEVICE, dtype=DTYPE) * 0.02\n    lora_A = torch.randn(R * num_experts, K, device=DEVICE, dtype=DTYPE) * 0.01\n    lora_B = torch.randn(N, R * num_experts, device=DEVICE, dtype=DTYPE) * 0.01\n    logits = torch.randn(T, num_experts, device=DEVICE)\n    _, top_idx = torch.topk(torch.softmax(logits, dim=-1), top_k, dim=-1)\n    sei, ssi, eo = flatten_sort_count(top_idx, num_experts)\n    gx = base_ops.group(x, ssi, fan_out=top_k)\n    dy = torch.randn(gx.size(0), N, device=DEVICE, dtype=DTYPE)\n    return x, W, lora_A, lora_B, sei, ssi, eo, gx, dy\n\n\n# ─── Kernel wrappers (avoid B023 loop-variable capture) ──────────────────────\n\n\ndef _call_fwd(x, W, sei, ssi, top_k, lA, lB):\n    return lora_ops.scatter2scatter_lora(\n        X=x,\n        W=W,\n        sorted_expert_idxs=sei,\n        sorted_scattered_idxs=ssi,\n        k=top_k,\n        lora_A=lA,\n        lora_B=lB,\n        scaling=2.0,\n    )\n\n\ndef _call_base(x, W, sei, ssi, top_k):\n    return base_ops.scatter2scatter(\n        X=x,\n        W=W,\n        sorted_expert_idxs=sei,\n        sorted_scattered_idxs=ssi,\n        k=top_k,\n    )\n\n\ndef _call_dx(dy, W, sei, ssi, lA, lB):\n    return lora_ops.scatter2scatter_lora_dX(\n        DY=dy,\n        W=W,\n        sorted_expert_idxs=sei,\n        sorted_scattered_idxs=ssi,\n        k=1,\n        lora_A=lA,\n        lora_B=lB,\n        scaling=2.0,\n        dy_grouped=True,\n        dx_grouped=False,\n    )\n\n\ndef _call_bwd(dy, gx, lA, lB, eo, num_experts):\n    return lora_ops.group_bwd_lora(\n        DY=dy,\n        X=gx,\n        lora_A=lA,\n        lora_B=lB,\n        expert_offsets=eo,\n        E=num_experts,\n        scaling=2.0,\n    )\n\n\n# ─── Main ────────────────────────────────────────────────────────────────────\n\n\ndef main():\n    parser = argparse.ArgumentParser(description=\"ScatterMoE LoRA kernel benchmark\")\n    parser.add_argument(\n        \"--models\",\n        \"-m\",\n        nargs=\"+\",\n        help=\"Model names or HF IDs (default: all builtins)\",\n    )\n    parser.add_argument(\"--ranks\", \"-r\", nargs=\"+\", type=int, default=[16, 32, 64])\n    parser.add_argument(\"--seq-len\", \"-T\", type=int, default=2048)\n    args = parser.parse_args()\n\n    T = args.seq_len\n    print(f\"GPU: {torch.cuda.get_device_name()}\")\n    print(f\"T={T}, ranks={args.ranks}\\n\")\n\n    if args.models:\n        configs = [_resolve_config(m) for m in args.models]\n    else:\n        configs = list(BUILTIN_CONFIGS.items())\n\n    for model_name, (num_experts, hidden, inter, top_k) in configs:\n        print(f\"{'=' * 70}\")\n        print(f\"  {model_name}: E={num_experts}, H={hidden}, I={inter}, k={top_k}\")\n        print(f\"{'=' * 70}\")\n\n        for R in args.ranks:\n            for proj, K, N in [(\"gate_up\", hidden, 2 * inter), (\"down\", inter, hidden)]:\n                _clean()\n                x, W, lA, lB, sei, ssi, eo, gx, dy = _setup(\n                    num_experts, K, N, T, top_k, R\n                )\n\n                # Forward with LoRA (auto-dispatched: fused or split)\n                dispatch = (\n                    \"split\"\n                    if (\n                        num_experts <= lora_ops._SPLIT_LORA_FWD_MAX_EXPERTS\n                        and K * N >= lora_ops._SPLIT_LORA_FWD_THRESHOLD\n                    )\n                    else \"fused\"\n                )\n                t_fwd = _bench(partial(_call_fwd, x, W, sei, ssi, top_k, lA, lB))\n                t_base = _bench(partial(_call_base, x, W, sei, ssi, top_k))\n                t_dx = _bench(partial(_call_dx, dy, W, sei, ssi, lA, lB))\n                t_bwd = _bench(partial(_call_bwd, dy, gx, lA, lB, eo, num_experts))\n\n                total = t_fwd + t_dx + t_bwd\n                overhead = t_fwd / t_base - 1 if t_base > 0 else 0\n\n                print(\n                    f\"  R={R:>2} {proj:<8}  \"\n                    f\"fwd={t_fwd:>6.2f}ms [{dispatch}]  \"\n                    f\"base={t_base:>6.2f}ms \"\n                    f\"(+{overhead * 100:.0f}%)  \"\n                    f\"dx={t_dx:>6.2f}ms  bwd={t_bwd:>6.2f}ms  \"\n                    f\"total={total:>6.2f}ms\"\n                )\n\n                # Full autograd fwd+bwd with memory measurement\n                x_ag = x.clone().requires_grad_(True)\n                lA_ag = lA.clone().requires_grad_(True)\n                lB_ag = lB.clone().requires_grad_(True)\n\n                def _run_autograd(\n                    _x=x_ag,\n                    _W=W,\n                    _k=top_k,\n                    _sei=sei,\n                    _ssi=ssi,\n                    _eo=eo,\n                    _lA=lA_ag,\n                    _lB=lB_ag,\n                ):\n                    out = ScatterMoELoRA.apply(\n                        _x,\n                        _W,\n                        _k,\n                        _sei,\n                        _ssi,\n                        _eo,\n                        _lA,\n                        _lB,\n                        2.0,\n                        None,\n                        None,\n                        False,\n                        False,\n                        True,\n                        False,\n                    )\n                    out.sum().backward()\n                    _x.grad = None\n                    _lA.grad = None\n                    _lB.grad = None\n\n                t_full = _bench(_run_autograd)\n\n                _clean()\n                torch.cuda.reset_peak_memory_stats()\n                mem_before = torch.cuda.memory_allocated()\n                _run_autograd()\n                torch.cuda.synchronize()\n                mem_peak = torch.cuda.max_memory_allocated() - mem_before\n\n                print(\n                    f\"         full_fwd_bwd={t_full:>6.2f}ms  \"\n                    f\"peak_delta={mem_peak / 1e6:>6.1f}MB\"\n                )\n\n        print()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "benchmarks/bench_selective_logsoftmax.py",
    "content": "\"\"\"Benchmark for selective_log_softmax Triton kernel vs original implementation.\n\nUsage: CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_selective_logsoftmax.py\n\"\"\"\n\nimport gc\nimport statistics\n\nimport torch\n\nfrom axolotl.monkeypatch.trainer.utils import (\n    selective_log_softmax,\n    selective_log_softmax_original,\n)\n\nV = 151936  # Qwen vocab\nWARMUP = 5\nBENCH_ITERS = 20\nMEM_ITERS = 10\n\n\ndef _clean_gpu():\n    gc.collect()\n    torch.cuda.empty_cache()\n    torch.cuda.reset_peak_memory_stats()\n    torch.cuda.reset_accumulated_memory_stats()\n    torch.cuda.synchronize()\n\n\ndef profile_time(fn, args, n_iters=BENCH_ITERS):\n    for _ in range(WARMUP):\n        fn(*args)\n    torch.cuda.synchronize()\n\n    times = []\n    for _ in range(n_iters):\n        s = torch.cuda.Event(enable_timing=True)\n        e = torch.cuda.Event(enable_timing=True)\n        s.record()\n        fn(*args)\n        e.record()\n        torch.cuda.synchronize()\n        times.append(s.elapsed_time(e))\n    return times\n\n\ndef profile_memory(fn, args, n_iters=MEM_ITERS):\n    for _ in range(WARMUP):\n        out = fn(*args)\n        del out\n    torch.cuda.synchronize()\n\n    peaks = []\n    for _ in range(n_iters):\n        _clean_gpu()\n        base = torch.cuda.max_memory_allocated()\n        out = fn(*args)\n        torch.cuda.synchronize()\n        peaks.append(torch.cuda.max_memory_allocated() - base)\n        del out\n    return [p / 1e6 for p in peaks]\n\n\ndef fmt(values, unit=\"\"):\n    mean = statistics.mean(values)\n    std = statistics.stdev(values) if len(values) > 1 else 0.0\n    return f\"{mean:8.2f} ± {std:5.2f} {unit}  [min={min(values):.2f}, max={max(values):.2f}]\"\n\n\ndef benchmark_forward():\n    print(\"=\" * 60)\n    print(f\"FORWARD BENCHMARK  (warmup={WARMUP}, time={BENCH_ITERS}, mem={MEM_ITERS})\")\n    print(\"=\" * 60)\n\n    configs = [\n        (1, 2048),\n        (1, 8192),\n        (4, 4096),\n        (8, 2048),\n        (16, 2048),\n        (16, 4096),\n    ]\n\n    for B, L in configs:\n        mem_gb = B * L * V * 2 / 1e9\n        if mem_gb > 28:\n            print(f\"\\n  skip B={B}, L={L} ({mem_gb:.1f} GB)\")\n            continue\n\n        N = B * L\n        print(f\"\\n{'─' * 60}\")\n        print(f\"B={B:2d}, L={L:5d}  ({N:6d} rows, logits {mem_gb:.2f} GB)\")\n        print(f\"{'─' * 60}\")\n\n        torch.manual_seed(42)\n        logits = torch.randn(B, L, V, device=\"cuda\", dtype=torch.bfloat16)\n        index = torch.randint(0, V, (B, L), device=\"cuda\")\n\n        t_orig = profile_time(selective_log_softmax_original, (logits, index))\n        t_triton = profile_time(selective_log_softmax, (logits, index))\n        orig_mean = statistics.mean(t_orig)\n        triton_mean = statistics.mean(t_triton)\n\n        print(\"  TIME (ms):\")\n        print(f\"    original: {fmt(t_orig, 'ms')}\")\n        print(f\"    triton:   {fmt(t_triton, 'ms')}\")\n        print(f\"    speedup:  {orig_mean / triton_mean:.2f}x\")\n\n        m_orig = profile_memory(selective_log_softmax_original, (logits, index))\n        m_triton = profile_memory(selective_log_softmax, (logits, index))\n        orig_peak = statistics.mean(m_orig)\n        triton_peak = statistics.mean(m_triton)\n\n        print(\"  MEMORY (peak overhead):\")\n        print(f\"    original: {fmt(m_orig, 'MB')}\")\n        print(f\"    triton:   {fmt(m_triton, 'MB')}\")\n        print(f\"    saved:    {orig_peak - triton_peak:.1f} MB\")\n\n        del logits, index\n        _clean_gpu()\n\n\ndef benchmark_backward():\n    print(\"\\n\" + \"=\" * 60)\n    print(f\"FWD+BWD BENCHMARK  (warmup={WARMUP}, time={BENCH_ITERS}, mem={MEM_ITERS})\")\n    print(\"=\" * 60)\n\n    configs = [\n        (1, 2048),\n        (1, 8192),\n        (4, 4096),\n        (8, 2048),\n        (16, 2048),\n        (16, 4096),\n    ]\n\n    def fwd_bwd_original(logits, index):\n        logits.grad = None\n        out = selective_log_softmax_original(logits, index)\n        out.sum().backward()\n\n    def fwd_bwd_triton(logits, index):\n        logits.grad = None\n        out = selective_log_softmax(logits, index)\n        out.sum().backward()\n\n    for B, L in configs:\n        mem_gb = B * L * V * 2 / 1e9\n        if mem_gb > 20:\n            print(f\"\\n  skip B={B}, L={L} ({mem_gb:.1f} GB, need room for grads)\")\n            continue\n\n        N = B * L\n        print(f\"\\n{'─' * 60}\")\n        print(f\"B={B:2d}, L={L:5d}  ({N:6d} rows, logits {mem_gb:.2f} GB)\")\n        print(f\"{'─' * 60}\")\n\n        torch.manual_seed(42)\n        logits_orig = torch.randn(\n            B, L, V, device=\"cuda\", dtype=torch.bfloat16, requires_grad=True\n        )\n        logits_tri = logits_orig.detach().clone().requires_grad_(True)\n        index = torch.randint(0, V, (B, L), device=\"cuda\")\n\n        t_orig = profile_time(fwd_bwd_original, (logits_orig, index))\n        t_triton = profile_time(fwd_bwd_triton, (logits_tri, index))\n        orig_mean = statistics.mean(t_orig)\n        triton_mean = statistics.mean(t_triton)\n\n        print(\"  FWD+BWD TIME (ms):\")\n        print(f\"    original: {fmt(t_orig, 'ms')}\")\n        print(f\"    triton:   {fmt(t_triton, 'ms')}\")\n        print(f\"    speedup:  {orig_mean / triton_mean:.2f}x\")\n\n        m_orig = profile_memory(fwd_bwd_original, (logits_orig, index))\n        m_triton = profile_memory(fwd_bwd_triton, (logits_tri, index))\n        orig_peak = statistics.mean(m_orig)\n        triton_peak = statistics.mean(m_triton)\n\n        print(\"  FWD+BWD MEMORY (peak overhead):\")\n        print(f\"    original: {fmt(m_orig, 'MB')}\")\n        print(f\"    triton:   {fmt(m_triton, 'MB')}\")\n        print(f\"    saved:    {orig_peak - triton_peak:.1f} MB\")\n\n        del logits_orig, logits_tri, index\n        _clean_gpu()\n\n\nif __name__ == \"__main__\":\n    benchmark_forward()\n    benchmark_backward()\n"
  },
  {
    "path": "cicd/Dockerfile-uv.jinja",
    "content": "FROM axolotlai/axolotl-base-uv:{{ BASE_TAG }}\n\nENV TORCH_CUDA_ARCH_LIST=\"7.0 7.5 8.0 8.6 9.0+PTX\"\nENV AXOLOTL_EXTRAS=\"{{ AXOLOTL_EXTRAS }}\"\nENV AXOLOTL_ARGS=\"{{ AXOLOTL_ARGS }}\"\nENV CUDA=\"{{ CUDA }}\"\nENV PYTORCH_VERSION=\"{{ PYTORCH_VERSION }}\"\nENV GITHUB_REF=\"{{ GITHUB_REF }}\"\nENV GITHUB_SHA=\"{{ GITHUB_SHA }}\"\nENV NIGHTLY_BUILD=\"{{ NIGHTLY_BUILD }}\"\nENV HF_HOME=\"{{ HF_HOME }}\"\n\nRUN apt-get update && \\\n    apt-get install -y --allow-change-held-packages vim curl nano zstd libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm\n\nWORKDIR /workspace\n\nRUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git\n\nWORKDIR /workspace/axolotl\n\nRUN git fetch origin +$GITHUB_REF && \\\n    git checkout FETCH_HEAD\n\n# If AXOLOTL_EXTRAS is set, append it in brackets\nRUN if [ \"$NIGHTLY_BUILD\" = \"true\" ] ; then \\\n        sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt; \\\n        sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt; \\\n        sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt; \\\n        sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt; \\\n        sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \\\n    fi\n\nRUN uv pip install packaging==26.0 setuptools==78.1.1\nRUN uv pip install torchvision\nRUN uv pip uninstall causal_conv1d\nRUN if [ \"$AXOLOTL_EXTRAS\" != \"\" ] ; then \\\n        uv pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \\\n    else \\\n        uv pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \\\n    fi\n\nRUN python scripts/unsloth_install.py --uv | sh\nRUN python scripts/cutcrossentropy_install.py --uv | sh\n\n# So we can test the Docker image\nRUN uv pip install -r requirements-dev.txt -r requirements-tests.txt\n\n# fix so that git fetch/pull from remote works\nRUN git config remote.origin.fetch \"+refs/heads/*:refs/remotes/origin/*\" && \\\n    git config --get remote.origin.fetch\n\n# helper for huggingface-login cli\nRUN git config --global credential.helper store\n"
  },
  {
    "path": "cicd/Dockerfile.jinja",
    "content": "FROM axolotlai/axolotl-base:{{ BASE_TAG }}\n\nENV TORCH_CUDA_ARCH_LIST=\"7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX\"\nENV AXOLOTL_EXTRAS=\"{{ AXOLOTL_EXTRAS }}\"\nENV AXOLOTL_ARGS=\"{{ AXOLOTL_ARGS }}\"\nENV CUDA=\"{{ CUDA }}\"\nENV PYTORCH_VERSION=\"{{ PYTORCH_VERSION }}\"\nENV GITHUB_REF=\"{{ GITHUB_REF }}\"\nENV GITHUB_SHA=\"{{ GITHUB_SHA }}\"\nENV NIGHTLY_BUILD=\"{{ NIGHTLY_BUILD }}\"\nENV HF_HOME=\"{{ HF_HOME }}\"\nENV AXOLOTL_DATASET_NUM_PROC=\"8\"\n\nRUN apt-get update && \\\n    apt-get install -y --allow-change-held-packages vim curl nano zstd libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm\n\nWORKDIR /workspace\n\nRUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git\n\nWORKDIR /workspace/axolotl\n\nRUN git fetch origin +$GITHUB_REF && \\\n    git checkout FETCH_HEAD\n\n# If AXOLOTL_EXTRAS is set, append it in brackets\nRUN if [ \"$NIGHTLY_BUILD\" = \"true\" ] ; then \\\n        sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt; \\\n        sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt; \\\n        sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt; \\\n        sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt; \\\n        sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \\\n    fi\n\nRUN pip install packaging==26.0 setuptools==78.1.1 psutil\nRUN pip uninstall -y causal_conv1d\nRUN if [ \"$AXOLOTL_EXTRAS\" != \"\" ] ; then \\\n        pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \\\n    else \\\n        pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \\\n    fi\n\nRUN python scripts/unsloth_install.py | sh\nRUN python scripts/cutcrossentropy_install.py | sh\n\n# So we can test the Docker image\nRUN pip install -r requirements-dev.txt -r requirements-tests.txt\n\n# fix so that git fetch/pull from remote works\nRUN git config remote.origin.fetch \"+refs/heads/*:refs/remotes/origin/*\" && \\\n    git config --get remote.origin.fetch\n\n# helper for huggingface-login cli\nRUN git config --global credential.helper store\n"
  },
  {
    "path": "cicd/__init__.py",
    "content": ""
  },
  {
    "path": "cicd/cicd.sh",
    "content": "#!/bin/bash\nset -e\n\npython -c \"import torch; assert '$PYTORCH_VERSION' in torch.__version__\"\n\ncurl -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xpf - -C \"${HF_HOME}/hub/\"  --use-compress-program unzstd --strip-components=1\n# hf download \"NousResearch/Meta-Llama-3-8B\"\n# hf download \"NousResearch/Meta-Llama-3-8B-Instruct\"\n# hf download \"microsoft/Phi-4-reasoning\"\n# hf download \"microsoft/Phi-3.5-mini-instruct\"\n# hf download \"microsoft/Phi-3-medium-128k-instruct\"\n\n# Run unit tests with initial coverage report\npytest -v --durations=10 -n8 \\\n  --ignore=tests/e2e/ \\\n  --ignore=tests/patched/ \\\n  --ignore=tests/cli \\\n  /workspace/axolotl/tests/ \\\n  --cov=axolotl\n\n# Run lora kernels tests with coverage append\npytest -v --durations=10 \\\n  /workspace/axolotl/tests/e2e/patched/lora_kernels \\\n  --cov=axolotl \\\n  --cov-append\n\n# Run patched tests excluding lora kernels with coverage append\npytest --full-trace -vvv --durations=10 \\\n  --ignore=tests/e2e/patched/lora_kernels \\\n  /workspace/axolotl/tests/e2e/patched \\\n  --cov=axolotl \\\n  --cov-append\n\n# Run solo tests with coverage append\npytest -v --durations=10 -n1 \\\n  /workspace/axolotl/tests/e2e/solo/ \\\n  --cov=axolotl \\\n  --cov-append\n\n# Run integration tests with coverage append\npytest -v --durations=10 \\\n  /workspace/axolotl/tests/e2e/integrations/ \\\n  --cov=axolotl \\\n  --cov-append\n\npytest -v --durations=10 /workspace/axolotl/tests/cli \\\n  --cov=axolotl \\\n  --cov-append\n\n# Run remaining e2e tests with coverage append and final report\npytest -v --durations=10 \\\n  --ignore=tests/e2e/solo/ \\\n  --ignore=tests/e2e/patched/ \\\n  --ignore=tests/e2e/multigpu/ \\\n  --ignore=tests/e2e/integrations/ \\\n  --ignore=tests/cli \\\n  /workspace/axolotl/tests/e2e/ \\\n  --cov=axolotl \\\n  --cov-append \\\n  --cov-report=xml:e2e-coverage.xml\n\ncodecov upload-process -t $CODECOV_TOKEN -f e2e-coverage.xml -F e2e,pytorch-${PYTORCH_VERSION} || true\n"
  },
  {
    "path": "cicd/cleanup.py",
    "content": "\"\"\"Modal app to run axolotl GPU cleanup\"\"\"\n\nfrom .single_gpu import VOLUME_CONFIG, app, cicd_image, run_cmd\n\n\n@app.function(\n    image=cicd_image,\n    timeout=60 * 60,\n    cpu=8.0,\n    memory=131072,\n    volumes=VOLUME_CONFIG,\n)\ndef cleanup():\n    run_cmd(\"./cicd/cleanup.sh\", \"/workspace/axolotl\")\n\n\n@app.local_entrypoint()\ndef main():\n    cleanup.remote()\n"
  },
  {
    "path": "cicd/cleanup.sh",
    "content": "#!/bin/bash\nset -e\n\n# cleanup old cache files for datasets processing and intermediate mappings\nfind /workspace/data/huggingface-cache/hub/datasets -name \"cache-*\" -type f -mtime +1 -exec rm {} \\;\nfind /workspace/data/huggingface-cache/hub/datasets -name \"*.lock\" -type f -mtime +1 -exec rm {} \\;\n"
  },
  {
    "path": "cicd/e2e_tests.py",
    "content": "\"\"\"Modal app to run axolotl GPU tests\"\"\"\n\nfrom .single_gpu import GPU_CONFIG, VOLUME_CONFIG, app, cicd_image, run_cmd\n\n\n@app.function(\n    image=cicd_image,\n    gpu=GPU_CONFIG,\n    timeout=120 * 60,  # 90 min\n    cpu=8.0,\n    memory=131072,\n    volumes=VOLUME_CONFIG,\n)\ndef cicd_pytest():\n    run_cmd(\"./cicd/cicd.sh\", \"/workspace/axolotl\")\n\n\n@app.local_entrypoint()\ndef main():\n    cicd_pytest.remote()\n"
  },
  {
    "path": "cicd/multigpu.py",
    "content": "\"\"\"\nmodal application to run axolotl gpu tests in Modal\n\"\"\"\n\nimport os\nimport pathlib\nimport tempfile\n\nimport jinja2\nimport modal\nfrom jinja2 import select_autoescape\nfrom modal import App, Image\n\ncicd_path = pathlib.Path(__file__).parent.resolve()\n\ntemplate_loader = jinja2.FileSystemLoader(searchpath=cicd_path)\ntemplate_env = jinja2.Environment(\n    loader=template_loader, autoescape=select_autoescape()\n)\ndockerfile = os.environ.get(\"E2E_DOCKERFILE\", \"Dockerfile.jinja\")\ndf_template = template_env.get_template(dockerfile)\n\ndf_args = {\n    \"AXOLOTL_EXTRAS\": os.environ.get(\"AXOLOTL_EXTRAS\", \"\"),\n    \"AXOLOTL_ARGS\": os.environ.get(\"AXOLOTL_ARGS\", \"\"),\n    \"PYTORCH_VERSION\": os.environ.get(\"PYTORCH_VERSION\", \"2.6.0\"),\n    \"BASE_TAG\": os.environ.get(\"BASE_TAG\", \"main-base-py3.11-cu126-2.6.0\"),\n    \"CUDA\": os.environ.get(\"CUDA\", \"126\"),\n    \"GITHUB_REF\": os.environ.get(\"GITHUB_REF\", \"refs/heads/main\"),\n    \"GITHUB_SHA\": os.environ.get(\"GITHUB_SHA\", \"\"),\n    \"NIGHTLY_BUILD\": os.environ.get(\"NIGHTLY_BUILD\", \"\"),\n    \"CODECOV_TOKEN\": os.environ.get(\"CODECOV_TOKEN\", \"\"),\n    \"HF_HOME\": \"/workspace/data/huggingface-cache/hub\",\n    \"PYTHONUNBUFFERED\": os.environ.get(\"PYTHONUNBUFFERED\", \"1\"),\n    \"DEEPSPEED_LOG_LEVEL\": os.environ.get(\"DEEPSPEED_LOG_LEVEL\", \"WARNING\"),\n}\n\ndockerfile_contents = df_template.render(**df_args)\n\ntemp_dir = tempfile.mkdtemp()\nwith open(pathlib.Path(temp_dir) / \"Dockerfile\", \"w\", encoding=\"utf-8\") as f:\n    f.write(dockerfile_contents)\n\ncicd_image = Image.from_dockerfile(\n    pathlib.Path(temp_dir) / \"Dockerfile\",\n    force_build=True,\n    gpu=\"A10G\",\n).env(df_args)\n\napp = App(\"Axolotl CI/CD\", secrets=[])\n\nhf_cache_volume = modal.Volume.from_name(\n    \"axolotl-ci-hf-hub-cache\", create_if_missing=True\n)\nVOLUME_CONFIG = {\n    \"/workspace/data/huggingface-cache/hub\": hf_cache_volume,\n}\n\nN_GPUS = int(os.environ.get(\"N_GPUS\", 2))\nGPU_CONFIG = f\"H100:{N_GPUS}\"\n\n\ndef run_cmd(cmd: str, run_folder: str):\n    import subprocess  # nosec\n\n    # Propagate errors from subprocess.\n    if exit_code := subprocess.call(cmd.split(), cwd=run_folder):  # nosec\n        exit(exit_code)\n\n\n@app.function(\n    image=cicd_image,\n    gpu=GPU_CONFIG,\n    timeout=120 * 60,\n    cpu=16.0,\n    memory=131072 * N_GPUS,\n    volumes=VOLUME_CONFIG,\n)\ndef cicd_pytest():\n    run_cmd(\"./cicd/multigpu.sh\", \"/workspace/axolotl\")\n\n\n@app.local_entrypoint()\ndef main():\n    cicd_pytest.remote()\n"
  },
  {
    "path": "cicd/multigpu.sh",
    "content": "#!/bin/bash\nset -e\n\n# Only run two tests at a time to avoid OOM on GPU (with coverage collection)\npytest -v --durations=10 -n2 --maxfail=3 \\\n  --ignore=/workspace/axolotl/tests/e2e/multigpu/solo/ \\\n  --ignore=/workspace/axolotl/tests/e2e/multigpu/patched/ \\\n  /workspace/axolotl/tests/e2e/multigpu/ \\\n  --cov=axolotl\n\n# Run solo tests with coverage append\npytest -v --durations=10 -n1 \\\n  /workspace/axolotl/tests/e2e/multigpu/solo/ \\\n  --cov=axolotl \\\n  --cov-append\n\npytest -v  --durations=10 -n1 /workspace/axolotl/tests/e2e/multigpu/patched/ \\\n  --cov=axolotl \\\n  --cov-append \\\n  --cov-report=xml:multigpu-coverage.xml\n\n# Upload coverage to Codecov if CODECOV_TOKEN is available\nif [ -n \"$CODECOV_TOKEN\" ]; then\n  codecov upload-process -t \"${CODECOV_TOKEN}\" -f multigpu-coverage.xml -F multigpu,docker-tests,pytorch-${PYTORCH_VERSION} || true\nfi\n"
  },
  {
    "path": "cicd/single_gpu.py",
    "content": "\"\"\"Modal app to run axolotl GPU tests\"\"\"\n\nimport os\nimport pathlib\nimport tempfile\n\nimport jinja2\nimport modal\nimport modal.experimental\nfrom jinja2 import select_autoescape\nfrom modal import App\n\ncicd_path = pathlib.Path(__file__).parent.resolve()\n\ntemplate_loader = jinja2.FileSystemLoader(searchpath=cicd_path)\ntemplate_env = jinja2.Environment(\n    loader=template_loader, autoescape=select_autoescape()\n)\ndockerfile = os.environ.get(\"E2E_DOCKERFILE\", \"Dockerfile.jinja\")\ndf_template = template_env.get_template(dockerfile)\n\ndf_args = {\n    \"AXOLOTL_EXTRAS\": os.environ.get(\"AXOLOTL_EXTRAS\", \"\"),\n    \"AXOLOTL_ARGS\": os.environ.get(\"AXOLOTL_ARGS\", \"\"),\n    \"PYTORCH_VERSION\": os.environ.get(\"PYTORCH_VERSION\", \"2.6.0\"),\n    \"BASE_TAG\": os.environ.get(\"BASE_TAG\", \"main-base-py3.11-cu126-2.6.0\"),\n    \"CUDA\": os.environ.get(\"CUDA\", \"126\"),\n    \"GITHUB_REF\": os.environ.get(\"GITHUB_REF\", \"refs/heads/main\"),\n    \"GITHUB_SHA\": os.environ.get(\"GITHUB_SHA\", \"\"),\n    \"NIGHTLY_BUILD\": os.environ.get(\"NIGHTLY_BUILD\", \"\"),\n    \"CODECOV_TOKEN\": os.environ.get(\"CODECOV_TOKEN\", \"\"),\n    \"HF_HOME\": \"/workspace/data/huggingface-cache/hub\",\n    \"PYTHONUNBUFFERED\": os.environ.get(\"PYTHONUNBUFFERED\", \"1\"),\n    \"DEEPSPEED_LOG_LEVEL\": os.environ.get(\"DEEPSPEED_LOG_LEVEL\", \"WARNING\"),\n}\n\ndockerfile_contents = df_template.render(**df_args)\n\ntemp_dir = tempfile.mkdtemp()\nwith open(pathlib.Path(temp_dir) / \"Dockerfile\", \"w\", encoding=\"utf-8\") as f:\n    f.write(dockerfile_contents)\n\ncicd_image = modal.experimental.raw_dockerfile_image(\n    pathlib.Path(temp_dir) / \"Dockerfile\",\n    # context_mount=None,\n    force_build=True,\n    # gpu=\"A10G\",\n).env(df_args)\n\napp = App(\"Axolotl CI/CD\", secrets=[])\n\nhf_cache_volume = modal.Volume.from_name(\n    \"axolotl-ci-hf-hub-cache\", create_if_missing=True\n)\nVOLUME_CONFIG = {\n    \"/workspace/data/huggingface-cache/hub\": hf_cache_volume,\n}\n\nN_GPUS = int(os.environ.get(\"N_GPUS\", 1))\nGPU_TYPE = os.environ.get(\"GPU_TYPE\", \"L40S\")\nGPU_CONFIG = f\"{GPU_TYPE}:{N_GPUS}\"\n\n\ndef run_cmd(cmd: str, run_folder: str):\n    import subprocess  # nosec\n\n    sp_env = os.environ.copy()\n    sp_env[\"AXOLOTL_DATASET_NUM_PROC\"] = \"8\"\n\n    # Propagate errors from subprocess.\n    exit_code = subprocess.call(cmd.split(), cwd=run_folder, env=sp_env)  # nosec\n    if exit_code:\n        raise RuntimeError(f\"Command '{cmd}' failed with exit code {exit_code}\")\n"
  },
  {
    "path": "codecov.yml",
    "content": "codecov:\n  require_ci_to_pass: yes\n  notify:\n    wait_for_ci: true\n\ncoverage:\n  precision: 2\n  round: down\n  range: \"70...100\"\n  status:\n    project:\n      default:\n        # basic\n        target: auto\n        threshold: 1%\n        base: auto\n        # advanced\n        branches: null\n        if_no_uploads: error\n        if_not_found: success\n        if_ci_failed: error\n        only_pulls: true\n        flags: null\n        paths: null\n        informational: true\n    patch:\n      default:\n        # basic\n        target: auto\n        threshold: 1%\n        base: auto\n        # advanced\n        branches: null\n        if_no_uploads: error\n        if_not_found: success\n        if_ci_failed: error\n        only_pulls: false\n        flags: null\n        paths: null\n\nparsers:\n  gcov:\n    branch_detection:\n      conditional: yes\n      loop: yes\n      method: no\n      macro: no\n\ncomment:\n  layout: \"reach,diff,flags,files,footer\"\n  behavior: default\n  require_changes: no\n  require_base: no\n  require_head: yes\n\ngithub_checks:\n  annotations: false\n"
  },
  {
    "path": "deepspeed_configs/zero1.json",
    "content": "{\n  \"zero_optimization\": {\n    \"stage\": 1,\n    \"overlap_comm\": true\n  },\n  \"bf16\": {\n    \"enabled\": \"auto\"\n  },\n  \"fp16\": {\n    \"enabled\": \"auto\",\n    \"auto_cast\": false,\n    \"loss_scale\": 0,\n    \"initial_scale_power\": 32,\n    \"loss_scale_window\": 1000,\n    \"hysteresis\": 2,\n    \"min_loss_scale\": 1\n  },\n  \"gradient_accumulation_steps\": \"auto\",\n  \"gradient_clipping\": \"auto\",\n  \"train_batch_size\": \"auto\",\n  \"train_micro_batch_size_per_gpu\": \"auto\",\n  \"wall_clock_breakdown\": false\n}\n"
  },
  {
    "path": "deepspeed_configs/zero1_torch_compile.json",
    "content": "{\n  \"zero_optimization\": {\n    \"stage\": 1,\n    \"overlap_comm\": true\n  },\n  \"bf16\": {\n    \"enabled\": \"auto\"\n  },\n  \"fp16\": {\n    \"enabled\": \"auto\",\n    \"auto_cast\": false,\n    \"loss_scale\": 0,\n    \"initial_scale_power\": 32,\n    \"loss_scale_window\": 1000,\n    \"hysteresis\": 2,\n    \"min_loss_scale\": 1\n  },\n  \"compile\": {\n    \"disable\": false,\n    \"backend\": \"inductor\"\n  },\n  \"gradient_accumulation_steps\": \"auto\",\n  \"gradient_clipping\": \"auto\",\n  \"train_batch_size\": \"auto\",\n  \"train_micro_batch_size_per_gpu\": \"auto\",\n  \"wall_clock_breakdown\": false\n}\n"
  },
  {
    "path": "deepspeed_configs/zero2.json",
    "content": "{\n  \"zero_optimization\": {\n    \"stage\": 2,\n    \"offload_optimizer\": {\n      \"device\": \"cpu\"\n    },\n    \"contiguous_gradients\": true,\n    \"overlap_comm\": true\n  },\n  \"bf16\": {\n    \"enabled\": \"auto\"\n  },\n  \"fp16\": {\n    \"enabled\": \"auto\",\n    \"auto_cast\": false,\n    \"loss_scale\": 0,\n    \"initial_scale_power\": 32,\n    \"loss_scale_window\": 1000,\n    \"hysteresis\": 2,\n    \"min_loss_scale\": 1\n  },\n  \"gradient_accumulation_steps\": \"auto\",\n  \"gradient_clipping\": \"auto\",\n  \"train_batch_size\": \"auto\",\n  \"train_micro_batch_size_per_gpu\": \"auto\",\n  \"wall_clock_breakdown\": false\n}\n"
  },
  {
    "path": "deepspeed_configs/zero2_torch_compile.json",
    "content": "{\n  \"compile\": {\n    \"disable\": false,\n    \"backend\": \"inductor\"\n  },\n  \"zero_optimization\": {\n    \"stage\": 2,\n    \"offload_optimizer\": {\n      \"device\": \"cpu\"\n    },\n    \"contiguous_gradients\": true,\n    \"overlap_comm\": true\n  },\n  \"bf16\": {\n    \"enabled\": \"auto\"\n  },\n  \"fp16\": {\n    \"enabled\": \"auto\",\n    \"auto_cast\": false,\n    \"loss_scale\": 0,\n    \"initial_scale_power\": 32,\n    \"loss_scale_window\": 1000,\n    \"hysteresis\": 2,\n    \"min_loss_scale\": 1\n  },\n  \"gradient_accumulation_steps\": \"auto\",\n  \"gradient_clipping\": \"auto\",\n  \"train_batch_size\": \"auto\",\n  \"train_micro_batch_size_per_gpu\": \"auto\",\n  \"wall_clock_breakdown\": false\n}\n"
  },
  {
    "path": "deepspeed_configs/zero3.json",
    "content": "{\n  \"zero_optimization\": {\n    \"stage\": 3,\n    \"overlap_comm\": true,\n    \"contiguous_gradients\": true,\n    \"sub_group_size\": 0,\n    \"reduce_bucket_size\": \"auto\",\n    \"stage3_prefetch_bucket_size\": \"auto\",\n    \"stage3_param_persistence_threshold\": \"auto\",\n    \"max_live_parameters\": 0,\n    \"max_reuse_distance\": 0,\n    \"gather_16bit_weights_on_model_save\": true\n  },\n  \"bf16\": {\n    \"enabled\": \"auto\"\n  },\n  \"fp16\": {\n    \"enabled\": \"auto\",\n    \"auto_cast\": false,\n    \"loss_scale\": 0,\n    \"initial_scale_power\": 32,\n    \"loss_scale_window\": 1000,\n    \"hysteresis\": 2,\n    \"min_loss_scale\": 1\n  },\n  \"gradient_accumulation_steps\": \"auto\",\n  \"gradient_clipping\": \"auto\",\n  \"train_batch_size\": \"auto\",\n  \"train_micro_batch_size_per_gpu\": \"auto\",\n  \"wall_clock_breakdown\": false\n}\n"
  },
  {
    "path": "deepspeed_configs/zero3_bf16.json",
    "content": "{\n  \"zero_optimization\": {\n    \"stage\": 3,\n    \"overlap_comm\": true,\n    \"contiguous_gradients\": true,\n    \"sub_group_size\": 0,\n    \"reduce_bucket_size\": \"auto\",\n    \"stage3_prefetch_bucket_size\": \"auto\",\n    \"stage3_param_persistence_threshold\": \"auto\",\n    \"max_live_parameters\": 0,\n    \"max_reuse_distance\": 0,\n    \"gather_16bit_weights_on_model_save\": true\n  },\n  \"bf16\": {\n    \"enabled\": true\n  },\n  \"gradient_accumulation_steps\": \"auto\",\n  \"gradient_clipping\": \"auto\",\n  \"train_batch_size\": \"auto\",\n  \"train_micro_batch_size_per_gpu\": \"auto\",\n  \"wall_clock_breakdown\": false\n}\n"
  },
  {
    "path": "deepspeed_configs/zero3_bf16_cpuoffload_all.json",
    "content": "{\n  \"zero_force_ds_cpu_optimizer\": false,\n  \"zero_allow_untested_optimizer\": true,\n  \"zero_optimization\": {\n    \"stage\": 3,\n    \"offload_optimizer\": {\n      \"device\": \"cpu\",\n      \"pin_memory\": true\n    },\n    \"offload_param\": {\n      \"device\": \"cpu\",\n      \"pin_memory\": true\n    },\n    \"overlap_comm\": true,\n    \"contiguous_gradients\": true,\n    \"sub_group_size\": 0,\n    \"reduce_bucket_size\": \"auto\",\n    \"stage3_prefetch_bucket_size\": \"auto\",\n    \"stage3_param_persistence_threshold\": \"auto\",\n    \"max_live_parameters\": 0,\n    \"max_reuse_distance\": 0,\n    \"gather_16bit_weights_on_model_save\": true\n  },\n  \"bf16\": {\n    \"enabled\": true\n  },\n  \"gradient_accumulation_steps\": \"auto\",\n  \"gradient_clipping\": \"auto\",\n  \"train_batch_size\": \"auto\",\n  \"train_micro_batch_size_per_gpu\": \"auto\",\n  \"wall_clock_breakdown\": false\n}\n"
  },
  {
    "path": "deepspeed_configs/zero3_bf16_cpuoffload_params.json",
    "content": "{\n  \"zero_force_ds_cpu_optimizer\": false,\n  \"zero_allow_untested_optimizer\": true,\n  \"zero_optimization\": {\n    \"stage\": 3,\n    \"offload_param\": {\n      \"device\": \"cpu\",\n      \"pin_memory\": true\n    },\n    \"overlap_comm\": true,\n    \"contiguous_gradients\": true,\n    \"sub_group_size\": 0,\n    \"reduce_bucket_size\": \"auto\",\n    \"stage3_prefetch_bucket_size\": \"auto\",\n    \"stage3_param_persistence_threshold\": \"auto\",\n    \"max_live_parameters\": 0,\n    \"max_reuse_distance\": 0,\n    \"gather_16bit_weights_on_model_save\": true\n  },\n  \"bf16\": {\n    \"enabled\": true\n  },\n  \"gradient_accumulation_steps\": \"auto\",\n  \"gradient_clipping\": \"auto\",\n  \"train_batch_size\": \"auto\",\n  \"train_micro_batch_size_per_gpu\": \"auto\",\n  \"wall_clock_breakdown\": false\n}\n"
  },
  {
    "path": "devtools/README.md",
    "content": "This directory contains example config files that might be useful for debugging. Please see [docs/debugging.qmd](../docs/debugging.qmd) for more information.\n"
  },
  {
    "path": "devtools/dev_chat_template.yml",
    "content": "# Example config for debugging the chat_template prompt format\nbase_model: TinyLlama/TinyLlama-1.1B-Chat-v1.0\nmodel_type: LlamaForCausalLM\ntokenizer_type: LlamaTokenizer\n\nload_in_8bit: true\nload_in_4bit: false\n\ndatasets:\n  - path: fozziethebeat/alpaca_messages_2k_test\n    type: chat_template\n    shards: 10\nval_set_size: 0\noutput_dir: temp_debug/axolotl_outputs/model\ndataset_prepared_path: temp_debug/axolotl_outputs/data\ndataset_num_proc: 1\n\nsequence_len: 4096\nsample_packing: false\npad_to_sequence_len: true\n\nadapter: lora\nlora_model_dir:\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_linear: true\nlora_fan_in_fan_out:\n\nmicro_batch_size: 1\nnum_epochs: 1\nmax_steps: 10\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\ntrain_on_inputs: false\ngroup_by_length: false\nbf16: false\nfp16: true\ntf32: false\n\ngradient_checkpointing: true\nlogging_steps: 1\nflash_attention: true\n\nwarmup_steps: 10\nweight_decay: 0.0\n"
  },
  {
    "path": "docker/Dockerfile",
    "content": "ARG BASE_TAG=main-base\nFROM axolotlai/axolotl-base:$BASE_TAG\n\nARG TORCH_CUDA_ARCH_LIST=\"7.0 7.5 8.0 8.6+PTX\"\nARG AXOLOTL_EXTRAS=\"\"\nARG AXOLOTL_ARGS=\"\"\nARG CUDA=\"118\"\nARG PYTORCH_VERSION=\"2.1.2\"\nARG TARGETARCH\n\nENV PYTORCH_VERSION=$PYTORCH_VERSION\n\nRUN apt-get update && \\\n    apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev rsync s3fs && \\\n    rm -rf /var/cache/apt/archives && \\\n    rm -rf /var/lib/apt/lists/*\n\nWORKDIR /workspace\n\nRUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git\n\nWORKDIR /workspace/axolotl\n\n# If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64\nRUN pip uninstall -y causal_conv1d\nRUN if [ \"$TARGETARCH\" = \"arm64\" ]; then \\\n        BASE_EXTRAS=\"flash-attn,ring-flash-attn,optimizers,ray\"; \\\n    else \\\n        BASE_EXTRAS=\"deepspeed,flash-attn,ring-flash-attn,optimizers,ray\"; \\\n    fi && \\\n    if [ \"$AXOLOTL_EXTRAS\" != \"\" ]; then \\\n        pip install --no-build-isolation -e .[$BASE_EXTRAS,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \\\n    else \\\n        pip install --no-build-isolation -e .[$BASE_EXTRAS] $AXOLOTL_ARGS; \\\n    fi && \\    python scripts/unsloth_install.py | sh && \\\n    python scripts/cutcrossentropy_install.py | sh && \\\n    pip install pytest && \\\n    pip cache purge\n\n# fix so that git fetch/pull from remote works with shallow clone\nRUN git config remote.origin.fetch \"+refs/heads/*:refs/remotes/origin/*\" && \\\n    git config --get remote.origin.fetch && \\\n    git config --global credential.helper store\n\nCOPY .axolotl-complete.bash /root/.axolotl-complete.bash\nRUN chmod +x /root/.axolotl-complete.bash && \\\n    echo 'source /root/.axolotl-complete.bash' >> ~/.bashrc\n"
  },
  {
    "path": "docker/Dockerfile-base",
    "content": "ARG CUDA_VERSION=\"11.8.0\"\nARG CUDNN_VERSION=\"8\"\nARG UBUNTU_VERSION=\"22.04\"\nARG MAX_JOBS=4\nARG TARGETARCH\n\nFROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder\n\nENV PATH=\"/root/miniconda3/bin:${PATH}\"\n\nARG TARGETARCH\nARG PYTHON_VERSION=\"3.11\"\nARG PYTORCH_VERSION=\"2.1.2\"\nARG CUDA=\"128\"\nARG TORCH_CUDA_ARCH_LIST=\"7.0 7.5 8.0 8.6 9.0+PTX\"\n\nENV PYTHON_VERSION=$PYTHON_VERSION\nENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST\n\nRUN apt-get update \\\n    && apt-get install -y --no-install-recommends \\\n        wget git build-essential ninja-build git-lfs libaio-dev pkg-config \\\n        ibverbs-providers ibverbs-utils infiniband-diags  \\\n        librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm \\\n    && rm -rf /var/cache/apt/archives \\\n    && rm -rf /var/lib/apt/lists/* \\\n    && if [ \"$TARGETARCH\" = \"amd64\" ]; then \\\n        MINICONDA_ARCH=\"x86_64\"; \\\n    elif [ \"$TARGETARCH\" = \"arm64\" ]; then \\\n        MINICONDA_ARCH=\"aarch64\"; \\\n    else \\\n        echo \"Unsupported architecture: $TARGETARCH\"; exit 1; \\\n    fi \\\n    && wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-${MINICONDA_ARCH}.sh \\\n    && mkdir /root/.conda \\\n    && bash Miniconda3-latest-Linux-${MINICONDA_ARCH}.sh -b \\\n    && rm -f Miniconda3-latest-Linux-${MINICONDA_ARCH}.sh \\\n    && conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main \\\n    && conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r \\\n    && conda create -n \"py${PYTHON_VERSION}\" python=\"${PYTHON_VERSION}\"\n\nENV PATH=\"/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}\"\n\nWORKDIR /workspace\n\nRUN python3 -m pip install --upgrade pip && pip3 install -U packaging==26.0 setuptools==75.8.0 wheel psutil && \\\n    python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} torchvision --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \\\n    python3 -m pip cache purge\n\nRUN if [ \"$CUDA\" != \"130\" ] ; then \\\n        CAUSAL_CONV1D_FORCE_CXX11_ABI=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE python3 -m pip install --no-cache-dir \"causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@v1.5.4\"; \\\n        python3 -m pip install --no-cache-dir \"mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main\"; \\\n        python3 -m pip cache purge; \\\n    fi\n\nRUN git lfs install --skip-repo && \\\n    pip3 install awscli && \\\n    # The base image ships with `pydantic==1.8.2` which is not working\n    pip3 install -U --no-cache-dir pydantic==1.10.10 && \\\n    pip3 cache purge\n\n# Map Python version (e.g., 3.12 -> cp312)\nRUN PYTHON_CP=\"cp$(echo $PYTHON_VERSION | tr -d '.')\" && \\\n    # Map PyTorch version (e.g., 2.9.1 -> torch2.9, 2.10.0 -> torch2.10)\n    TORCH_TAG=\"torch$(echo $PYTORCH_VERSION | grep -oP '^\\d+\\.\\d+')\" && \\\n    # Map architecture\n    case \"$TARGETARCH\" in \\\n        amd64) ARCH_TAG=\"x86_64\" ;; \\\n        arm64) ARCH_TAG=\"aarch64\" ;; \\\n        *) echo \"Unsupported architecture: $TARGETARCH\"; exit 1 ;; \\\n    esac && \\\n    WHL_VERSION=\"v0.7.16\" && \\\n    WHL_FILE=\"flash_attn-2.8.3+cu${CUDA}${TORCH_TAG}-${PYTHON_CP}-${PYTHON_CP}-linux_${ARCH_TAG}.whl\" && \\\n    wget -nv \"https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}\" && \\\n    pip3 install --no-cache-dir \"${WHL_FILE}\" && \\\n    rm \"${WHL_FILE}\"\n"
  },
  {
    "path": "docker/Dockerfile-base-next",
    "content": "ARG CUDA_VERSION=\"12.8.1\"\nARG CUDNN_VERSION=\"8\"\nARG UBUNTU_VERSION=\"22.04\"\nARG MAX_JOBS=4\n\nFROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder\n\nENV PATH=\"/root/miniconda3/bin:${PATH}\"\n\nARG PYTHON_VERSION=\"3.11\"\nARG PYTORCH_VERSION=\"next\"\nARG CUDA=\"128\"\nARG TORCH_CUDA_ARCH_LIST=\"7.0 7.5 8.0 8.6 9.0+PTX\"\n\nENV PYTHON_VERSION=$PYTHON_VERSION\nENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST\n\nRUN apt-get update \\\n    && apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config && rm -rf /var/lib/apt/lists/* \\\n    && wget \\\n    https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \\\n    && mkdir /root/.conda \\\n    && bash Miniconda3-latest-Linux-x86_64.sh -b \\\n    && rm -f Miniconda3-latest-Linux-x86_64.sh \\\n    && conda create -n \"py${PYTHON_VERSION}\" python=\"${PYTHON_VERSION}\"\n\nENV PATH=\"/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}\"\n\nWORKDIR /workspace\n\nRUN python3 -m pip install --upgrade pip && pip3 install packaging && \\\n    python3 -m pip install --no-cache-dir -U torch==2.7.1 --extra-index-url https://download.pytorch.org/whl/test/cu$CUDA && \\\n    python3 -m pip install --no-cache-dir \"causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main\" && \\\n    python3 -m pip install --no-cache-dir \"mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main\"\n\nRUN git lfs install --skip-repo && \\\n    pip3 install awscli && \\\n    pip3 install -U --no-cache-dir pydantic==2.10.6\n"
  },
  {
    "path": "docker/Dockerfile-base-nightly",
    "content": "ARG CUDA_VERSION=\"12.8.1\"\nARG CUDNN_VERSION=\"8\"\nARG UBUNTU_VERSION=\"22.04\"\nARG MAX_JOBS=4\n\nFROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder\n\nENV PATH=\"/root/miniconda3/bin:${PATH}\"\n\nARG PYTHON_VERSION=\"3.11\"\nARG PYTORCH_VERSION=\"nightly\"\nARG CUDA=\"128\"\nARG TORCH_CUDA_ARCH_LIST=\"7.0 7.5 8.0 8.6 9.0+PTX\"\n\nENV PYTHON_VERSION=$PYTHON_VERSION\nENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST\n\nRUN apt-get update \\\n    && apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config && rm -rf /var/lib/apt/lists/* \\\n    && wget \\\n    https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \\\n    && mkdir /root/.conda \\\n    && bash Miniconda3-latest-Linux-x86_64.sh -b \\\n    && rm -f Miniconda3-latest-Linux-x86_64.sh \\\n    && conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main \\\n    && conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r \\\n    && conda create -n \"py${PYTHON_VERSION}\" python=\"${PYTHON_VERSION}\"\n\nENV PATH=\"/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}\"\n\nWORKDIR /workspace\n\nRUN python3 -m pip install --upgrade pip && pip3 install -U packaging==26.0 setuptools==75.8.0 wheel && \\\n    python3 -m pip install --no-cache-dir -U torch --extra-index-url https://download.pytorch.org/whl/nightly/cu$CUDA && \\\n    python3 -m pip install --no-cache-dir \"causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main\" && \\\n    python3 -m pip install --no-cache-dir \"mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main\" && \\\n    python3 -m pip cache purge\n\nRUN git lfs install --skip-repo && \\\n    pip3 install awscli && \\\n    # The base image ships with `pydantic==1.8.2` which is not working\n    pip3 install -U --no-cache-dir pydantic==1.10.10 && \\\n    pip3 cache purge\n"
  },
  {
    "path": "docker/Dockerfile-cloud",
    "content": "ARG BASE_TAG=main\nFROM axolotlai/axolotl:$BASE_TAG\n\nENV HF_DATASETS_CACHE=\"/workspace/data/huggingface-cache/datasets\"\nENV HF_HUB_CACHE=\"/workspace/data/huggingface-cache/hub\"\nENV HF_HOME=\"/workspace/data/huggingface-cache/hub\"\nENV HF_HUB_ENABLE_HF_TRANSFER=\"1\"\n\nEXPOSE 8888\nEXPOSE 22\n\nCOPY scripts/cloud-entrypoint.sh /root/cloud-entrypoint.sh\nCOPY scripts/motd /etc/motd\n\nRUN pip install jupyterlab notebook ipywidgets && \\\n    jupyter lab clean\nRUN apt update && \\\n    apt install --yes --no-install-recommends openssh-server tmux iproute2 nvtop && \\\n    rm -rf /var/cache/apt/archives && \\\n    rm -rf /var/lib/apt/lists/* && \\\n    mkdir -p ~/.ssh && \\\n    chmod 700 ~/.ssh && \\\n    printf \"\\n[[ -z \\\"\\$TMUX\\\"  ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\\n\" >> ~/.bashrc && \\\n    printf \"[ ! -z \\\"\\$TERM\\\" -a -r /etc/motd ] && cat /etc/motd\\n\" >> ~/.bashrc && \\\n    chmod +x /workspace/axolotl/scripts/cloud-entrypoint.sh && \\\n    chmod +x /root/cloud-entrypoint.sh && \\\n    echo 'set-option -g history-limit 5000' >> ~/.tmux.conf\n\nENTRYPOINT [\"/root/cloud-entrypoint.sh\"]\nCMD [\"sleep\", \"infinity\"]\n"
  },
  {
    "path": "docker/Dockerfile-cloud-no-tmux",
    "content": "ARG BASE_TAG=main\nFROM axolotlai/axolotl:$BASE_TAG\n\nENV HF_DATASETS_CACHE=\"/workspace/data/huggingface-cache/datasets\"\nENV HF_HUB_CACHE=\"/workspace/data/huggingface-cache/hub\"\nENV HF_HOME=\"/workspace/data/huggingface-cache/hub\"\nENV HF_HUB_ENABLE_HF_TRANSFER=\"1\"\n\nEXPOSE 8888\nEXPOSE 22\n\nCOPY scripts/cloud-entrypoint.sh /root/cloud-entrypoint.sh\nCOPY scripts/motd /etc/motd\n\nRUN pip install jupyterlab notebook ipywidgets && \\\n    jupyter lab clean\nRUN apt update && \\\n    apt install --yes --no-install-recommends openssh-server tmux iproute2 nvtop ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm && \\\n    rm -rf /var/cache/apt/archives && \\\n    rm -rf /var/lib/apt/lists/* && \\\n    mkdir -p ~/.ssh && \\\n    chmod 700 ~/.ssh && \\\n    printf \"[ ! -z \\\"\\$TERM\\\" -a -r /etc/motd ] && cat /etc/motd\\n\" >> ~/.bashrc && \\\n    chmod +x /workspace/axolotl/scripts/cloud-entrypoint.sh && \\\n    chmod +x /root/cloud-entrypoint.sh\n\nENTRYPOINT [\"/root/cloud-entrypoint.sh\"]\nCMD [\"sleep\", \"infinity\"]\n"
  },
  {
    "path": "docker/Dockerfile-cloud-uv",
    "content": "ARG BASE_TAG=main\nFROM axolotlai/axolotl-uv:$BASE_TAG\n\nENV HF_DATASETS_CACHE=\"/workspace/data/huggingface-cache/datasets\"\nENV HF_HUB_CACHE=\"/workspace/data/huggingface-cache/hub\"\nENV HF_HOME=\"/workspace/data/huggingface-cache/hub\"\nENV HF_HUB_ENABLE_HF_TRANSFER=\"1\"\n\nEXPOSE 8888\nEXPOSE 22\n\nCOPY scripts/cloud-entrypoint.sh /root/cloud-entrypoint.sh\nCOPY scripts/motd /etc/motd\n\nRUN uv pip install jupyterlab notebook ipywidgets && \\\n    jupyter lab clean\nRUN apt update && \\\n    apt install --yes --no-install-recommends openssh-server tmux iproute2 nvtop && \\\n    rm -rf /var/cache/apt/archives && \\\n    rm -rf /var/lib/apt/lists/* && \\\n    mkdir -p ~/.ssh && \\\n    chmod 700 ~/.ssh && \\\n    printf \"\\n[[ -z \\\"\\$TMUX\\\"  ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\\n\" >> ~/.bashrc && \\\n    printf \"[ ! -z \\\"\\$TERM\\\" -a -r /etc/motd ] && cat /etc/motd\\n\" >> ~/.bashrc && \\\n    chmod +x /workspace/axolotl/scripts/cloud-entrypoint.sh && \\\n    chmod +x /root/cloud-entrypoint.sh && \\\n    echo 'set-option -g history-limit 5000' >> ~/.tmux.conf\n\nENTRYPOINT [\"/root/cloud-entrypoint.sh\"]\nCMD [\"sleep\", \"infinity\"]\n"
  },
  {
    "path": "docker/Dockerfile-tests",
    "content": "ARG BASE_TAG=main-base\nFROM axolotlai/axolotl-base:$BASE_TAG\n\nARG TORCH_CUDA_ARCH_LIST=\"7.0 7.5 8.0 8.6+PTX\"\nARG AXOLOTL_EXTRAS=\"\"\nARG AXOLOTL_ARGS=\"\"\nARG CUDA=\"118\"\nARG PYTORCH_VERSION=\"2.1.2\"\nARG GITHUB_REF=\"main\"\n\nENV PYTORCH_VERSION=$PYTORCH_VERSION\n\nRUN apt-get update && \\\n    apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev\n\nWORKDIR /workspace\n\nRUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git\n\nWORKDIR /workspace/axolotl\n\nRUN git fetch origin +$GITHUB_REF && \\\n    git checkout FETCH_HEAD\n\n# If AXOLOTL_EXTRAS is set, append it in brackets\nRUN if [ \"$AXOLOTL_EXTRAS\" != \"\" ] ; then \\\n        pip install --no-build-isolation -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \\\n    else \\\n        pip install --no-build-isolation -e .[deepspeed,flash-attn,mamba-ssm] $AXOLOTL_ARGS; \\\n    fi\n\n# So we can test the Docker image\nRUN pip install pytest\n\n# fix so that git fetch/pull from remote works\nRUN git config remote.origin.fetch \"+refs/heads/*:refs/remotes/origin/*\" && \\\n    git config --get remote.origin.fetch\n\n# helper for huggingface-login cli\nRUN git config --global credential.helper store\n"
  },
  {
    "path": "docker/Dockerfile-uv",
    "content": "ARG BASE_TAG=main-base\nFROM axolotlai/axolotl-base-uv:$BASE_TAG\n\nARG TORCH_CUDA_ARCH_LIST=\"7.0 7.5 8.0 8.6+PTX\"\nARG AXOLOTL_EXTRAS=\"\"\nARG AXOLOTL_ARGS=\"\"\nARG CUDA=\"118\"\nARG PYTORCH_VERSION=\"2.1.2\"\nARG TARGETARCH\n\nENV PYTORCH_VERSION=$PYTORCH_VERSION\n\nRUN apt-get update && \\\n    apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev rsync s3fs && \\\n    rm -rf /var/cache/apt/archives && \\\n    rm -rf /var/lib/apt/lists/*\n\nWORKDIR /workspace\n\nRUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git\n\nWORKDIR /workspace/axolotl\n\n# If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64\nRUN uv pip uninstall causal_conv1d\nRUN if [ \"$TARGETARCH\" = \"arm64\" ]; then \\\n        BASE_EXTRAS=\"flash-attn,ring-flash-attn,optimizers,ray\"; \\\n    else \\\n        BASE_EXTRAS=\"deepspeed,flash-attn,ring-flash-attn,optimizers,ray\"; \\\n    fi && \\\n    if [ \"$AXOLOTL_EXTRAS\" != \"\" ]; then \\\n        uv pip install --no-build-isolation -e .[$BASE_EXTRAS,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \\\n    else \\\n        uv pip install --no-build-isolation -e .[$BASE_EXTRAS] $AXOLOTL_ARGS; \\\n    fi && \\\n    python scripts/unsloth_install.py --uv | sh && \\\n    python scripts/cutcrossentropy_install.py --uv | sh && \\\n    uv pip install pytest && \\\n    uv cache clean\n\n# fix so that git fetch/pull from remote works with shallow clone\nRUN git config remote.origin.fetch \"+refs/heads/*:refs/remotes/origin/*\" && \\\n    git config --get remote.origin.fetch && \\\n    git config --global credential.helper store\n\nCOPY .axolotl-complete.bash /root/.axolotl-complete.bash\nRUN chmod +x /root/.axolotl-complete.bash && \\\n    echo 'source /root/.axolotl-complete.bash' >> ~/.bashrc\n"
  },
  {
    "path": "docker/Dockerfile-uv-base",
    "content": "ARG CUDA_VERSION=\"12.6.3\"\nARG CUDNN_VERSION=\"\"\nARG UBUNTU_VERSION=\"22.04\"\nARG MAX_JOBS=4\nARG TARGETARCH\n\nFROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder\n\nARG TARGETARCH\nARG PYTHON_VERSION=\"3.11\"\nARG PYTORCH_VERSION=\"2.6.0\"\nARG CUDA=\"126\"\nARG TORCH_CUDA_ARCH_LIST=\"7.0 7.5 8.0 8.6 9.0+PTX\"\n\nENV PYTHON_VERSION=$PYTHON_VERSION\nENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST\nENV UV_TORCH_BACKEND=\"cu${CUDA}\"\n\nRUN apt-get update \\\n    && apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config curl && rm -rf /var/lib/apt/lists/* \\\n    && git lfs install --skip-repo \\\n    && curl -LsSf https://astral.sh/uv/install.sh | sh\n\nENV PATH=\"/root/.local/bin:${PATH}\"\n\nRUN uv python install ${PYTHON_VERSION}\n\nWORKDIR /workspace\n\nRUN uv venv --no-project --relocatable axolotl-venv\n\nENV PATH=\"/workspace/axolotl-venv/bin:${PATH}\"\n\nRUN uv pip install packaging setuptools wheel psutil \\\n    && uv pip install torch==${PYTORCH_VERSION} torchvision \\\n    && uv pip install awscli pydantic\n\nRUN if [ \"$TARGETARCH\" = \"amd64\" ]; then \\\n        uv pip install --no-build-isolation \"causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main\"; \\\n        uv pip install \"mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main\"; \\\n    fi\n\n# Map Python version (e.g., 3.12 -> cp312)\nRUN PYTHON_CP=\"cp$(echo $PYTHON_VERSION | tr -d '.')\" && \\\n    # Map PyTorch version (e.g., 2.9.1 -> torch2.9, 2.10.0 -> torch2.10)\n    TORCH_TAG=\"torch$(echo $PYTORCH_VERSION | grep -oP '^\\d+\\.\\d+')\" && \\\n    # Map architecture\n    case \"$TARGETARCH\" in \\\n        amd64) ARCH_TAG=\"x86_64\" ;; \\\n        arm64) ARCH_TAG=\"aarch64\" ;; \\\n        *) echo \"Unsupported architecture: $TARGETARCH\"; exit 1 ;; \\\n    esac && \\\n    WHL_VERSION=\"v0.7.16\" && \\\n    WHL_FILE=\"flash_attn-2.8.3+cu${CUDA}${TORCH_TAG}-${PYTHON_CP}-${PYTHON_CP}-linux_${ARCH_TAG}.whl\" && \\\n    wget -nv \"https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}\" && \\\n    uv pip install --no-cache-dir \"${WHL_FILE}\" && \\\n    rm \"${WHL_FILE}\"\n"
  },
  {
    "path": "docker-compose.yaml",
    "content": "# version: '3.8'\nservices:\n  axolotl:\n    build:\n      context: .\n      dockerfile: ./docker/Dockerfile\n    volumes:\n      - .:/workspace/axolotl\n      - ~/.cache/huggingface/:/root/.cache/huggingface/\n    # set environment variables\n    environment:\n      # Set environment variables\n      - GIT_AUTHOR_NAME=${GIT_AUTHOR_NAME}\n      - GIT_AUTHOR_EMAIL=${GIT_AUTHOR_EMAIL}\n      - GIT_COMMITTER_NAME=${GIT_COMMITTER_NAME}\n      - GIT_COMMITTER_EMAIL=${GIT_COMMITTER_EMAIL}\n      - WANDB_API_KEY=${WANDB_API_KEY}\n    deploy:\n      resources:\n        reservations:\n          devices:\n            - driver: nvidia\n              # count: 1\n              capabilities: [gpu]\n    command: tail -f /dev/null\n"
  },
  {
    "path": "docs/.gitignore",
    "content": "/.quarto/\n_site/\n/api/*.qmd\n/api/*.html\nconfig-reference.qmd\nmodels/**/*.qmd\nmodels/**/*.html\n"
  },
  {
    "path": "docs/amd_hpc.qmd",
    "content": "---\ntitle: AMD GPUs on HPC Systems\ndescription: A comprehensive guide for using Axolotl on distributed systems with AMD GPUs\n---\n\nThis guide provides step-by-step instructions for installing and configuring Axolotl on a High-Performance Computing (HPC) environment equipped with AMD GPUs.\n\n## Setup\n\n### 1. Install Python\n\nWe recommend using Miniforge, a minimal conda-based Python distribution:\n\n```bash\ncurl -L -O \"https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-$(uname)-$(uname -m).sh\"\nbash Miniforge3-$(uname)-$(uname -m).sh\n```\n\n### 2. Configure Python Environment\nAdd Python to your PATH and ensure it's available at login:\n\n```bash\necho 'export PATH=~/miniforge3/bin:$PATH' >> ~/.bashrc\necho 'if [ -f ~/.bashrc ]; then . ~/.bashrc; fi' >> ~/.bash_profile\n```\n\n### 3. Load AMD GPU Software\n\nLoad the ROCm module:\n\n```bash\nmodule load rocm/5.7.1\n```\n\nNote: The specific module name and version may vary depending on your HPC system. Consult your system documentation for the correct module name.\n\n### 4. Install PyTorch\n\nInstall PyTorch with ROCm support:\n\n```bash\npip install -U torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm5.7 --force-reinstall\n```\n\n### 5. Install Flash Attention\n\nClone and install the Flash Attention repository:\n\n```bash\ngit clone --recursive https://github.com/ROCmSoftwarePlatform/flash-attention.git\nexport GPU_ARCHS=\"gfx90a\"\ncd flash-attention\nexport PYTHON_SITE_PACKAGES=$(python -c 'import site; print(site.getsitepackages()[0])')\npatch \"${PYTHON_SITE_PACKAGES}/torch/utils/hipify/hipify_python.py\" hipify_patch.patch\npip install --no-build-isolation .\n```\n\n### 6. Install Axolotl\n\nClone and install Axolotl:\n\n```bash\ngit clone https://github.com/axolotl-ai-cloud/axolotl\ncd axolotl\npip install packaging ninja\npip install --no-build-isolation -e .\n```\n\n### 7. Apply xformers Workaround\n\nxformers appears to be incompatible with ROCm. Apply the following workarounds:\n - Edit $HOME/packages/axolotl/src/axolotl/monkeypatch/llama_attn_hijack_flash.py modifying the code to always return `False` for SwiGLU availability from xformers.\n - Edit $HOME/miniforge3/lib/python3.10/site-packages/xformers/ops/swiglu_op.py replacing the \"SwiGLU\" function with a pass statement.\n\n### 8. Prepare Job Submission Script\n\nCreate a script for job submission using your HPC's particular software (e.g. Slurm, PBS). Include necessary environment setup and the command to run Axolotl training. If the compute node(s) do(es) not have internet access, it is recommended to include\n\n```bash\nexport TRANSFORMERS_OFFLINE=1\nexport HF_DATASETS_OFFLINE=1\n```\n\n### 9. Download Base Model\n\nDownload a base model using the Hugging Face CLI:\n\n```bash\nhf download meta-llama/Meta-Llama-3.1-8B --local-dir ~/hfdata/llama3.1-8B\n```\n\n### 10. Create Axolotl Configuration\n\nCreate an Axolotl configuration file (YAML format) tailored to your specific training requirements and dataset. Use FSDP for multi-node training.\n\nNote: Deepspeed did not work at the time of testing. However, if anyone managed to get it working, please let us know.\n\n### 11. Preprocess Data\n\nRun preprocessing on the login node:\n\n```bash\nCUDA_VISIBLE_DEVICES=\"\" python -m axolotl.cli.preprocess /path/to/your/config.yaml\n```\n\n### 12. Train\n\nYou are now ready to submit your previously prepared job script. 🚂\n"
  },
  {
    "path": "docs/attention.qmd",
    "content": "---\ntitle: Attention\ndescription: Supported attention modules in Axolotl\n---\n\n## SDP Attention\n\nThis is the default built-in attention in PyTorch.\n\n```yaml\nsdp_attention: true\n```\n\nFor more details: [PyTorch docs](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)\n\n## Flash Attention\n\nAxolotl supports Flash Attention 2, 3, and 4. The best available version is used automatically\nbased on your installed packages and GPU.\n\n```yaml\nflash_attention: true\n```\n\nFor more details: [Flash Attention](https://github.com/Dao-AILab/flash-attention/)\n\n### Flash Attention 2\n\nRequirements: Ampere, Ada, or Hopper GPUs (Turing or lower not supported)\n\n```bash\npip install flash-attn --no-build-isolation\n```\n\n::: {.callout-tip}\n\nIf you get `undefined symbol` while training, ensure you installed PyTorch prior to Axolotl.\nAlternatively, try reinstall or downgrade a version.\n\n:::\n\n### Flash Attention 3\n\nRequirements: Hopper only and CUDA 12.8 (recommended)\n\n```bash\ngit clone https://github.com/Dao-AILab/flash-attention.git\ncd flash-attention/hopper\n\npython setup.py install\n```\n\n### Flash Attention 4\n\nRequirements: Hopper or Blackwell GPUs\n\n```bash\npip install flash-attn-4\n```\n\nOr from source:\n\n```bash\ngit clone https://github.com/Dao-AILab/flash-attention.git\ncd flash-attention/flash_attn/cute\n\npip install -e .\n\n# FA2's flash_attn package includes a cute/ stub that shadows FA4.\n# Remove it so Python can find the real FA4 module:\nrm -r $(python -c \"import flash_attn; print(flash_attn.__path__[0])\")/cute\n```\n\n::: {.callout-note}\n\n**Hopper (SM90) users**: The backward kernel is not yet included in the pip package. To use FA4\nfor training on Hopper, install from source using the instructions above.\n\n:::\n\n::: {.callout-warning}\n\nFA4 only supports head dimensions up to 128 (`d ≤ 128`). The DeepSeek shape `(192, 128)` is\nalso supported but only on Blackwell. Axolotl automatically detects incompatible head dimensions\nand falls back to FA2/3.\n\n:::\n\nFor more details: [flash-attention/flash_attn/cute](https://github.com/Dao-AILab/flash-attention/tree/main/flash_attn/cute)\n\n### AMD\n\nRequirements: ROCm 6.0 and above.\n\nSee [Flash Attention AMD docs](https://github.com/Dao-AILab/flash-attention/tree/main?tab=readme-ov-file#amd-rocm-support).\n\n## Flex Attention\n\nA flexible PyTorch API for attention used in combination with `torch.compile`.\n\n```yaml\nflex_attention: true\n\n# recommended\ntorch_compile: true\n```\n\n::: {.callout-note}\n\nWe recommend using latest stable version of PyTorch for best performance.\n\n:::\n\nFor more details: [PyTorch docs](https://pytorch.org/blog/flexattention/)\n\n## SageAttention\n\nAttention kernels with QK Int8 and PV FP16 accumulator.\n\n```yaml\nsage_attention: true\n```\n\nRequirements: Ampere, Ada, or Hopper GPUs\n\n```bash\npip install sageattention==2.2.0 --no-build-isolation\n```\n\n::: {.callout-warning}\n\nOnly LoRA/QLoRA recommended at the moment. We found loss drop to 0 for full finetuning. See [GitHub Issue](https://github.com/thu-ml/SageAttention/issues/198).\n\n:::\n\nFor more details: [Sage Attention](https://github.com/thu-ml/SageAttention)\n\n::: {.callout-note}\n\nWe do not support SageAttention 3 at the moment. If you are interested on adding this or improving SageAttention implementation, please make an Issue.\n\n:::\n\n\n## xFormers\n\n```yaml\nxformers_attention: true\n```\n\n::: {.callout-tip}\n\nWe recommend using with Turing GPUs or below (such as on Colab).\n\n:::\n\nFor more details: [xFormers](https://github.com/facebookresearch/xformers)\n\n## Shifted Sparse Attention\n\n::: {.callout-warning}\n\nWe plan to deprecate this! If you use this feature, we recommend switching to methods above.\n\n:::\n\nRequirements: LLaMA model architecture\n\n```yaml\nflash_attention: true\ns2_attention: true\n```\n\n::: {.callout-tip}\n\nNo sample packing support!\n\n:::\n"
  },
  {
    "path": "docs/batch_vs_grad.qmd",
    "content": "---\ntitle: Batch size vs Gradient accumulation\ndescription: Understanding of batch size and gradient accumulation steps\n---\n\nGradient accumulation means accumulating gradients over several mini-batches and updating the model weights afterward. When the samples in each batch are diverse, this technique doesn't significantly impact learning.\n\nThis method allows for effective training with larger effective batch sizes without needing proportionally larger memory. Here's why:\n\n1. **Memory Consumption with Batch Size**: The primary reason increasing the batch size impacts memory is due to the storage requirements for intermediate activations. When you forward propagate a batch through a network, you have to store the activations at each layer for each sample in the batch, because these activations are used during backpropagation to compute gradients. Therefore, larger batches mean more activations, leading to greater GPU memory consumption.\n\n2. **Gradient Accumulation**: With gradient accumulation, you're effectively simulating a larger batch size by accumulating gradients over several smaller batches (or micro-batches). However, at any given time, you're only forward and backward propagating a micro-batch. This means you only store activations for the micro-batch, not the full accumulated batch. As a result, you can simulate the effect of a larger batch size without the memory cost of storing activations for a large batch.\n\n**Example 1:**\nMicro batch size: 3\nGradient accumulation steps: 2\nNumber of GPUs: 3\nTotal batch size = 3 * 2 * 3 = 18\n\n```\n| GPU 1          | GPU 2          | GPU 3          |\n|----------------|----------------|----------------|\n| S1, S2, S3     | S4, S5, S6     | S7, S8, S9     |\n| e1, e2, e3     | e4, e5, e6     | e7, e8, e9     |\n|----------------|----------------|----------------|\n| → (accumulate) | → (accumulate) | → (accumulate) |\n|----------------|----------------|----------------|\n| S10, S11, S12  | S13, S14, S15  | S16, S17, S18  |\n| e10, e11, e12  | e13, e14, e15  | e16, e17, e18  |\n|----------------|----------------|----------------|\n| → (apply)      | → (apply)      | → (apply)      |\n\nAccumulated gradient for the weight w1 after the second iteration (considering all GPUs):\nTotal gradient for w1 = e1 + e2 + e3 + e4 + e5 + e6 + e7 + e8 + e9 + e10 + e11 + e12 + e13 + e14 + e15 + e16 + e17 + e18\n\nWeight update for w1:\nw1_new = w1_old - learning rate x (Total gradient for w1 / 18)\n```\n\n**Example 2:**\nMicro batch size: 2\nGradient accumulation steps: 1\nNumber of GPUs: 3\nTotal batch size = 2 * 1 * 3 = 6\n\n```\n| GPU 1     | GPU 2     | GPU 3     |\n|-----------|-----------|-----------|\n| S1, S2    | S3, S4    | S5, S6    |\n| e1, e2    | e3, e4    | e5, e6    |\n|-----------|-----------|-----------|\n| → (apply) | → (apply) | → (apply) |\n\nAccumulated gradient for the weight w1 (considering all GPUs):\nTotal gradient for w1 = e1 + e2 + e3 + e4 + e5 + e6\n\nWeight update for w1:\nw1_new = w1_old - learning rate × (Total gradient for w1 / 6)\n```\n"
  },
  {
    "path": "docs/checkpoint_saving.qmd",
    "content": "---\ntitle: \"Checkpoint Saving\"\nformat:\n  html:\n    toc: true\n    toc-depth: 2\n    number-sections: true\nexecute:\n  enabled: false\n---\n\n## Overview\n\nAxolotl supports on-demand checkpoint saving during training. You can trigger checkpoints via file-based triggers (for programmatic control) or Control+C (for interactive use).\n\n## File-Based Checkpoint Trigger\n\n### Configuration\n\nEnable in your config:\n\n```yaml\ndynamic_checkpoint:\n  enabled: true\n  check_interval: 100  # Optional: check every N steps (default: 100)\n  trigger_file_path: \"axolotl_checkpoint.save\"  # Optional: custom filename\n```\n\n**Options:**\n- `enabled`: `true` to enable (required)\n- `check_interval`: Steps between file checks. Default: 100. Lower = faster response, higher I/O overhead.\n- `trigger_file_path`: Custom trigger filename. Default: `axolotl_checkpoint.save`\n\n### How It Works\n\n1. Rank 0 checks for trigger file every `check_interval` steps in `output_dir`\n2. When detected, file is deleted and checkpoint is saved\n3. In distributed training, rank 0 broadcasts to synchronize all ranks\n\n### Usage\n\n**Command line:**\n```bash\ntouch /path/to/output_dir/axolotl_checkpoint.save\n```\n\n**Programmatic:**\n```python\nfrom pathlib import Path\nPath(\"/path/to/output_dir/axolotl_checkpoint.save\").touch()\n```\n\nCheckpoint saves within the next `check_interval` steps. The trigger file is auto-deleted after detection, so you can create it multiple times.\n\n**Custom filename:**\n```yaml\ndynamic_checkpoint:\n  enabled: true\n  trigger_file_path: \"my_trigger.save\"\n```\n```bash\ntouch /path/to/output_dir/my_trigger.save\n```\n\n## Control+C (SIGINT) Checkpoint\n\nPressing `Ctrl+C` during training saves the model state and exits gracefully. **Note:** This saves only the model weights, not optimizer state. For resumable checkpoints, use the file-based trigger.\n\n## Best Practices\n\n- **Check interval**: Lower values (10-50) for fast training, default 100 for slower training\n- **Distributed training**: Create trigger file once; rank 0 handles synchronization\n- **Resume**: Dynamic checkpoints can be resumed like regular checkpoints via `resume_from_checkpoint`\n\n## Example\n\n```yaml\noutput_dir: ./outputs/lora-out\nsave_steps: 500  # Scheduled checkpoints\n\ndynamic_checkpoint:\n  enabled: true\n  check_interval: 50\n```\n\nThis enables scheduled checkpoints every 500 steps plus on-demand saves via file trigger (checked every 50 steps).\n"
  },
  {
    "path": "docs/cli.qmd",
    "content": "---\ntitle: \"Command Line Interface (CLI)\"\nformat:\n  html:\n    toc: true\n    toc-expand: 2\n    toc-depth: 3\nexecute:\n  enabled: false\n---\n\nThe Axolotl CLI provides a streamlined interface for training and fine-tuning large language models. This guide covers\nthe CLI commands, their usage, and common examples.\n\n\n## Basic Commands\n\nAll Axolotl commands follow this general structure:\n\n```bash\naxolotl <command> [config.yml] [options]\n```\n\nThe config file can be local or a URL to a raw YAML file.\n\n### Launcher Arguments\n\nFor commands that support multi-GPU (`train`, `evaluate`, ...), you can pass launcher-specific arguments using the `--` separator:\n\n```bash\n# Pass torchrun arguments\naxolotl train config.yml --launcher torchrun -- --nproc_per_node=2 --nnodes=1\n\n# Pass accelerate arguments\naxolotl train config.yml --launcher accelerate -- --config_file=accelerate_config.yml --num_processes=4\n```\n\nArguments after `--` are passed directly to the launcher (torchrun, accelerate launch, etc.).\n\n## Command Reference\n\n### fetch\n\nDownloads example configurations and deepspeed configs to your local machine.\n\n```bash\n# Get example YAML files\naxolotl fetch examples\n\n# Get deepspeed config files\naxolotl fetch deepspeed_configs\n\n# Specify custom destination\naxolotl fetch examples --dest path/to/folder\n```\n\n### preprocess\n\nPreprocesses and tokenizes your dataset before training. This is recommended for large datasets.\n\n```bash\n# Basic preprocessing\naxolotl preprocess config.yml\n\n# Preprocessing with one GPU\nCUDA_VISIBLE_DEVICES=\"0\" axolotl preprocess config.yml\n\n# Debug mode to see processed examples\naxolotl preprocess config.yml --debug\n\n# Debug with limited examples\naxolotl preprocess config.yml --debug --debug-num-examples 5\n```\n\nConfiguration options:\n\n```yaml\ndataset_prepared_path: Local folder for saving preprocessed data\npush_dataset_to_hub: HuggingFace repo to push preprocessed data (optional)\n```\n\n### train\n\nTrains or fine-tunes a model using the configuration specified in your YAML file.\n\n```bash\n# Basic training\naxolotl train config.yml\n\n# Train and set/override specific options\naxolotl train config.yml \\\n    --learning-rate 1e-4 \\\n    --micro-batch-size 2 \\\n    --num-epochs 3\n\n# Training without accelerate\naxolotl train config.yml --launcher python\n\n# Pass launcher-specific arguments using -- separator\naxolotl train config.yml --launcher torchrun -- --nproc_per_node=2 --nnodes=1\naxolotl train config.yml --launcher accelerate -- --config_file=accelerate_config.yml\n\n# Resume training from checkpoint\naxolotl train config.yml --resume-from-checkpoint path/to/checkpoint\n```\n\nIt is possible to run sweeps over multiple hyperparameters by passing in a sweeps config.\n\n```bash\n# Basic training with sweeps\naxolotl train config.yml --sweep path/to/sweep.yaml\n```\n\nExample sweep config:\n```yaml\n_:\n  # This section is for dependent variables we need to fix\n  - load_in_8bit: false\n    load_in_4bit: false\n    adapter: lora\n  - load_in_8bit: true\n    load_in_4bit: false\n    adapter: lora\n\n# These are independent variables\nlearning_rate: [0.0003, 0.0006]\nlora_r:\n  - 16\n  - 32\nlora_alpha:\n  - 16\n  - 32\n  - 64\n```\n\n\n\n### inference\n\nRuns inference using your trained model in either CLI or Gradio interface mode.\n\n```bash\n# CLI inference with LoRA\naxolotl inference config.yml --lora-model-dir=\"./outputs/lora-out\"\n\n# CLI inference with full model\naxolotl inference config.yml --base-model=\"./completed-model\"\n\n# Gradio web interface\naxolotl inference config.yml --gradio \\\n    --lora-model-dir=\"./outputs/lora-out\"\n\n# Inference with input from file\ncat prompt.txt | axolotl inference config.yml \\\n    --base-model=\"./completed-model\"\n```\n\n### merge-lora\n\nMerges trained LoRA adapters into the base model.\n\n```bash\n# Basic merge\naxolotl merge-lora config.yml\n\n# Specify LoRA directory (usually used with checkpoints)\naxolotl merge-lora config.yml --lora-model-dir=\"./lora-output/checkpoint-100\"\n\n# Merge using CPU (if out of GPU memory)\nCUDA_VISIBLE_DEVICES=\"\" axolotl merge-lora config.yml\n```\n\nConfiguration options:\n\n```yaml\ngpu_memory_limit: Limit GPU memory usage\nlora_on_cpu: Load LoRA weights on CPU\n```\n\n### merge-sharded-fsdp-weights\n\nMerges sharded FSDP model checkpoints into a single combined checkpoint.\n\n```bash\n# Basic merge\naxolotl merge-sharded-fsdp-weights config.yml\n```\n\n### evaluate\n\nEvaluates a model's performance (loss etc) on the train and eval datasets.\n\n```bash\n# Basic evaluation\naxolotl evaluate config.yml\n\n# Evaluation with launcher arguments\naxolotl evaluate config.yml --launcher torchrun -- --nproc_per_node=2\n```\n\n### lm-eval\n\nRuns LM Evaluation Harness on your model.\n\n```bash\n# Basic evaluation\naxolotl lm-eval config.yml\n```\n\nConfiguration options:\n\n```yaml\nlm_eval_model: # model to evaluate (local or hf path)\n\n# List of tasks to evaluate\nlm_eval_tasks:\n  - arc_challenge\n  - hellaswag\nlm_eval_batch_size: # Batch size for evaluation\noutput_dir: # Directory to save evaluation results\n```\n\nSee [LM Eval Harness integration docs](https://docs.axolotl.ai/docs/custom_integrations.html#language-model-evaluation-harness-lm-eval) for full configuration details.\n\n### delinearize-llama4\n\nDelinearizes a Llama 4 linearized model into a regular HuggingFace Llama 4 model. This only works with the non-quantized linearized model.\n\n```bash\naxolotl delinearize-llama4 --model path/to/model_dir --output path/to/output_dir\n```\n\nThis would be necessary to use with other frameworks. If you have an adapter, merge it with the non-quantized linearized model before delinearizing.\n\n### quantize\n\nQuantizes a model using the quantization configuration specified in your YAML file.\n\n```bash\naxolotl quantize config.yml\n```\n\nSee [Quantization](./quantize.qmd) for more details.\n\n\n## Legacy CLI Usage\n\nWhile the new Click-based CLI is preferred, Axolotl still supports the legacy module-based CLI:\n\n```bash\n# Preprocess\npython -m axolotl.cli.preprocess config.yml\n\n# Train\naccelerate launch -m axolotl.cli.train config.yml\n\n# Inference\naccelerate launch -m axolotl.cli.inference config.yml \\\n    --lora_model_dir=\"./outputs/lora-out\"\n\n# Gradio interface\naccelerate launch -m axolotl.cli.inference config.yml \\\n    --lora_model_dir=\"./outputs/lora-out\" --gradio\n```\n\n::: {.callout-important}\nWhen overriding CLI parameters in the legacy CLI, use same notation as in yaml file (e.g., `--lora_model_dir`).\n\n**Note:** This differs from the new Click-based CLI, which uses dash notation (e.g., `--lora-model-dir`). Keep this in mind if you're referencing newer documentation or switching between CLI versions.\n:::\n\n## Remote Compute with Modal Cloud\n\nAxolotl supports running training and inference workloads on Modal cloud infrastructure. This is configured using a\ncloud YAML file alongside your regular Axolotl config.\n\n### Cloud Configuration\n\nCreate a cloud config YAML with your Modal settings:\n\n```yaml\n# cloud_config.yml\nprovider: modal\ngpu: a100       # Supported: l40s, a100-40gb, a100-80gb, a10g, h100, t4, l4\ngpu_count: 1    # Number of GPUs to use\ntimeout: 86400  # Maximum runtime in seconds (24 hours)\nbranch: main    # Git branch to use (optional)\n\nvolumes:        # Persistent storage volumes\n  - name: axolotl-cache\n    mount: /workspace/cache\n  - name: axolotl-data\n    mount: /workspace/data\n  - name: axolotl-artifacts\n    mount: /workspace/artifacts\n\nsecrets:        # Secrets to inject\n  - WANDB_API_KEY\n  - HF_TOKEN\n```\n\n### Running on Modal Cloud\n\nCommands that support the --cloud flag:\n\n```bash\n# Preprocess on cloud\naxolotl preprocess config.yml --cloud cloud_config.yml\n\n# Train on cloud\naxolotl train config.yml --cloud cloud_config.yml\n\n# Run lm-eval on cloud\naxolotl lm-eval config.yml --cloud cloud_config.yml\n```\n\n### Cloud Configuration Options\n\n```yaml\nprovider:    # compute provider, currently only `modal` is supported\ngpu:         # GPU type to use\ngpu_count:   # Number of GPUs (default: 1)\nmemory:      # RAM in GB (default: 128)\ntimeout:     # Maximum runtime in seconds\ntimeout_preprocess: # Preprocessing timeout\nbranch:      # Git branch to use\ndocker_tag:  # Custom Docker image tag\nvolumes:     # List of persistent storage volumes\n\n# Environment variables to pass. Can be specified in two ways:\n# 1. As a string: Will load the value from the host computer's environment variables\n# 2. As a key-value pair: Will use the specified value directly\n# Example:\n# env:\n#   - CUSTOM_VAR  # Loads from host's $CUSTOM_VAR\n#   - {CUSTOM_VAR: \"value\"}  # Uses \"value\" directly\nenv:\n\n# Secrets to inject. Same input format as `env` but for sensitive data.\nsecrets:\n  # - HF_TOKEN\n  # - WANDB_API_KEY\n```\n"
  },
  {
    "path": "docs/custom_integrations.qmd",
    "content": "---\ntitle: Custom Integrations\ntoc: true\ntoc-depth: 3\n---\n\n```{python}\n#| echo: false\n\nimport os\nimport re\n\ndef process_readme(integration_name):\n    try:\n        path = f'../src/axolotl/integrations/{integration_name}/README.md'\n        with open(path, 'r') as f:\n            txt = f.read()\n            # Remove h1 headings\n            txt = re.sub(r'^# .*\\n?', '', txt, flags=re.MULTILINE)\n            # Convert h2 to h3\n            txt = re.sub(r'^## ', '### ', txt, flags=re.MULTILINE)\n            return txt\n    except FileNotFoundError:\n        return None\n\ndef print_section(name, folder_name):\n    output = f\"\\n## {name}\\n\"\n    content = process_readme(folder_name)\n    if content:\n        output += content\n    output += f\"\\nPlease see reference [here](https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/integrations/{folder_name})\\n\"\n    return output\n```\n\n```{python}\n#| output: asis\n#| echo: false\n\n# Introduction text\nprint(\"\"\"\nAxolotl adds custom features through `integrations`. They are located within the `src/axolotl/integrations` directory.\n\nTo enable them, please check the respective documentations.\n\"\"\")\n\n# Sections\nsections = [\n    (\"Cut Cross Entropy\", \"cut_cross_entropy\"),\n    (\"Grokfast\", \"grokfast\"),\n    (\"Knowledge Distillation (KD)\", \"kd\"),\n    (\"Liger Kernels\", \"liger\"),\n    (\"Language Model Evaluation Harness (LM Eval)\", \"lm_eval\"),\n    (\"Spectrum\", \"spectrum\"),\n    (\"LLMCompressor\", \"llm_compressor\")\n]\n\nfor folder_name in os.listdir(\"../src/axolotl/integrations/\"):\n    if folder_name in [path for name, path in sections]:\n        # skip if already in sections\n        continue\n    if os.path.exists(f\"../src/axolotl/integrations/{folder_name}/README.md\"):\n        # grab the first heading in README.md as the section name\n        with open(f\"../src/axolotl/integrations/{folder_name}/README.md\", \"r\") as f:\n            txt = f.read()\n            matches = re.search(r'^# (.*)\\n?', txt, flags=re.MULTILINE)\n            if matches:\n                name = matches.group(1)\n            else:\n                continue\n            sections.append((name, folder_name))\n\n# sort sections by name\nsections = sorted(sections, key=lambda x: x[0])\n\nfor section_name, folder_name in sections:\n    print(print_section(section_name, folder_name))\n```\n\n## Adding a new integration\n\nPlugins can be used to customize the behavior of the training pipeline through [hooks](https://en.wikipedia.org/wiki/Hooking). See [`axolotl.integrations.BasePlugin`](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/integrations/base.py) for the possible hooks.\n\nTo add a new integration, please follow these steps:\n\n1. Create a new folder in the `src/axolotl/integrations` directory.\n2. Add any relevant files (`LICENSE`, `README.md`, `ACKNOWLEDGEMENTS.md`, etc.) to the new folder.\n3. Add `__init__.py` and `args.py` files to the new folder.\n  - `__init__.py` should import the integration and hook into the appropriate functions.\n  - `args.py` should define the arguments for the integration.\n4. (If applicable) Add CPU tests under `tests/integrations` or GPU tests under `tests/e2e/integrations`.\n\n::: {.callout-tip}\n\nSee [src/axolotl/integrations/cut_cross_entropy](https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/integrations/cut_cross_entropy) for a minimal integration example.\n\n:::\n\n::: {.callout-warning}\n\nIf you could not load your integration, please ensure you are pip installing in editable mode.\n\n```bash\npip install -e .\n```\n\nand correctly spelled the integration name in the config file.\n\n```yaml\nplugins:\n  - axolotl.integrations.your_integration_name.YourIntegrationPlugin\n```\n\n:::\n\n::: {.callout-note}\n\nIt is not necessary to place your integration in the `integrations` folder. It can be in any location, so long as it's installed in a package in your python env.\n\nSee this repo for an example: [https://github.com/axolotl-ai-cloud/diff-transformer](https://github.com/axolotl-ai-cloud/diff-transformer)\n\n:::\n"
  },
  {
    "path": "docs/dataset-formats/conversation.qmd",
    "content": "---\ntitle: Conversation\ndescription: Conversation format for supervised fine-tuning.\norder: 3\n---\n\n## chat_template\n\nChat Template strategy uses a jinja2 template that converts a list of messages into a prompt. Support using tokenizer's template, a supported template, or custom jinja2.\n\n```{.json filename=\"data.jsonl\"}\n{\"messages\": [{\"role\": \"...\", \"content\": \"...\"}, {\"role\": \"...\", \"content\": \"...\"}, ...]}\n```\n\nSee [configs](../config-reference.qmd) for full configs and supported templates.\n\n### Migrating from sharegpt\n\nMost configs can be adapted as follows:\n\n```yaml\n# old\nchat_template: chatml\ndatasets:\n  - path: ...\n    type: sharegpt\n    conversation: chatml\n\n# new (if using tokenizer's chat_template)\ndatasets:\n  - path: ...\n    type: chat_template\n\n    field_messages: conversations\n    message_property_mappings:\n      role: from\n      content: value\n\n# new (if setting a new chat_template like chatml, gemma, etc)\nchat_template: chatml\ndatasets:\n  - path: ...\n    type: chat_template\n\n    field_messages: conversations\n    message_property_mappings:\n      role: from\n      content: value\n```\n\nWe recommend checking the below examples for other usecases.\n\n### Examples\n\n#### Training on last message\n\n(Legacy) Using the default chat template in the tokenizer_config.json on OpenAI messages format, training on only last message.\n\n```yaml\ndatasets:\n  - path: ...\n    type: chat_template\n    roles_to_train:\n    train_on_eos:\n```\n\n::: {.callout-tip}\nIf you receive an error like \"`chat_template` choice is `tokenizer_default` but tokenizer's `chat_template` is null.\", it means the tokenizer does not have a default `chat_template`. Follow the examples below instead to set a custom `chat_template`.\n:::\n\n#### Overriding default chat template\n\nUsing the `gemma` chat template to override the tokenizer_config.json's chat template on OpenAI messages format, training on all assistant messages.\n\n```yaml\nchat_template: gemma # this overwrites the tokenizer's chat_template\ndatasets:\n  - path: ...\n    type: chat_template\n    roles_to_train: [\"assistant\"]  # default value\n```\n\n::: {.callout-note}\nIf you want to use built-in chat_template, use `chat_template: tokenizer_default` (this is set by default).\n:::\n\n#### Using default chat template with fallback\n\nUsing the tokenizer_config.json's chat template or `chatml` as fallback if the former's chat template does not exist, on OpenAI messages format, training on all assistant messages.\n\n```yaml\nchat_template: tokenizer_default_fallback_chatml # this overwrites the tokenizer's chat_template\ndatasets:\n  - path: ...\n    type: chat_template\n```\n\n#### Custom Jinja template\n\nUsing a custom jinja template on OpenAI messages format, training on all assistant messages.\n\n```yaml\n# chat_template: jinja # `jinja` will be implied if the `chat_template_jinja` is set and this field is empty\nchat_template_jinja: \"{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '\\n' + message['content'] + '<|end|>' + '\\n'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + '\\n' + message['content'] + '<|end|>' + '\\n' + '<|assistant|>' + '\\n'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + '\\n'}}{% endif %}{% endfor %}\"\n\ndatasets:\n  - path: ...\n    type: chat_template\n```\n\n::: {.callout-important}\nPlease make sure that your `tokenizer.eos_token` is same as EOS (End-of-Sequence) token in template. Otherwise, set `eos_token` under `special_tokens: `.\n:::\n\n#### Using template with different token for EOT and EOS\n\n- If you are using a template that has a different EOT (End-of-Turn) token from EOS token or multiple EOT tokens (like Mistral V7 Tekken), set the `eot_tokens: ` config. The handling of EOT tokens follows `train_on_eos: ` which defaults to turn.\n\n```yaml\neot_tokens:\n  - \"[/INST]\"\n  # - \"[/SYSTEM_PROMPT]\"\n\ndatasets:\n  - path: ...\n    type: chat_template\n\n    # optional\n    train_on_eot: turn  # defaults read from train_on_eos (which defaults to turn)\n```\n\n::: {.callout-tip}\nSee [config documentation](../config-reference.qmd) for detailed explanations of \"turn\", \"last\", and \"all\" options for training on tokens.\n:::\n\n::: {.callout-note}\nUsing `eot_tokens` requires each token that exists in `chat_template` to be a single token in the tokenizer. Otherwise, the tokenizer will split the token and cause unexpected behavior.\n\nYou can add those tokens as new tokens under `tokens: ` or (recommended) override unused added_tokens via `added_tokens_overrides: `. See [config](../config-reference.qmd) for more details.\n:::\n\n- Continuing from the previous example, if you want to train on all EOT token trainable turns but only last EOS token, set `train_on_eos: last`.\n\n```yaml\neot_tokens:\n  - \"[/INST]\"\n  # ...\n\ndatasets:\n  - path: ...\n    type: chat_template\n\n    train_on_eos: last\n    train_on_eot: turn\n```\n\n::: {.callout-tip}\nIf EOS token only appears at the end of a prompt, `train_on_eos: last` is equivalent to `train_on_eos: turn`. Therefore, generally, you can leave them to their defaults and omit them.\n:::\n\n\n#### Using tool use\n\nInstead of passing `tools` via the system prompt, an alternative method would be to have the `tools` in a separate column and loaded via `chat_template` to let the template dynamically build it.\n\n```json\n{\n    \"tools\": [\n        {\n            \"type\": \"...\",\n            \"function\": {\n                \"name\": \"...\",\n                \"description\": \"...\",\n                \"parameters\": {\n                    \"type\": \"...\",\n                    \"properties\": {\n                        // ...\n                    },\n                    \"required\": [\"...\"],\n                },\n            },\n        },\n    ],\n    \"messages\": [\n        // ...\n        {\n            \"role\": \"assistant\", // call the function via assistant\n            \"tool_calls\": [\n                {\n                    \"id\": \"...\",  // required only for mistral\n                    \"type\": \"function\",\n                    \"function\": {\n                        \"name\": \"...\",\n                        \"arguments\": {\n                            \"...\": \"...\",\n                        }\n                    }\n                }\n            ]\n        },\n        {\n            \"role\": \"tool\",\n            \"tool_call_id\": \"...\",  // required only for mistral\n            \"name\": \"...\",\n            \"content\": \"...\"\n        },\n    ],\n}\n```\n\n::: {.callout-note}\nTools need to follow [JSON schema](https://json-schema.org/learn/getting-started-step-by-step).\n:::\n\n::: {.callout-warning}\nIf you have tool arguments with same name but different dtypes (like `\"time\": string` and `\"time\": number`), please save `arguments: ` as JSON string to prevent `datasets` from having casting issues.\n\n```\n\"arguments\": \"{\\\"...\\\": \\\"...\\\"}\"\n```\n\nThe same is applicable for tool parameters.\n\n```\n\"parameters\": \"{\\\"...\\\": \\\"...\\\"}\"\n```\n\n:::\n\nExample config for Llama4:\n```yaml\nchat_template: llama4\ndatasets:\n  - path: Nanobit/text-tools-2k-test\n    type: chat_template\n    # field_tools: tools # default is `tools`\n```\n\n::: {.callout-tip}\nLook into the `chat_template` you are using to see if it supports `tools` and what the expected role is for the tool answer. In the example above, the tool answer is expected to be in the `tool` or `ipython` role for `llama4` template.\n:::\n\n\n#### Using fine-grained control over token masking\n\n(Advanced) Using fine-grained control over tokens and turns to train in a conversation\n\nFor a data sample that looks like:\n\n```{.json filename=\"data.jsonl\"}\n{\n  \"conversations\": [\n    {\"from\": \"system\", \"value\": \"You are an AI assistant.\", \"train\": false},\n    {\"from\": \"human\", \"value\": \"Hello\", \"train\": false},\n    {\"from\": \"assistant\", \"value\": \"Hello\", \"train\": true},\n    {\"from\": \"human\", \"value\": \"How are you?\", \"train\": true},\n    {\n      \"from\": \"assistant\",\n      \"value\": \"I'm doing very well, thank you!\",\n      \"train_detail\": [\n        {\"begin_offset\": 0, \"end_offset\": 8, \"train\": false},\n        {\"begin_offset\": 9, \"end_offset\": 18, \"train\": true},\n        {\"begin_offset\": 19, \"end_offset\": 30, \"train\": false},\n      ],\n    },\n    {\n        \"from\": \"human\",\n        \"value\": \"I'm doing very well, thank you!\",\n        \"train\": true,\n    },\n    {\"from\": \"assistant\", \"value\": \"Hi there!\", \"train\": true}\n  ]\n}\n```\n\nThe configuration would look like:\n\n```yaml\ndatasets:\n  - path: ...\n    type: chat_template\n    chat_template: tokenizer_default\n    field_messages: conversations\n    message_property_mappings:\n      role: from\n      content: value\n    roles_to_train: []\n    train_on_eos: turn\n    message_field_training: train\n    message_field_training_detail: train_detail\n```\n\n::: {.callout-tip}\nIt is not necessary to set both `message_field_training` and `message_field_training_detail` at once.\n:::\n\n#### Reasoning split\n\n(For Qwen3 template only) Enable reasoning split, where the reasoning is split from the content and passed as a separate field into the template.\n\n```yaml\ndatasets:\n  - path: ...\n    type: chat_template\n    chat_template: qwen3\n    split_thinking: true\n```\n\nFor example, a content can look like:\n\n```json\n{\n  \"content\": \"<think>Some thinking outputs</think>Output after thinking.\"\n}\n```\n\nAfter split, it will look like:\n\n```json\n{\n  \"reasoning_content\": \"Some thinking outputs\",\n  \"content\": \"Output after thinking...\"\n}\n```\n\n\n## sharegpt\n\n::: {.callout-important}\nShareGPT is deprecated!. Please see [chat_template](#chat_template) section.\n:::\n\n## pygmalion\n\n```{.json filename=\"data.jsonl\"}\n{\"conversations\": [{\"role\": \"...\", \"value\": \"...\"}]}\n```\n"
  },
  {
    "path": "docs/dataset-formats/index.qmd",
    "content": "---\ntitle: Dataset Formats\ndescription: Guide to Dataset Formats in Axolotl\nback-to-top-navigation: true\ntoc: true\ntoc-depth: 5\n---\n\n\nAxolotl is a training framework that aims to make the process convenient yet flexible to users by simply passing a config yaml file.\n\nAs there are a lot of available options in Axolotl, this guide aims to provide an simplify the user experience to choosing the proper choice.\n\nAxolotl supports 3 kinds of training methods: pre-training, supervised fine-tuning, and preference-based post-training (e.g. DPO, ORPO, PRMs). Each method has their own dataset format which are described below.\n\n::: {.callout-tip}\n\nThis guide will mainly use JSONL as an introduction. Please refer to the [dataset loading docs](../dataset_loading.qmd) to understand how to load datasets from other sources.\n\nFor `pretraining_dataset:` specifically, please refer to the [Pre-training section](#pre-training).\n:::\n\n## Pre-training\n\nWhen aiming to train on large corpora of text datasets, pre-training is your go-to choice. Due to the size of these datasets, downloading the entire-datasets before beginning training would be prohibitively time-consuming. Axolotl supports [streaming](https://huggingface.co/docs/datasets/en/stream) to only load batches into memory at a time.\n\nA sample format for a pre-training dataset is as follows:\n\n```json\n{\"text\": \"first row\"}\n{\"text\": \"second row\"}\n...\n```\n\nIt is typically recommended to save your dataset as `.jsonl` due to its flexibility and simplicity.\n\nAxolotl supports loading from a Hugging Face hub repo or from local files.\n\n### Pre-training from Hugging Face hub datasets\n\nAs an example, to train using a Hugging Face dataset `hf_org/name`, you can pass the following config:\n\n```yaml\npretraining_dataset: hf_org/name\n```\n\n### Pre-training from local dataset files\n\nGiven a few corpus files: `A.jsonl`, `B.jsonl`, and `C.jsonl`, your config will look like the below:\n\n```yaml\npretraining_dataset:\n  - path: json\n    data_files:\n      - A.jsonl\n      - B.jsonl\n      - C.jsonl\n```\n\nWhile we recommend `.jsonl`, you can also use the other formats (`csv`, `parquet`, `arrow`, `SQL`, `Webdataset`) that are supported by [`Dataset.load_dataset`](https://huggingface.co/docs/datasets/loading#local-and-remote-files)\n\n### Pre-training without streaming\n\nIn the case that the dataset is small and can be loaded entirely into memory, another approach to running pre-training is to use the `completion` format. This would mean that the entire dataset is pre-tokenized instead of on-demand in streaming.\n\nOne benefit of this is that the tokenization can be performed separately on a CPU-only machine, and then transferred to a GPU machine for training to save costs.\n\nFrom Hugging Face:\n\n```yaml\ndatasets:\n  - path: hf_org/name\n    type: completion\n```\n\nFrom local files:\n\n```yaml\ndatasets:\n  - path: A.jsonl\n    type: completion\n\n  - path: B.jsonl\n    type: completion\n```\n\n::: {.callout-important}\nFor `completion` only, Axolotl would split texts if it exceeds the context length into multiple smaller prompts. If you are interested in having this for `pretraining_dataset` too, please let us know or help make a PR!\n:::\n\n### Pre-training dataset configuration tips\n\n#### Setting max_steps\n\nWhen using streaming for large datasets, Axolotl does not know in advance how large the dataset is and does not know when to stop.\n\nTherefore, it is necessary to set `max_steps: int` in your config for pre-training to run, so that Axolotl knows when to stop training.\n\nOne step is equal to `sequence_len * micro_batch_size * gradient_accumulation_steps * total_num_gpus` tokens.\n\n#### Group_by_length\n\nIt is recommended to leave this off if downloading from Hugging Face hub as it would download the entire dataset which can be very large.\n\n### Reference\n\nPlease see docs [here](pretraining.qmd).\n\n## Supervised fine-tuning (SFT)\n\nSupervised fine-tuning is the process of training models to respond to an instruction or chat input.\n\nAs there are a wide variety of dataset formats, Axolotl tries to support a majority of the formats available in public datasets.\n\nAxolotl provides four approaches for loading datasets, however, it's easier to work backwards from the dataset you have available to figure out which approach to use.\n\nA flow chart is as follows:\n\n1. Do you already have the dataset tokenized? If yes, check [Pre-Tokenized Dataset](#pre-tokenized-dataset).\n\n2. Do you want to format the dataset yourself and manually choose each section to mask? If yes, check [Template Free Dataset](#template-free-dataset)\n\n3. Is your dataset in a \"conversation\" format, containing a `list[messages]`? If yes, check [Conversation Dataset](#conversation-dataset)\n\n4. Is your dataset in an \"instruct\" format, containing `{ instruction, response }`? If yes, check [Instruction Dataset](#instruction-dataset)\n\nIf you went through the flow chart and did not find one that matches, it is recommended to preprocess your dataset into one of the above or create a thread on Github Discussion.\n\n::: {.callout-tip}\nYou can mix and match within each approach or across approaches to train a model on a variety of datasets.\n:::\n\n### Pre-Tokenized Dataset\n\nWe suggest this approach when you want to bring your own tokenized dataset.\n\nAxolotl expects the dataset to have three keys:\n\n- `input_ids`: from tokenizing formatted prompt\n- `attention_mask`: for masking padding. If you don't add padding, it would be equal to `len(input_ids) * [1]`\n- `labels`: this is the same as `input_ids`, however, if you want to mask certain tokens, you would set those indices to `-100`.\n\n::: {.callout-tip}\nMake sure to add BOS/EOS tokens to your prompt and mask it appropriately.\n:::\n\nA config for this would look like:\n\n```yaml\ndatasets:\n  - path: A.jsonl\n    type:\n```\n\n::: {.callout-note}\n`type: ` is empty!\n:::\n\nReference: [Pre-Tokenized Dataset Documentation](tokenized.qmd).\n\n### Template Free Dataset\n\nWe reccomend this approach when you want granular control over the prompt formatting, special tokens, and masking, whilst letting Axolotl handle the tokenization. This is very useful if your dataset has unique prompts that differ across samples and where one single general template wouldn't suffice.\n\nIn the example below, you could see that there is no proper structure. At the same time, it's very flexible as there are no constraints on how your prompt can look.\n\n```json\n{\n    \"segments\": [\n        {\n            \"label\": true,\n            \"text\": \"<s>Hello\\n\"\n        },\n        {\n            \"label\": true,\n            \"text\": \"hi there!. \"\n        },\n        {\n            \"label\": false,\n            \"text\": \"goodbye \"\n        },\n        {\n            \"label\": true,\n            \"text\": \"farewell</s>\"\n        }\n    ]\n}\n```\n\nEach prompt must be have a key called `segments` which is a list of `{ text, label }`.\n\n```yaml\ndatasets:\n  - path: A.jsonl\n    type: input_output\n```\n\nReference: [Template Free Documentation](template_free.qmd).\n\n### Conversation Dataset\n\n`conversation` messages are a list of messages which usually contain a `role` and `content` key.\n\n::: {.callout-tip}\nFun fact: Axolotl synonymously refers to \"chat\" messages as `conversation` messages due to how FastChat initially used this term to build a widely used [fastchat conversation](https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py) method for formatting chat messages prior to the creation of `chat_templates`.\n:::\n\n#### What are `chat_templates`?\n\nThe current most popular and convenient method for inference is to use `chat_templates` for formatting prompts. Axolotl supports using `chat_templates` for training to ensure that the model performs in the same environment as in inference.\n\nHere's a quick rundown on `chat_template`: A `chat_template` is a Jinja2 template which formats a list of messages into a prompt.\n\nAn example of a prompt formatted into a popular template called ChatML can be seen below:\n\nSingle prompt (pretty-printed):\n```json\n{\n    \"messages\": [\n        {\n            \"role\": \"user\",\n            \"content\": \"Hi\"\n        },\n        {\n            \"role\": \"assistant\",\n            \"content\": \"How can I help you?\"\n        },\n        {\n            \"role\": \"user\",\n            \"content\": \"Can you add 3+5?\"\n        },\n        {\n            \"role\": \"assistant\",\n            \"content\": \"The answer is 8.\"\n        }\n    ]\n}\n```\n\nThe ChatML template is as follows:\n```jinja2\n{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}\n```\n\nThe above prompt formatted into this template will result in:\n\n```\n<|im_start|>user\nHi<|im_end|>\n<|im_start|>assistant\nHow can I help you?<|im_end|>\n<|im_start|>user\nCan you add 3+5?<|im_end|>\n<|im_start|>assistant\nThe answer is 8.<|im_end|>\n```\n\nBy using delimiters (`<|im_start|>` and `<|im_end|>`), a prompt separates different speakers which helps the model identify which portion belongs to whom.\n\n#### Common Conversation Dataset formats\n\nOlder conversation datasets with the following format are colloquially called `sharegpt` datasets.\n\n```json\n{\"conversations\": [{\"from\": \"...\", \"value\": \"...\"}]}\n```\n\nNewer conversation datasets usually follow the OpenAI format.\n\n```json\n{\"messages\": [{\"role\": \"...\", \"content\": \"...\"}]}\n```\n\nAxolotl supports both as well as allowing customization of any kind of key.\n\n#### Chat Template Usage\n\nTo properly use this method, it is important to identify three things:\n\n1. Which `chat_template` would you use?\n\n2. What are the keys in your dataset, and what are the possible roles? For example, in OpenAI format, the keys would be `messages`, `role`, and `content`, respectively, whereas the possible roles are `system`, `user`, and `assistant`.\n\n3. What do you want to mask? For instance, only assistant messages, only last message, or nothing.\n\n##### Choosing a `chat_template`\n\nThere are a lot of `chat_templates` out there. Axolotl supports the common ones: [supported chat templates](https://github.com/axolotl-ai-cloud/axolotl/blob/860609392184cf62a7e0ca676658b170e059ce6c/src/axolotl/utils/chat_templates.py#L17). For example, to use ChatML, it would be `chat_template: chatml`.\n\nHowever, it is also possible to use the already configured template within the tokenizer by specifying `chat_template: tokenizer_default`. If you want a fallback (in case some tokenizer does not have it pre-configured), you can do `chat_template: tokenizer_default_fallback_chatml` to fallback to the ChatML template if a tokenizer template was not found.\n\nOne last but powerful approach is to bring your own template. This can be set via:\n\n```yaml\nchat_template_jinja: # your template\n```\n\n##### Setting `chat_template` dataset keys\n\nWe currently default to OpenAI format for dataset keys, so if that's your current dataset format, there's nothing to do here.\n\nIf your dataset format is different, here are the keys you should check (with their defaults):\n\n```yaml\ndatasets:\n    ...\n    field_messages: messages  # this should point to the key containing the list of conversations\n    message_property_mappings:  # this is a mapping from keys in your dataset to keys in chat_template\n      role: role\n      content: content\n```\n\nIn some `chat_templates` (e.g. [Gemma](https://huggingface.co/google/gemma-2b-it/blob/main/tokenizer_config.json#L1507)), the roles are hardcoded to `user` and `assistant`. Consequently, you may find it necessary to map the roles in your dataset to these above. We currently have some defaults that should work for common datasets, but if you get a `KeyError`, it would be necessary to add mapping for your roles. Here is an example of how it would look like:\n\n```yaml\ndatasets:\n    ...\n    roles:\n      assistant:\n        - gpt\n        - model\n      user:\n        - human\n```\n\nIn the example above, all `gpt` and `model` values are converted to `assistant`. All `human` values are converted to `user.`\n\n##### Handling masking\n\nThe common use case for `chat_template` is for chat messages, therefore, it is common to mask all non-assistant messages. Assistant messages refer to the bot messages that you want the model to learn on.\n\nTo train on all `assistant` messages, you would set the following configs.\n\n```yaml\ndatasets:\n    ...\n    roles_to_train: [\"assistant\"]\n    train_on_eos: \"turn\"\n```\n\nThe `train_on_eos` config means that it would mask all EOS tokens for turns that aren't assistant-turns. The other options are: `all` and `last` to choose which EOS to train on.\n\nPerhaps, you want to train on `assistant` and `narrator` roles, you can simply add `narrator` to the list of `roles_to_train`. You would also need to add it to the mapping of `roles` above.\n\n```yaml\ndatasets:\n    ...\n    roles_to_train: [\"assistant\", \"narrator\"]\n    roles:\n      assistant:\n        - gpt\n        - model\n      user:\n        - human\n      narrator: [\"narrator\"]\n```\n\n::: {.callout-tip}\nAs chat_templates may use hardcoded EOS/EOT tokens that are different from the tokenizer's EOS, it is highly recommended to set them. For example, `ChatML` uses `<|im_end|>` to end turns.\n\n```yaml\nspecial_tokens:\n  eos_token: <|im_end|>\n```\n\n:::\n\n##### Applying `chat_template`\n\nOnce all the above steps are completed, you could combine all these configs together to form a bespoke configuration for your custom dataset.\n\n```yaml\ndatasets:\n  - path: A.jsonl\n    type: chat_template\n\n    # step 1\n    chat_template: chatml\n\n    # step 2\n    field_messages: messages\n    message_property_mappings:\n      role: role\n      content: content\n\n    roles:\n      assistant:\n        - gpt\n        - model\n        - assistant\n      user:\n        - human\n        - user\n\n    # step 3\n    roles_to_train: [\"assistant\"]\n    train_on_eos: \"turn\"\n\nspecial_tokens:\n  eos_token: <|im_end|>\n```\n\nIf this config were to be applied to the sample dataset above, the output would look as such (which can be retrieved via `axolotl preprocess config.yaml --debug`):\n\n```\n<|im_start|>(-100, 128256) user(-100, 882)\n(-100, 198) Hi(-100, 13347) <|im_end|>(-100, 128257)\n(-100, 198) <|im_start|>(-100, 128256) assistant(-100, 78191)\n(-100, 198) How(4438, 4438)  can(649, 649)  I(358, 358)  help(1520, 1520)  you(499, 499) ?(30, 30) <|im_end|>(128257, 128257)\n(-100, 198) <|im_start|>(-100, 128256) user(-100, 882)\n(-100, 198) Can(-100, 6854)  you(-100, 499)  add(-100, 923)  (-100, 220) 3(-100, 18) +(-100, 10) 5(-100, 20) ?(-100, 30) <|im_end|>(-100, 128257)\n(-100, 198) <|im_start|>(-100, 128256) assistant(-100, 78191)\n(-100, 198) The(791, 791)  answer(4320, 4320)  is(374, 374)  (220, 220) 8(23, 23) .(13, 13) <|im_end|>(128257, 128257)\n(-100, 198)\n```\n\nThe first number refers to the label, the second refers to the `token_id`. For example, `-100` labels appear on non-assistant portions, meaning that they are masked during. For assistant portions, the label is the same as the `token_id`.\n\n::: {.callout-note}\n\nIf during `preprocess`, there are a lot of warnings of `Could not find content __ boundary`, please check the FAQ section for [chat_templates](../faq.qmd#chat-templates).\n\n:::\n\n#### Reference\n\nPlease see docs [here](conversation.qmd).\n\n### Instruction Dataset\n\nInstruction datasets are used to train instruction-following models and comprise a prompt, containing an instruction, and a single response. In contrast to chat datasets which may be multi-turn, instruct datasets are typically single-turn.\n\nAn example is of a common format called Alpaca:\n```json\n{\"instruction\": \"...\", \"input\": \"...\", \"output\": \"...\"}\n```\n\nUsing those keys, a prompt can be built based on it.\n```\nBelow is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n{output}\n```\n\nThis can be configured as such:\n```yaml\ndatasets:\n  - path: A.jsonl\n    type: alpaca\n```\n\nAxolotl supports many kinds of instruction dataset. All of them can be found in the [Instruction Dataset Documentation](inst_tune.qmd) with their respective type and sample row format.\n\n#### Custom Instruct Prompt Format\n\nDue to the myriad possibilities of instruction formats, Axolotl allows customizing your own instruction format without having to dive into the code directly.\n\nIn the example below, a sample row is used to output in `mistral_v1` format.\n```json\n{\"input\": \"...\", \"output\": \"...\"}\n```\n\n```yaml\ndatasets:\n  - path: repo\n    type:\n      system_prompt: \"\"\n\n      field_system:\n      field_instruction: input\n      field_input:\n      field_output: output\n\n      # multi-line example with input\n      format: |-\n        [INST] {instruction} {input} [/INST]\n\n      # single-line example without input\n      no_input_format: \"[INST] {instruction} [/INST]\"\n```\n\nThe config sets that the `field_instruction` is actually named `input`, and the `field_input` is empty as we don't have an `input` in this sample. Generally, `instruction` can be thought as the question to the model, and `input` as the additional information with `output` being the response. It is not necessary to have an `input` nor `system`. In the end, the most important part is to understand what format you want it to look like and how you can customize this to your use case.\n\nReference: [Custom Instruct Prompt Format Documentation](inst_tune.qmd#how-to-add-custom-prompt-format).\n\n## Reinforcement Learning from Human Feedback (RLHF)\n\nAs there are multiple RLHF methods with their own dataset requirements. Please see [RLHF documentation](../rlhf.qmd) for more detail.\n"
  },
  {
    "path": "docs/dataset-formats/inst_tune.qmd",
    "content": "---\ntitle: Instruction Tuning\ndescription: Instruction tuning formats for supervised fine-tuning.\norder: 2\n---\n\n## alpaca\n\ninstruction; input(optional)\n\n```{.json filename=\"data.jsonl\"}\n{\"instruction\": \"...\", \"input\": \"...\", \"output\": \"...\"}\n```\n\n## jeopardy\n\nquestion and answer\n\n```{.json filename=\"data.jsonl\"}\n{\"question\": \"...\", \"category\": \"...\", \"answer\": \"...\"}\n```\n\n## oasst\n\ninstruction\n\n```{.json filename=\"data.jsonl\"}\n{\"INSTRUCTION\": \"...\", \"RESPONSE\": \"...\"}\n```\n\n## gpteacher\n\ninstruction; input(optional)\n\n```{.json filename=\"data.jsonl\"}\n{\"instruction\": \"...\", \"input\": \"...\", \"response\": \"...\"}\n```\n\n## reflection\n\ninstruction with reflect; input(optional)\n\n```{.json filename=\"data.jsonl\"}\n{\"instruction\": \"...\", \"input\": \"...\", \"output\": \"...\", \"reflection\": \"...\", \"corrected\": \"...\"}\n```\n\n## explainchoice\n\nquestion, choices, (solution OR explanation)\n\n```{.json filename=\"data.jsonl\"}\n{\"question\": \"...\", \"choices\": [\"...\"], \"solution\": \"...\", \"explanation\": \"...\"}\n```\n\n## concisechoice\n\nquestion, choices, (solution OR explanation)\n\n```{.json filename=\"data.jsonl\"}\n{\"question\": \"...\", \"choices\": [\"...\"], \"solution\": \"...\", \"explanation\": \"...\"}\n```\n\n## summarizetldr\n\narticle and summary\n\n```{.json filename=\"data.jsonl\"}\n{\"article\": \"...\", \"summary\": \"...\"}\n```\n\n## alpaca_chat\n\nbasic instruct for alpaca chat\n\n```{.json filename=\"data.jsonl\"}\n{\"instruction\": \"...\", \"input\": \"...\", \"response\": \"...\"}\n```\n\n## alpaca_chat.load_qa\n\nquestion and answer for alpaca chat\n\n```{.json filename=\"data.jsonl\"}\n{\"question\": \"...\", \"answer\": \"...\"}\n```\n\n## alpaca_chat.load_concise\n\nquestion and answer for alpaca chat, for concise answers\n\n```{.json filename=\"data.jsonl\"}\n{\"instruction\": \"...\", \"input\": \"...\", \"response\": \"...\"}\n```\n\n## alpaca_chat.load_camel_ai\n\nquestion and answer for alpaca chat, for load_camel_ai\n\n```{.json filename=\"data.jsonl\"}\n{\"message_1\": \"...\", \"message_2\": \"...\"}\n```\n\n## alpaca_w_system.load_open_orca\n\nsupport for open orca datasets with included system prompts, instruct\n\n```{.json filename=\"data.jsonl\"}\n{\"system_prompt\": \"...\", \"question\": \"...\", \"response\": \"...\"}\n```\n\n## context_qa\n\nin context question answering from an article\n\n```{.json filename=\"data.jsonl\"}\n{\"article\": \"...\", \"question\": \"...\", \"answer\": \"...\"}\n```\n\n## context_qa.load_v2\n\nin context question answering (alternate)\n\n```{.json filename=\"data.jsonl\"}\n{\"context\": \"...\", \"question\": \"...\", \"answer\": \"...\"}\n```\n\n## context_qa.load_404\n\nin context question answering from an article, with default response for no answer from context\n\n```{.json filename=\"data.jsonl\"}\n{\"article\": \"...\", \"unanswerable_question\": \"...\"}\n```\n\n## creative_acr.load_answer\n\ninstruction and revision\n\n```{.json filename=\"data.jsonl\"}\n{\"instruction\": \"...\", \"revision\": \"...\"}\n```\n\n## creative_acr.load_critique\n\ncritique\n\n```{.json filename=\"data.jsonl\"}\n{\"scores\": \"...\", \"critiques\": \"...\", \"instruction\": \"...\", \"answer\": \"...\"}\n```\n\n## creative_acr.load_revise\n\ncritique and revise\n\n```{.json filename=\"data.jsonl\"}\n{\"scores\": \"...\", \"critiques\": \"...\", \"instruction\": \"...\", \"answer\": \"...\", \"revision\": \"...\"}\n```\n\n## metharme\n\ninstruction, adds additional eos tokens\n\n```{.json filename=\"data.jsonl\"}\n{\"prompt\": \"...\", \"generation\": \"...\"}\n```\n\n## How to add custom prompt format\n\nFor a dataset that is preprocessed for instruction purposes:\n\n```{.json filename=\"data.jsonl\"}\n{\"input\": \"...\", \"output\": \"...\"}\n```\n\nYou can use this example in your YAML config:\n\n```{.yaml filename=\"config.yaml\"}\ndatasets:\n  - path: repo\n    type:\n      system_prompt: \"\"\n      field_system: system\n      field_instruction: input\n      field_output: output\n      format: \"[INST] {instruction} [/INST]\"\n      no_input_format: \"[INST] {instruction} [/INST]\"\n```\n\nSee full config options under [here](../config-reference.qmd).\n"
  },
  {
    "path": "docs/dataset-formats/pretraining.qmd",
    "content": "---\ntitle: Pre-training\ndescription: Data format for a pre-training completion task.\norder: 1\n---\n\nFor pretraining, there is no prompt template or roles.  The only required field is `text`:\n\n```{.json filename=\"data.jsonl\"}\n{\"text\": \"first row\"}\n{\"text\": \"second row\"}\n...\n```\n\n:::{.callout-note}\n\n### Streaming is recommended for large datasets\n\nAxolotl usually loads the entire dataset into memory. This will be challenging for large datasets. Use the following config to enable streaming:\n\n```{.yaml filename=\"config.yaml\"}\npretraining_dataset:\n  - name:\n    path:\n    split:\n    text_column: # column in dataset with the data, usually `text`\n    type: pretrain\n    trust_remote_code:\n    skip: # number of rows of data to skip over from the beginning\n```\n\n:::\n"
  },
  {
    "path": "docs/dataset-formats/stepwise_supervised.qmd",
    "content": "---\ntitle: Stepwise Supervised Format\ndescription: Format for datasets with stepwise completions and labels\norder: 3\n---\n\n## Stepwise Supervised\n\nThe stepwise supervised format is designed for chain-of-thought (COT) reasoning\ndatasets where each example contains multiple completion steps and a preference label\nfor each step.\n\n### Example\n\nHere's a simple example of a stepwise supervised dataset entry:\n\n```json\n{\n  \"prompt\": \"Which number is larger, 9.8 or 9.11?\",\n  \"completions\": [\n    \"The fractional part of 9.8 is 0.8, while the fractional part of 9.11 is 0.11.\",\n    \"Since 0.11 is greater than 0.8, the number 9.11 is larger than 9.8.\"\n  ],\n  \"labels\": [true, false]\n}\n```\n"
  },
  {
    "path": "docs/dataset-formats/template_free.qmd",
    "content": "---\ntitle: Template-Free\ndescription: Construct prompts without a template.\ntoc: true\ntoc-depth: 3\norder: 4\n---\n\n## Background {#sec-background}\n\n### Masking Inputs {#masking-inputs}\n\nOne of the most popular features of\n[axolotl](https://github.com/axolotl-ai-cloud/axolotl) is\nsetting the following configuration value:\n\n\n```yaml\ntrain_on_inputs: false\n```\n\nIf you declare a [dataset formats](https://github.com/axolotl-ai-cloud/axolotl?tab=readme-ov-file#dataset)\nsuch as `alpaca` or `chatml`, axolotl knows what is an input\n(i.e. human) vs. an output (i.e. the assistant) and masks the input\nlabels so that your model can focus on predicting the outputs only.\n\n### You may not want prompt templates {#sec-you-may-not-want-prompt-templates}\n\nHowever, there are many situations where you don't want to use one of\nthese formats or templates. This is because they can:\n\n-   Add unnecessary boilerplate to your prompts.\n-   Create artifacts like special delimiters `<|im_start|>` that can\n    quickly become footguns if you don't include them correctly at\n    inference time.\n-   Enforce a *chat* interface when you do not want one. Sometimes you\n    just want to fine-tune a model to a very specific task and do NOT\n    want multi-turn conversations, roles, etc.\n-   Limit you to only certain roles that the template allows.\n\n### The `input_output` format {#sec-the-inputoutput-format}\n\nYou can construct your prompts without a template by using the\n`input_output` format, by setting `type: input_output` in your\nconfiguration file like this:\n\n**config.yml**\n\n```yaml\ntrain_on_inputs: false # Mask segments of your data\ndatasets:\n  - path: output.jsonl\n    type: input_output  # use template free prompt construction\n```\n\nUnlike `type: completion`, which is also template-free,\n`type: input_output` allows you to mask segments of your text. More\ndetails on how this works are described below.\n\n## Usage {#sec-usage}\n\nThis is how you can use the `input_output` format:\n\n### 1. Prepare Data {#sec-1-prepare-data}\n\nTo use the `input_output` format, collect your data in the following\nformat into a jsonl file (below is the first row from the file\n`output`.jsonl` pretty printed):\n\n```bash\n$ head -n1 output.jsonl | python -m json.tool\n```\n\n:::{.cell-output .cell-output-stdout}\n    {\n        \"segments\": [\n            {\n                \"label\": true,\n                \"text\": \"<s>Hello\\n\"\n            },\n            {\n                \"label\": true,\n                \"text\": \"hi there!. \"\n            },\n            {\n                \"label\": false,\n                \"text\": \"goodbye \"\n            },\n            {\n                \"label\": true,\n                \"text\": \"farewell</s>\"\n            }\n        ]\n    }\n:::\n\nSet `label:false` when you want to mask a segment of text so that the\nmodel isn't trained on it. Some things to keep in mind:\n\n> [!IMPORTANT]\n> 1.  **EOS, BOS, spaces, newlines etc. are entirely up to you. Axolotl\n    concatenates all the segments as-is.** The tokenizer doesn't add\n    anything additional. Notice how I added spaces, newlines, `<s>`\n    (BOS), and `</s>` (EOS) myself.\n> 2.  Make sure you check the materialized output to validate that the\n    prompt is getting assembled how you like.\n\n### 2. Use `type: input_output` {#sec-2-use-type-inputoutput}\n\nLet's materialize data with our `output.jsonl` file by setting\n`type: input_output` in our axolotl config:\n\n```yaml\n# training_config.yaml\nbase_model: mistralai/Mistral-7B-v0.1\ndata_seed: 49\nseed: 49\n\ndatasets:\n  - path: output.jsonl\n    type: input_output\nval_set_size: 0.1\n\nsequence_len: 896\nsample_packing: false\n\nmicro_batch_size: 2\ngradient_accumulation_steps: 3\neval_batch_size: 2\nnum_epochs: 1\nlearning_rate: 0.0002\n\ntrain_on_inputs: false\nspecial_tokens:\n  bos_token: \"<s>\"\n  eos_token: \"</s>\"\n  unk_token: \"<unk>\"\n```\n\nYou can use the following command to materialize your data. The\n`--debug` flag will print the tokens, along with the labels so you can\nverify that the correct items are being ignored:\n\n```bash\naxolotl preprocess training_config.yaml --debug\n\n...\n[2024-03-05 23:36:46,969] [INFO] [axolotl.check_example_labels:35] [PID:607731] [RANK:0] <s>(1, 1) Hello(22557, 22557)\n(13, 13) hi(12014, 12014) there(736, 736) !(28808, 28808) .(28723, 28723) (28705, 28705) good(-100, 1179) bye(-100, 17664) (-100, 28705) fare(19111, 19111) well(5458, 5458) </s>(2, 2)\n\n```\n\nThe format is `decoded_token`(`label`, `token_id`), for example,\n`<s>(1, 1)` means that the token is `<s>`, the label is `1` and the\ntoken_id is `1`. When the label is `-100` then that token is ignored for\ntraining.\n\n### 3. Check the prompts {#sec-3-check-the-prompts}\n\nHere is another way to check the materialized output:\n\n```python\nfrom transformers import AutoTokenizer\nfrom datasets import load_from_disk\nimport yaml\n\ndirectory = !ls last_run_prepared/\nwith open('training_config.yaml', 'r') as f:\n    cfg = yaml.safe_load(f)\nmodel_id = cfg['base_model']\ntok = AutoTokenizer.from_pretrained(model_id)\nds = load_from_disk(f'last_run_prepared/{directory[0]}/')\n```\n\n```python\n>>> row = ds[0]\n>>> print(tok.decode(row['input_ids']))\n<s> Hello\n    hi there!.  goodbye  farewell</s>\n```\n\nWe can check that the right tokens are ignored by comparing the labels\nto each token:\n\n```python\nimport pandas as pd\npd.DataFrame([{'token': tok.decode(i), 'label': l, 'id':i} for i,l in\n              zip(row['input_ids'], row['labels'])])\n```\n\n| token | label | id    |\n|-------|-------|-------|\n| 0     | \\<s\\> | 1     |\n| 1     | Hello | 22557 |\n| 2     | \\\\n   | 13    |\n| 3     | hi    | 12014 |\n| 4     | there | 736   |\n| 5     | !     | 28808 |\n| 6     | .     | 28723 |\n| 7     |       | 28705 |\n| 8     | good  | -100  |\n| 9     | bye   | -100  |\n| 10    |       | -100  |\n| 11    | fare  | 19111 |\n| 12    | well  | 5458  |\n| 13    | \\</s\\>| 2     |\n\n\n\nIf we look at the input data, the above table seems correct! (The jsonl\nversion is repeated below for reference):\n\n\n```bash\n$ head -n1 output.jsonl | python -m json.tool\n```\n\n:::{.cell-output .cell-output-stdout}\n    {\n        \"segments\": [\n            {\n                \"label\": true,\n                \"text\": \"<s>Hello\\n\"\n            },\n            {\n                \"label\": true,\n                \"text\": \"hi there!. \"\n            },\n            {\n                \"label\": false,\n                \"text\": \"goodbye \"\n            },\n            {\n                \"label\": true,\n                \"text\": \"farewell</s>\"\n            }\n        ]\n    }\n:::\n"
  },
  {
    "path": "docs/dataset-formats/tokenized.qmd",
    "content": "---\ntitle: Custom Pre-Tokenized Dataset\ndescription: How to use a custom pre-tokenized dataset.\norder: 5\n---\n\n- Pass an empty `type:` in your axolotl config.\n- Columns in Dataset must be exactly `input_ids`, `attention_mask`, `labels`\n- To indicate that a token should be ignored during training, set its corresponding label to `-100`.\n- You must add BOS and EOS, and make sure that you are training on EOS by not setting its label to -100.\n- For pretraining, do not truncate/pad documents to the context window length.\n- For instruction training, documents must be truncated/padded as desired.\n\nSample config:\n\n```{.yaml filename=\"config.yml\"}\ndatasets:\n  - path: /path/to/your/file.jsonl\n    ds_type: json\n    type:\n```\n\nSample jsonl:\n\n```jsonl\n{\"input_ids\":[271,299,99],\"attention_mask\":[1,1,1],\"labels\":[271,-100,99]}\n{\"input_ids\":[87,227,8383,12],\"attention_mask\":[1,1,1,1],\"labels\":[87,227,8383,12]}\n```\n"
  },
  {
    "path": "docs/dataset_loading.qmd",
    "content": "---\ntitle: Dataset Loading\ndescription: Understanding how to load datasets from different sources\nback-to-top-navigation: true\ntoc: true\ntoc-depth: 5\n---\n\n## Overview\n\nDatasets can be loaded in a number of different ways depending on the how it is saved (the extension of the file) and where it is stored.\n\n## Loading Datasets\n\nWe use the `datasets` library to load datasets and a mix of `load_dataset` and `load_from_disk` to load them.\n\nYou may recognize the similar named configs between `load_dataset` and the `datasets` section of the config file.\n\n```yaml\ndatasets:\n  - path:\n    name:\n    data_files:\n    split:\n    revision:\n    trust_remote_code:\n```\n\n::: {.callout-tip}\n\nDo not feel overwhelmed by the number of options here. A lot of them are optional. In fact, the most common config to use would be `path` and sometimes `data_files`.\n\n:::\n\nThis matches the API of [`datasets.load_dataset`](https://github.com/huggingface/datasets/blob/0b5998ac62f08e358f8dcc17ec6e2f2a5e9450b6/src/datasets/load.py#L1838-L1858), so if you're familiar with that, you will feel right at home.\n\nFor HuggingFace's guide to load different dataset types, see [here](https://huggingface.co/docs/datasets/loading).\n\nFor full details on the config, see [config-reference.qmd](config-reference.qmd).\n\n::: {.callout-note}\n\nYou can set multiple datasets in the config file by more than one entry under `datasets`.\n\n```yaml\ndatasets:\n  - path: /path/to/your/dataset\n  - path: /path/to/your/other/dataset\n```\n\n:::\n\n### Local dataset\n\n#### Files\n\nTo load a JSON file, you would do something like this:\n\n```python\nfrom datasets import load_dataset\n\ndataset = load_dataset(\"json\", data_files=\"data.json\")\n```\n\nWhich translates to the following config:\n\n```yaml\ndatasets:\n  - path: data.json\n    ds_type: json\n```\n\nIn the example above, it can be seen that we can just point the `path` to the file or directory along with the `ds_type` to load the dataset.\n\nThis works for CSV, JSON, Parquet, and Arrow files.\n\n::: {.callout-tip}\n\nIf `path` points to a file and `ds_type` is not specified, we will automatically infer the dataset type from the file extension, so you could omit `ds_type` if you'd like.\n\n:::\n\n#### Directory\n\nIf you're loading a directory, you can point the `path` to the directory.\n\nThen, you have two options:\n\n##### Loading entire directory\n\nYou do not need any additional configs.\n\nWe will attempt to load in the following order:\n- datasets saved with `datasets.save_to_disk`\n- loading entire directory of files (such as with parquet/arrow files)\n\n```yaml\ndatasets:\n  - path: /path/to/your/directory\n```\n\n##### Loading specific files in directory\n\nProvide `data_files` with a list of files to load.\n\n```yaml\ndatasets:\n    # single file\n  - path: /path/to/your/directory\n    ds_type: csv\n    data_files: file1.csv\n\n    # multiple files\n  - path: /path/to/your/directory\n    ds_type: json\n    data_files:\n      - file1.jsonl\n      - file2.jsonl\n\n    # multiple files for parquet\n  - path: /path/to/your/directory\n    ds_type: parquet\n    data_files:\n      - file1.parquet\n      - file2.parquet\n\n```\n\n### HuggingFace Hub\n\nThe method you use to load the dataset depends on how the dataset was created, whether a folder was uploaded directly or a HuggingFace Dataset was pushed.\n\n::: {.callout-note}\n\nIf you're using a private dataset, you will need to enable the `hf_use_auth_token` flag in the root-level of the config file.\n\n:::\n\n#### Folder uploaded\n\nThis would mean that the dataset is a single file or file(s) uploaded to the Hub.\n\n```yaml\ndatasets:\n  - path: org/dataset-name\n    data_files:\n      - file1.jsonl\n      - file2.jsonl\n```\n\n#### HuggingFace Dataset\n\nThis means that the dataset is created as a HuggingFace Dataset and pushed to the Hub via `datasets.push_to_hub`.\n\n```yaml\ndatasets:\n  - path: org/dataset-name\n```\n\n::: {.callout-note}\n\nThere are some other configs which may be required like `name`, `split`, `revision`, `trust_remote_code`, etc depending on the dataset.\n\n:::\n\n### Remote Filesystems\n\nVia the `storage_options` config under `load_dataset`, you can load datasets from remote filesystems like S3, GCS, Azure, and OCI.\n\n::: {.callout-warning}\n\nThis is currently experimental. Please let us know if you run into any issues!\n\n:::\n\nThe only difference between the providers is that you need to prepend the path with the respective protocols.\n\n```yaml\ndatasets:\n    # Single file\n  - path: s3://bucket-name/path/to/your/file.jsonl\n\n    # Directory\n  - path: s3://bucket-name/path/to/your/directory\n```\n\nFor directory, we load via `load_from_disk`.\n\n#### S3\n\nPrepend the path with `s3://`.\n\nThe credentials are pulled in the following order:\n\n- `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, and `AWS_SESSION_TOKEN` environment variables\n- from the `~/.aws/credentials` file\n- for nodes on EC2, the IAM metadata provider\n\n::: {.callout-note}\n\nWe assume you have credentials setup and not using anonymous access. If you want to use anonymous access, let us know! We may have to open a config option for this.\n\n:::\n\nOther environment variables that can be set can be found in [boto3 docs](https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-environment-variables)\n\n#### GCS\n\nPrepend the path with `gs://` or `gcs://`.\n\nThe credentials are loaded in the following order:\n\n- gcloud credentials\n- for nodes on GCP, the google metadata service\n- anonymous access\n\n#### Azure\n\n##### Gen 1\n\nPrepend the path with `adl://`.\n\nEnsure you have the following environment variables set:\n\n- `AZURE_STORAGE_TENANT_ID`\n- `AZURE_STORAGE_CLIENT_ID`\n- `AZURE_STORAGE_CLIENT_SECRET`\n\n##### Gen 2\n\nPrepend the path with `abfs://` or `az://`.\n\nEnsure you have the following environment variables set:\n\n- `AZURE_STORAGE_ACCOUNT_NAME`\n- `AZURE_STORAGE_ACCOUNT_KEY`\n\nOther environment variables that can be set can be found in [adlfs docs](https://github.com/fsspec/adlfs?tab=readme-ov-file#setting-credentials)\n\n#### OCI\n\nPrepend the path with `oci://`.\n\nIt would attempt to read in the following order:\n\n- `OCIFS_IAM_TYPE`, `OCIFS_CONFIG_LOCATION`, and `OCIFS_CONFIG_PROFILE` environment variables\n- when on OCI resource, resource principal\n\nOther environment variables:\n\n- `OCI_REGION_METADATA`\n\nPlease see the [ocifs docs](https://ocifs.readthedocs.io/en/latest/getting-connected.html#Using-Environment-Variables).\n\n### HTTPS\n\nThe path should start with `https://`.\n\n```yaml\ndatasets:\n  - path: https://path/to/your/dataset/file.jsonl\n```\n\nThis must be publically accessible.\n\n## Next steps\n\nNow that you know how to load datasets, you can learn more on how to load your specific dataset format into your target output format [dataset formats docs](dataset-formats).\n"
  },
  {
    "path": "docs/dataset_preprocessing.qmd",
    "content": "---\ntitle: Dataset Preprocessing\ndescription: How datasets are processed\n---\n\n## Overview\n\nDataset pre-processing is the step where Axolotl takes each dataset you've configured alongside\nthe [dataset format](dataset-formats) and prompt strategies to:\n\n - parse the dataset based on the *dataset format*\n - transform the dataset to how you would interact with the model based on the *prompt strategy*\n - tokenize the dataset based on the configured model & tokenizer\n - shuffle and merge multiple datasets together if using more than one\n\nThe processing of the datasets can happen one of two ways:\n\n1. Before kicking off training by calling `axolotl preprocess config.yaml --debug`\n2. When training is started\n\n### What are the benefits of pre-processing?\n\nWhen training interactively or for sweeps\n(e.g. you are restarting the trainer often), processing the datasets can oftentimes be frustratingly\nslow. Pre-processing will cache the tokenized/formatted datasets according to a hash of dependent\ntraining parameters so that it will intelligently pull from its cache when possible.\n\nThe path of the cache is controlled by `dataset_prepared_path:` and is often left blank in example\nYAMLs as this leads to a more robust solution that prevents unexpectedly reusing cached data.\n\nIf `dataset_prepared_path:` is left empty, when training, the processed dataset will be cached in a\ndefault path of `./last_run_prepared/`, but will ignore anything already cached there. By explicitly\nsetting `dataset_prepared_path: ./last_run_prepared`, the trainer will use whatever pre-processed\ndata is in the cache.\n\n### What are the edge cases?\n\nLet's say you are writing a custom prompt strategy or using a user-defined\nprompt template. Because the trainer cannot readily detect these changes, we cannot change the\ncalculated hash value for the pre-processed dataset.\n\nIf you have `dataset_prepared_path: ...` set\nand change your prompt templating logic, it may not pick up the changes you made and you will be\ntraining over the old prompt.\n"
  },
  {
    "path": "docs/debugging.qmd",
    "content": "---\ntitle: Debugging\ndescription: How to debug Axolotl\n---\n\n\nThis document provides some tips and tricks for debugging Axolotl.  It also provides an example configuration for debugging with VSCode.  A good debugging setup is essential to understanding how Axolotl code works behind the scenes.\n\n## Table of Contents\n\n- [General Tips](#general-tips)\n- [Debugging with VSCode](#debugging-with-vscode)\n    - [Background](#background)\n    - [Configuration](#configuration)\n    - [Customizing your debugger](#customizing-your-debugger)\n    - [Video Tutorial](#video-tutorial)\n- [Debugging With Docker](#debugging-with-docker)\n    - [Setup](#setup)\n    - [Attach To Container](#attach-to-container)\n    - [Video - Attaching To Docker On Remote Host](#video---attaching-to-docker-on-remote-host)\n\n## General Tips\n\nWhile debugging it's helpful to simplify your test scenario as much as possible.  Here are some tips for doing so:\n\n> [!Important]\n> All of these tips are incorporated into the [example configuration](#configuration) for debugging with VSCode below.\n\n1. **Make sure you are using the latest version of axolotl**:  This project changes often and bugs get fixed fast.  Check your git branch and make sure you have pulled the latest changes from `main`.\n1. **Eliminate concurrency**: Restrict the number of processes to 1 for both training and data preprocessing:\n    - Set `CUDA_VISIBLE_DEVICES` to a single GPU, ex: `export CUDA_VISIBLE_DEVICES=0`.\n    - Set `dataset_num_proc: 1` in your axolotl config or run the training command with `--dataset_num_proc=1`.\n2. **Use a small dataset**: Construct or use a small dataset from HF Hub. When using a small dataset, you will often have to make sure `sample_packing: False` and `eval_sample_packing: False` to avoid errors.  If you are in a pinch and don't have time to construct a small dataset but want to use from the HF Hub, you can shard the data (this will still tokenize the entire dataset, but will only use a fraction of the data for training.  For example, to shard the dataset into 20 pieces, add the following to your axolotl config):\n\n    ```yaml\n    datasets:\n        ...\n        shards: 20\n    ```\n\n3. **Use a small model**: A good example of a small model is [TinyLlama/TinyLlama-1.1B-Chat-v1.0](https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0).\n4. **Minimize iteration time**: Make sure the training loop finishes as fast as possible, with these settings.\n    - `micro_batch_size: 1`\n    - `max_steps: 1`\n    - `val_set_size: 0`\n5. **Clear Caches:** Axolotl caches certain steps and so does the underlying HuggingFace trainer.  You may want to clear some of these caches when debugging.\n    - Data preprocessing: When debugging data preprocessing, which includes prompt template formation, you may want to delete the directory set in `dataset_prepared_path:` in your axolotl config.  If you didn't set this value, the default is `last_run_prepared`.\n    - HF Hub: If you are debugging data preprocessing, you should clear the relevant HF cache [HuggingFace cache](https://huggingface.co/docs/datasets/cache), by deleting the appropriate `~/.cache/huggingface/datasets/...` folder(s).\n    - **The recommended approach is to redirect all outputs and caches to a temporary folder and delete selected subfolders before each run.  This is demonstrated in the example configuration below.**\n\n\n## Debugging with VSCode\n\n### Background\n\nThe below example shows how to configure VSCode to debug data preprocessing of the `chat_template` format.  This is the format used when you have the following in your axolotl config:\n\n```yaml\ndatasets:\n  - path: <path to your chat_template formatted dataset> # example on HF Hub: fozziethebeat/alpaca_messages_2k_test\n    type: chat_template\n```\n\n>[!Important]\n> If you are already familiar with advanced VSCode debugging, you can skip the below explanation and look at the files [.vscode/launch.json](../.vscode/launch.json) and [.vscode/tasks.json](../.vscode/tasks.json) for an example configuration.\n\n>[!Tip]\n> If you prefer to watch a video, rather than read, you can skip to the [video tutorial](#video-tutorial) below (but doing both is recommended).\n\n### Setup\n\nMake sure you have an [editable install](https://setuptools.pypa.io/en/latest/userguide/development_mode.html) of Axolotl, which ensures that changes you make to the code are reflected at runtime.  Run the following commands from the root of this project:\n\n```bash\npip3 install packaging\npip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'\n```\n\n#### Remote Hosts\n\nIf you developing on a remote host, you can easily use VSCode to debug remotely.  To do so, you will need to follow this [remote - SSH guide](https://code.visualstudio.com/docs/remote/ssh).  You can also see the video below on [Docker and Remote SSH debugging](#video---attaching-to-docker-on-remote-host).\n\n\n### Configuration\n\nThe easiest way to get started is to modify the [.vscode/launch.json](../.vscode/launch.json) file in this project.  This is just an example configuration, so you may need to modify or copy it to suit your needs.\n\nFor example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 accelerate launch -m axolotl.cli.train dev_chat_template.yml`, you would use the below configuration[^1].  Note that we add additional flags that override the axolotl config and incorporate the tips above (see the comments). We also set the working directory to `devtools` and set the `env` variable `HF_HOME` to a temporary folder that is later partially deleted.  This is because we want to delete the HF dataset cache before each run in order to ensure that the data preprocessing code is run from scratch.\n\n```json\n// .vscode/launch.json\n{\n    \"version\": \"0.2.0\",\n    \"configurations\": [\n        {\n            \"name\": \"Debug axolotl prompt - chat_template\",\n            \"type\": \"python\",\n            \"module\": \"accelerate.commands.launch\",\n            \"request\": \"launch\",\n            \"args\": [\n                \"-m\", \"axolotl.cli.train\", \"dev_chat_template.yml\",\n                // The flags below simplify debugging by overriding the axolotl config\n                // with the debugging tips above.  Modify as needed.\n                \"--dataset_num_proc=1\",      // limits data preprocessing to one process\n                \"--max_steps=1\",              // limits training to just one step\n                \"--batch_size=1\",             // minimizes batch size\n                \"--micro_batch_size=1\",       // minimizes batch size\n                \"--val_set_size=0\",           // disables validation\n                \"--sample_packing=False\",     // disables sample packing which is necessary for small datasets\n                \"--eval_sample_packing=False\",// disables sample packing on eval set\n                \"--dataset_prepared_path=temp_debug/axolotl_outputs/data\", // send data outputs to a temp folder\n                \"--output_dir=temp_debug/axolotl_outputs/model\" // send model outputs to a temp folder\n                ],\n            \"console\": \"integratedTerminal\",      // show output in the integrated terminal\n            \"cwd\": \"${workspaceFolder}/devtools\", // set working directory to devtools from the root of the project\n            \"justMyCode\": true,                   // step through only axolotl code\n            \"env\": {\"CUDA_VISIBLE_DEVICES\": \"0\",  // Since we aren't doing distributed training, we need to limit to one GPU\n                    \"HF_HOME\": \"${workspaceFolder}/devtools/temp_debug/.hf-cache\"}, // send HF cache to a temp folder\n            \"preLaunchTask\": \"cleanup-for-dataprep\", // delete temp folders (see below)\n        }\n    ]\n}\n```\n\n**Additional notes about this configuration:**\n\n- The argument `justMyCode` is set to `true` such that you step through only the axolotl code.  If you want to step into dependencies, set this to `false`.\n- The `preLaunchTask`: `cleanup-for-dataprep` is defined in [.vscode/tasks.json](../.vscode/tasks.json) and is used to delete the following folders before debugging, which is essential to ensure that the data pre-processing code is run from scratch:\n    -  `./devtools/temp_debug/axolotl_outputs`\n    - `./devtools/temp_debug/.hf-cache/datasets`\n\n>[!Tip]\n> You may not want to delete these folders. For example, if you are debugging model training instead of data pre-processing, you may NOT want to delete the cache or output folders. You may also need to add additional tasks to the `tasks.json` file depending on your use case.\n\nBelow is the [./vscode/tasks.json](../.vscode/tasks.json) file that defines the `cleanup-for-dataprep` task.  This task is run before each debugging session when you use the above configuration.  Note how there are two tasks that delete the two folders mentioned above.  The third task `cleanup-for-dataprep` is a composite task that combines the two tasks.  A composite task is necessary because VSCode does not allow you to specify multiple tasks in the `preLaunchTask` argument of the `launch.json` file.\n\n```json\n// .vscode/tasks.json\n// this file is used by launch.json\n{\n    \"version\": \"2.0.0\",\n    \"tasks\": [\n      // this task changes into the devtools directory and deletes the temp_debug/axolotl_outputs folder\n      {\n        \"label\": \"delete-outputs\",\n        \"type\": \"shell\",\n        \"command\": \"rm -rf temp_debug/axolotl_outputs\",\n        \"options\":{ \"cwd\": \"${workspaceFolder}/devtools\"},\n        \"problemMatcher\": []\n      },\n      // this task changes into the devtools directory and deletes the `temp_debug/.hf-cache/datasets` folder\n      {\n        \"label\": \"delete-temp-hf-dataset-cache\",\n        \"type\": \"shell\",\n        \"command\": \"rm -rf temp_debug/.hf-cache/datasets\",\n        \"options\":{ \"cwd\": \"${workspaceFolder}/devtools\"},\n        \"problemMatcher\": []\n      },\n        // this task combines the two tasks above\n      {\n       \"label\": \"cleanup-for-dataprep\",\n       \"dependsOn\": [\"delete-outputs\", \"delete-temp-hf-dataset-cache\"],\n      }\n    ]\n}\n```\n\n### Customizing your debugger\n\nYour debugging use case may differ from the example above.  The easiest thing to do is to put your own axolotl config in the `devtools` folder and modify the `launch.json` file to use your config.  You may also want to modify the `preLaunchTask` to delete different folders or not delete anything at all.\n\n### Video Tutorial\n\nThe following video tutorial walks through the above configuration and demonstrates how to debug with VSCode, (click the image below to watch):\n\n<div style=\"text-align: center; line-height: 0;\">\n\n<a href=\"https://youtu.be/xUUB11yeMmc\" target=\"_blank\"\ntitle=\"How to debug Axolotl (for fine tuning LLMs)\"><img\nsrc=\"https://i.ytimg.com/vi/xUUB11yeMmc/maxresdefault.jpg\"\nstyle=\"border-radius: 10px; display: block; margin: auto;\" width=\"560\" height=\"315\" /></a>\n\n<figcaption style=\"font-size: smaller;\"><a href=\"https://hamel.dev\">Hamel Husain's</a> tutorial: <a href=\"https://www.youtube.com/watch?v=xUUB11yeMmc\">Debugging Axolotl w/VSCode</a></figcaption>\n\n</div>\n<br>\n\n## Debugging With Docker\n\nUsing [official Axolotl Docker images](https://hub.docker.com/r/axolotlai/axolotl/tags) is a great way to debug your code, and is a very popular way to use Axolotl.  Attaching VSCode to Docker takes a few more steps.\n\n### Setup\n\nOn the host that is running axolotl (ex: if you are using a remote host), clone the axolotl repo and change your current directory to the root:\n\n```bash\ngit clone https://github.com/axolotl-ai-cloud/axolotl\ncd axolotl\n```\n\n>[!Tip]\n> If you already have axolotl cloned on your host, make sure you have the latest changes and change into the root of the project.\n\nNext, run the desired docker image and mount the current directory. Below is a docker command you can run to do this:[^2]\n\n```bash\ndocker run --privileged --gpus '\"all\"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src=\"${PWD}\",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface axolotlai/axolotl:main-py3.10-cu118-2.0.1\n```\n\n>[!Tip]\n> To understand which containers are available, see the [Docker section of the README](../README.md#docker) and the [DockerHub repo](https://hub.docker.com/r/axolotlai/axolotl/tags).  For details of how the Docker containers are built, see axolotl's [Docker CI builds](../.github/workflows/main.yml).\n\nYou will now be in the container.  Next, perform an editable install of Axolotl:\n\n```bash\npip3 install packaging\npip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'\n```\n\n### Attach To Container\n\nNext, if you are using a remote host, [Remote into this host with VSCode](https://code.visualstudio.com/docs/remote/ssh).  If you are using a local host, you can skip this step.\n\nNext, select `Dev Containers: Attach to Running Container...` using the command palette (`CMD + SHIFT + P`) in VSCode.  You will be prompted to select a container to attach to.  Select the container you just created.  You will now be in the container with a working directory that is at the root of the project.  Any changes you make to the code will be reflected both in the container and on the host.\n\nNow you are ready to debug as described above (see [Debugging with VSCode](#debugging-with-vscode)).\n\n### Video - Attaching To Docker On Remote Host\n\nHere is a short video that demonstrates how to attach to a Docker container on a remote host:\n\n<div style=\"text-align: center; line-height: 0;\">\n\n<a href=\"https://youtu.be/0AuoR7QnHR0\" target=\"_blank\"\ntitle=\"Debugging Axolotl Part 2: Attaching to Docker on a Remote Host\"><img\nsrc=\"https://i.ytimg.com/vi/0AuoR7QnHR0/hqdefault.jpg\"\nstyle=\"border-radius: 10px; display: block; margin: auto;\" width=\"560\" height=\"315\" /></a>\n\n<figcaption style=\"font-size: smaller;\"><a href=\"https://hamel.dev\">Hamel Husain's</a> tutorial: <a href=\"https://youtu.be/0AuoR7QnHR0\">Debugging Axolotl Part 2: Attaching to Docker on a Remote Host\n</a></figcaption>\n\n</div>\n<br>\n\n[^1]: The config actually mimics the command `CUDA_VISIBLE_DEVICES=0 python -m accelerate.commands.launch -m axolotl.cli.train devtools/chat_template.yml`, but this is the same thing.\n\n[^2]: Many of the below flags are recommended best practices by Nvidia when using nvidia-container-toolkit.  You can read more about these flags [here](https://docs.nvidia.com/deeplearning/frameworks/user-guide/index.html).\n"
  },
  {
    "path": "docs/docker.qmd",
    "content": "---\ntitle: \"Docker\"\nformat:\n  html:\n    toc: true\n    toc-depth: 4\n---\n\nThis section describes the different Docker images that are released by AxolotlAI at [Docker Hub](https://hub.docker.com/u/axolotlai).\n\n::: {.callout-important}\nFor Blackwell GPUs, please use the tags with PyTorch 2.7.1 and CUDA 12.8.\n:::\n\n## Base\n\nThe base image is the most minimal image that can install Axolotl. It is based on the `nvidia/cuda` image. It includes python, torch, git, git-lfs, awscli, pydantic, and more.\n\n#### Image\n\n```\naxolotlai/axolotl-base\n```\n\nLink: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl-base)\n\n#### Tags format\n\n```bash\nmain-base-py{python_version}-cu{cuda_version}-{pytorch_version}\n```\n\nTags examples:\n\n- `main-base-py3.11-cu128-2.8.0`\n- `main-base-py3.11-cu128-2.9.1`\n\n## Main\n\nThe main image is the image that is used to run Axolotl. It is based on the `axolotlai/axolotl-base` image and includes the Axolotl codebase, dependencies, and more.\n\n#### Image\n\n```\naxolotlai/axolotl\n```\n\nLink: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl)\n\n#### Tags format {#sec-main-tags}\n\n```bash\n# on push to main\nmain-py{python_version}-cu{cuda_version}-{pytorch_version}\n\n# latest main (currently torch 2.6.0, python 3.11, cuda 12.4)\nmain-latest\n\n# nightly build\n{branch}-{date_in_YYYYMMDD}-py{python_version}-cu{cuda_version}-{pytorch_version}\n\n# tagged release\n{version}\n```\n\n:::{.callout-tip}\n\nThere may be some extra tags appended to the image, like `-vllm` which installs those packages.\n\n:::\n\nTags examples:\n\n- `main-py3.11-cu128-2.8.0`\n- `main-py3.11-cu128-2.9.1`\n- `main-latest`\n- `main-20250303-py3.11-cu124-2.6.0`\n- `main-20250303-py3.11-cu126-2.6.0`\n- `0.12.0`\n\n## Cloud\n\nThe cloud image is the image that is used to run Axolotl in the cloud. It is based on the `axolotlai/axolotl` image and sets ENV variables like HuggingFace cache directories for volume mounts, tmux, and more for different cloud providers.\n\n:::{.callout-tip}\n\nJupyter lab is run by default. Set `JUPYTER_DISABLE=1` in the environment variables to disable it.\n\n:::\n\n#### Image\n\n```\naxolotlai/axolotl-cloud\n```\n\nLink: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl-cloud)\n\n#### Tags format\n\nThis uses the same tags as the [`main` image](#sec-main-tags).\n\n#### Environment variables\n\n- `JUPYTER_DISABLE`: Disable Jupyter lab.\n- `JUPYTER_PASSWORD`: Set a password for the Jupyter lab.\n- `PUBLIC_KEY` / `SSH_KEY`: Add a public key for the SSH service.\n\n#### Volume mounts\n\n:::{.callout-tip}\n\nWe recommend mounting volumes to `/workspace/data` for data persistence. `/workspace/axolotl` contains the source code and is ephemeral.\n\n:::\n\n- `/workspace/data/axolotl-artifacts`: Directory to store Axolotl artifacts.\n- `/workspace/data/huggingface-cache`: Directory to store HuggingFace cache.\n\n## Cloud-no-tmux\n\nThis is the same as the [`cloud` image](#sec-cloud) but without tmux.\n\n#### Image\n\n```\naxolotlai/axolotl-cloud-term\n```\n\nLink: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl-cloud-term)\n\n:::{.callout-note}\n\nThe naming may be a bit confusing as it has `-term` appended to the end.\n\n:::\n\n#### Tags format\n\nThis uses the same tags as the [`cloud` image](#sec-cloud-tags).\n"
  },
  {
    "path": "docs/expert_quantization.qmd",
    "content": "---\ntitle: \"MoE Expert Quantization\"\ndescription: \"Reduce VRAM usage when training MoE model adapters by quantizing expert weights on load\"\n---\n\nTransformers v5 changed MoE expert layers from `nn.Linear` to fused `nn.Parameter` (3D+ tensors).\nThis means `bitsandbytes` can no longer quantize them during model loading, resulting in all expert\nweights being loaded in full bf16 precision and causing massive VRAM usage.\n\n`quantize_moe_experts` solves this by quantizing expert weights during model loading.\nIt intercepts the weight loading process, quantizes each expert tensor on the fly, and\nimmediately frees the original bf16 tensor from VRAM. This dramatically reduces peak memory.\nFor example, GLM-4.7-Flash QLoRA drops from ~127GiB to ~23GiB reserved memory.\n\n## Usage\n\nEnable expert quantization in your Axolotl config:\n\n```yaml\nquantize_moe_experts: true\n```\n\nThis works with both 4-bit (QLoRA) and 8-bit (LoRA) quantization.\n\n### Expert LoRA targeting\n\nYou can optionally apply LoRA adapters directly to expert weights using `lora_target_parameters`:\n\n```yaml\nlora_target_parameters:\n  - mlp.experts.gate_up_proj\n  - mlp.experts.down_proj\n  # - mlp.gate.weight  # router\n```\n\n::: {.callout-note}\n`lora_dropout` must be `0` when using `lora_target_parameters`.\n:::\n\n## Requirements\n\n- Requires (`adapter: lora` and `load_in_8bit: true`) or (`adapter: qlora` and `load_in_4bit: true`)\n- CUDA GPUs only (not tested with ROCm or other backends)\n- FSDP2 compatible for distributed training\n\n## Limitations\n\n- `lora_target_linear` is not compatible with `quantize_moe_experts`. See [Expert LoRA targeting](#expert-lora-targeting) instead.\n- `cpu_ram_efficient_loading` hangs / takes long time with FSDP2 + QLoRA.\n- Total model parameter count may display incorrectly (trainable param count is correct).\n- FSDP LoRA (8-bit) may have a large initial VRAM spike at the first 1-2 steps, which then drops. QLoRA does not exhibit this.\n- FSDP2 may use more VRAM per GPU than single GPU training due to not all layers being properly sharded across ranks.\n- Model loading takes longer due to on-demand quantization, even on consecutive runs.\n- DeepSpeed has not been tested.\n\n## Implementation details\n\nThe quantization is applied by patching transformers to intercept weight loading.\nWhen a 3D+ CUDA tensor with \"expert\" in its name is detected:\n\n- **4-bit mode:** Uses bitsandbytes NF4 parametrization (configurable via `bnb_4bit_quant_type`).\n- **8-bit mode:** Uses a custom row-wise int8 parametrization with bitsandbytes dequantization.\n\nThe original bf16 tensor is freed immediately after quantization. Multiple sub-patches are applied to\ntransformers, PEFT and accelerate FSDP2 to support these parametrized expert modules.\n\nFor full implementation details, see [PR #3439](https://github.com/axolotl-ai-cloud/axolotl/pull/3439).\n"
  },
  {
    "path": "docs/faq.qmd",
    "content": "---\ntitle: FAQ\ndescription: Frequently asked questions\n---\n\n### General\n\n**Q: The trainer stopped and hasn't progressed in several minutes.**\n\n> A: Usually an issue with the GPUs communicating with each other. See the [NCCL doc](nccl.qmd)\n\n**Q: exitcode: -9**\n\n> A: This usually happens when you run out of system RAM.\n\n**Q: exitcode: -7 while using deepspeed**\n\n> A: Try upgrading deepspeed w: `pip install -U deepspeed`\n\n**Q: AttributeError: 'DummyOptim' object has no attribute 'step'**\n\n**Q: ModuleNotFoundError: No module named 'mpi4py' using single GPU with deepspeed**\n\n> A: You may be using deepspeed with single gpu. Please remove the `deepspeed:` section in the yaml file or `--deepspeed` CLI flag.\n\n**Q: The codes is stuck on saving preprocessed datasets.**\n\n> A: This is usually an issue with the GPU. This can be resolved through setting the os environment variable `CUDA_VISIBLE_DEVICES=0`. If you are on runpod, this is usually a pod issue. Starting a new pod should take care of it.\n\n**Q: Received mismatch error on merge adapters / loading adapters between torch.Size of checkpoint and model.**\n\n> A: This is likely due to vocab size mismatch. By default, Axolotl expands the model's embeddings if the tokenizer has more tokens than the model. Please use the `axolotl merge-lora` command to merge the adapters instead of using your own scripts.\n\n> On the other hand, if the model has more tokens than the tokenizer, Axolotl does not shrink the model's embeddings unless `shrink_embeddings: true` is set in the config.\n\n**Q: How to call Axolotl via custom python scripts?**\n\n> A: Since Axolotl is just Python, please see `src/axolotl/cli/main.py` on how each command is called.\n\n**Q: How to know the value to use for `fsdp_transformer_layer_cls_to_wrap`?**\n\n> A: This is the class name of the transformer layer to wrap with FSDP. For example, for `LlamaForCausalLM`, the value is `LlamaDecoderLayer`. To find this for a specific model, check the model's `PreTrainedModel` definition and look for `_no_split_modules` variable in the `modeling_<model_name>.py` file within `transformers` library.\n\n**Q: ValueError: Asking to pad but the tokenizer does not have a padding token. Please select a token to use as pad_token**\n\n> A: This is because the tokenizer does not have a padding token. Please add a padding token to the tokenizer via:\n\n> ```yaml\n> special_tokens:\n>   # str. If you're not sure, set to same as `eos_token`.\n>   pad_token: \"...\"\n> ```\n\n**Q: `IterableDataset error` or `KeyError: 'input_ids'` when using `preprocess` CLI**\n\n> A: This is because you may be using `preprocess` CLI with `pretraining_dataset:` or `skip_prepare_dataset: true` respectively. Please use `axolotl train` CLI directly instead as these datasets are prepared on demand.\n\n**Q: vLLM is not working with Axolotl**\n\n> A: We currently recommend torch 2.6.0 for use with `vllm`. Please ensure you use the right version. For Docker, please use the `main-py3.11-cu124-2.6.0` tag.\n\n**Q: FA2 2.8.0 `undefined symbol` runtime error on CUDA 12.4**\n\n> A: There seems to be a wheel issue with FA2 2.8.0 on CUDA 12.4. Try CUDA 12.6 instead or downgrade to FA2 2.7.4. Please refer to the upstream issue: https://github.com/Dao-AILab/flash-attention/issues/1717.\n\n**Q: Can we mix text and text+image datasets for VLM training?**\n\n> A: Yes, you can for newer VLM arch. The ones that would not work are LLaVA / Pixtral arch. If you notice one not working, please let us know!\n\n**Q: Why is `memory/max_*` different from `nvidia-smi`?**\n\n> A: We use `torch` APIs to retrieve this information. You can see https://docs.pytorch.org/docs/stable/notes/cuda.html#cuda-memory-management for more information.\n\n### Chat templates\n\n**Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`**\n\n> A: This means that the property mapping for the stated attribute does not exist when building `chat_template` prompt. For example, if `no attribute 'content'`, please check you have added the correct mapping for `content` under `message_property_mappings`.\n\n**Q: `Empty template generated for turn ___`**\n\n> A: The `content` is empty for that turn.\n\n**Q: `Could not find content start/end boundary for turn __`**\n\n> A: The specific turn's start/end could not be detected. Please ensure you have set the `eos_token` following your `chat_template`. Otherwise, this could be a `chat_template` which doesn't use proper boundaries for each turn (like system). On the rare occurrence, make sure your content is not `[[dummy_message]]`. Please let us know about this.\n\n**Q: `Content end boundary is before start boundary for turn ___`**\n\n> A: This is an edge case which should not occur. Please create an Issue if this happens.\n\n**Q: `Content end boundary is the same as start boundary for turn ___. This is likely an empty turn.`**\n\n> A: This is likely an empty turn.\n\n**Q: The EOS token is incorrectly being masked or not being masked / `EOS token __ not found in chat template`.**\n\n> A: There can be two reasons:\n\n> 1. This is because of the mismatch between `tokenizer.eos_token` and EOS token in template. Please make sure to set `eos_token: ` under `special_tokens: ` to the same EOS token as in template.\n\n> 2. The EOS token is not in the template. Please check if your template is correct. As an example, `phi_35` template does not use its dedicated EOS token `<|endoftext|>` at the end.\n\n**Q: \"`chat_template` choice is `tokenizer_default` but tokenizer's `chat_template` is null. Please add a `chat_template` in tokenizer config\"**\n\n> A: This is because the tokenizer does not have a chat template. Please add a chat template in the tokenizer config. See [chat_template](dataset-formats/conversation.qmd#chat-template) for more details.\n\n**Q: The EOT token(s) are incorrectly being masked or not being masked / `EOT token __ not found in chat template`.**\n\n> A: There can be two reasons:\n\n> 1. The EOT token is different from the EOS token and was not specified under `eot_tokens: `. Please set `eot_tokens: ` to the same EOT token(s) as in template.\n\n> 2. There is more than one EOT token per turn in the template. Please raise an issue with examples as we recognize this as an edge case.\n\n**Q: `EOT token encoding failed. Please check if the token is valid and can be encoded.`**\n\n> A: There could be some issue with the tokenizer or unicode encoding. Please raise an issue with examples with the EOT token & tokenizer causing the issue.\n\n**Q: `EOT token __ is encoded as multiple tokens.`**\n\n> A: This is because the EOT token is encoded as multiple tokens which can cause unexpected behavior. Please add it under `tokens: ` or (recommended) override unused added_tokens via `added_tokens_overrides: `.\n\n**Q: `Conflict between train_on_eos and train_on_eot. eos_token is in eot_tokens and train_on_eos != train_on_eot`**\n\n> A: This is because the EOS token is in the `eot_tokens: ` while mismatch between `train_on_eos: ` and `train_on_eot: `. This will cause one to override the other. Please ensure that `train_on_eos: ` and `train_on_eot: ` are the same or remove the EOS token from `eot_tokens: `.\n\n**Q: If `eot_tokens: ` is not provided, what happens?**\n\n> A: If `eot_tokens: ` is not provided, the default behavior is the same as before. EOS tokens used to delimit turns are masked/unmasked depending on whether the turn is trainable.\n\n> Internally, `eot_tokens: tokenizer.eos_token` and `train_on_eot: train_on_eos` (which defaults to `turn`). This transition helps clarify the naming and behavior of EOT/EOS tokens.\n\n**Q: `Data processing error: CAS service error`**\n\n> A: Try disabling XET with `export HF_HUB_DISABLE_XET=1`\n\n**Q: `torch._inductor.exc.LoweringException: NoValidChoicesError: No choices to select, please consider adding ATEN into max_autotune_gemm_backends config (defined in torch/_inductor/config.py) to allow at least one choice. `**\n\n> A: Depending on the version of torch, you may need to include this in your YAML:\n\n> ```yaml\n> flex_attn_compile_kwargs:\n>   dynamic: false\n>   mode: max-autotune-no-cudagraphs\n> ```\n\n**Q: `ValueError(\"Backward pass should have cleared tracker of all tensors\")`\n\n> A: This may happen due to edge cases in using the modern OffloadActivations context manager for CUDA streams. If you encounter this error, you may have success using the naive implementation with `offload_activations: legacy` in your YAML.\n\n**Q: `Error parsing tool_calls arguments as JSON.`\n\n> A: There is an error parsing string arguments to a dict. Please check your dataset and the error message for more details.\n"
  },
  {
    "path": "docs/fsdp_qlora.qmd",
    "content": "---\ntitle: \"FSDP + QLoRA\"\ndescription: Use FSDP with QLoRA to fine-tune large LLMs on consumer GPUs.\nformat:\n  html:\n    toc: true\n---\n\n## Background\n\nUsing FSDP with QLoRA is essential for **fine-tuning larger (70b+ parameter) LLMs on consumer GPUs.**  For example, you can use FSDP + QLoRA to train a 70b model on two 24GB GPUs[^1].\n\nBelow, we describe how to use this feature in Axolotl.\n\n## Usage\n\nTo enable `QLoRA` with `FSDP`, you need to perform the following steps:\n\n> ![Tip]\n> See the [example config](#example-config) file in addition to reading these instructions.\n\n1. Set `adapter: qlora` in your axolotl config file.\n2. Enable FSDP in your axolotl config, as [described here](multi-gpu.qmd#sec-fsdp).\n3. Use one of the supported model types: `llama`, `mistral` or `mixtral`.\n\n## Enabling Swap for FSDP2\n\nIf available memory is insufficient even after FSDP's CPU offloading, you can enable swap memory usage by setting `cpu_offload_pin_memory: false` alongside `offload_params: true` in FSDP config.\n\nThis disables memory pinning, allowing FSDP to use disk swap space as fallback. Disabling memory pinning itself incurs performance overhead, and actually having to use swap adds more, but it may enable training larger models that would otherwise cause OOM errors on resource constrained systems.\n\n## Example Config\n\n[examples/llama-2/qlora-fsdp.yml](../examples/llama-2/qlora-fsdp.yml) contains an example of how to enable QLoRA + FSDP in axolotl.\n\n## References\n\n- [PR #1378](https://github.com/axolotl-ai-cloud/axolotl/pull/1378) enabling QLoRA in FSDP in Axolotl.\n- [Blog Post](https://www.answer.ai/posts/2024-03-06-fsdp-qlora.html) from the [Answer.AI](https://www.answer.ai/) team describing the work that enabled QLoRA in FSDP.\n- Related HuggingFace PRs Enabling FDSP + QLoRA:\n    - Accelerate [PR#2544](https://github.com/huggingface/accelerate/pull/2544 )\n    - Transformers [PR#29587](https://github.com/huggingface/transformers/pull/29587)\n    - TRL [PR#1416](https://github.com/huggingface/trl/pull/1416)\n    - PEFT [PR#1550](https://github.com/huggingface/peft/pull/1550)\n\n\n\n\n[^1]: This was enabled by [this work](https://www.answer.ai/posts/2024-03-06-fsdp-qlora.html) from the Answer.AI team.\n"
  },
  {
    "path": "docs/getting-started.qmd",
    "content": "---\ntitle: \"Quickstart\"\nformat:\n  html:\n    toc: true\n    toc-depth: 3\n    number-sections: true\nexecute:\n  enabled: false\n---\n\nThis guide will walk you through your first model fine-tuning project with Axolotl.\n\n## Quick Example {#sec-quick-example}\n\nLet's start by fine-tuning a small language model using LoRA. This example uses a 1B parameter model to ensure it runs on most GPUs.\nAssuming `axolotl` is installed (if not, see our [Installation Guide](installation.qmd))\n\n1. Download example configs:\n```bash\naxolotl fetch examples\n```\n\n2. Run the training:\n```bash\naxolotl train examples/llama-3/lora-1b.yml\n```\n\nThat's it! Let's understand what just happened.\n\n## Understanding the Process {#sec-understanding}\n\n### The Configuration File {#sec-config}\n\nThe YAML configuration file controls everything about your training. Here's what (part of) our example config looks like:\n\n```yaml\nbase_model: NousResearch/Llama-3.2-1B\n\nload_in_8bit: true\nadapter: lora\n\ndatasets:\n  - path: teknium/GPT4-LLM-Cleaned\n    type: alpaca\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.1\noutput_dir: ./outputs/lora-out\n```\n\n::: {.callout-tip}\n`load_in_8bit: true` and `adapter: lora` enables LoRA adapter finetuning.\n\n- To perform Full finetuning, remove these two lines.\n- To perform QLoRA finetuning, replace with `load_in_4bit: true` and `adapter: qlora`.\n:::\n\nSee our [config options](config-reference.qmd) for more details.\n\n### Training {#sec-training}\n\nWhen you run `axolotl train`, Axolotl:\n\n1. Downloads the base model\n2. (If specified) applies QLoRA/LoRA adapter layers\n3. Loads and processes the dataset\n4. Runs the training loop\n5. Saves the trained model and / or LoRA weights\n\n## Your First Custom Training {#sec-custom}\n\nLet's modify the example for your own data:\n\n1. Create a new config file `my_training.yml`:\n\n```yaml\nbase_model: NousResearch/Nous-Hermes-llama-1b-v1\n\nload_in_8bit: true\nadapter: lora\n\n# Training settings\nmicro_batch_size: 2\nnum_epochs: 3\nlearning_rate: 0.0003\n\n# Your dataset\ndatasets:\n  - path: my_data.jsonl        # Your local data file\n    type: alpaca               # Or other format\n```\n\nThis specific config is for LoRA fine-tuning a model with instruction tuning data using\nthe `alpaca` dataset format, which has the following format:\n\n```json\n{\n    \"instruction\": \"Write a description of alpacas.\",\n    \"input\": \"\",\n    \"output\": \"Alpacas are domesticated South American camelids...\"\n}\n```\n\nPlease see our [Dataset Formats](dataset-formats) for more dataset formats and how to\nformat them.\n\n2. Prepare your JSONL data in the specified format (in this case, the expected `alpaca`\nformat):\n\n```json\n{\"instruction\": \"Classify this text\", \"input\": \"I love this!\", \"output\": \"positive\"}\n{\"instruction\": \"Classify this text\", \"input\": \"Not good at all\", \"output\": \"negative\"}\n```\n\n3. Run the training:\n\n```bash\naxolotl train my_training.yml\n```\n\n## Common Tasks {#sec-common-tasks}\n\n::: {.callout-tip}\n\nThe same yaml file is used for training, inference, and merging.\n\n:::\n\n### Testing Your Model {#sec-testing}\n\nAfter training, test your model:\n\n```bash\naxolotl inference my_training.yml --lora-model-dir=\"./outputs/lora-out\"\n```\n\nMore details can be found in [Inference](inference.qmd).\n\n### Using a UI {#sec-ui}\n\nLaunch a Gradio interface:\n\n```bash\naxolotl inference my_training.yml --lora-model-dir=\"./outputs/lora-out\" --gradio\n```\n\n### Preprocessing Data {#sec-preprocessing}\n\nFor large datasets, preprocess first:\n\n```bash\naxolotl preprocess my_training.yml\n```\n\nPlease make sure to set `dataset_prepared_path: ` in your config to set the path to save the prepared dataset.\n\nMore details can be found in [Dataset Preprocessing](dataset_preprocessing.qmd).\n\n### Merging LoRA weights {#sec-merging-lora}\n\nTo merge the LoRA weights back into the base model, run:\n\n```bash\naxolotl merge-lora my_training.yml --lora-model-dir=\"./outputs/lora-out\"\n```\n\nThe merged model will be saved in the `{output_dir}/merged` directory.\n\nMore details can be found in [Merging LoRA weights](inference.qmd#sec-merging).\n\n## Next Steps {#sec-next-steps}\n\nNow that you have the basics, you might want to:\n\n- Try different model architectures\n- Experiment with hyperparameters\n- Use more advanced training methods\n- Scale up to larger models\n\nCheck our other guides for details on these topics:\n\n- [Configuration Guide](config-reference.qmd) - Full configuration options\n- [Dataset Loading](dataset_loading.qmd) - Loading datasets from various sources\n- [Dataset Formats](dataset-formats) - Working with different data formats\n- [Multi-GPU Training](multi-gpu.qmd)\n- [Multi-Node Training](multi-node.qmd)\n"
  },
  {
    "path": "docs/gradient_checkpointing.qmd",
    "content": "---\ntitle: Gradient Checkpointing and Activation Offloading\n---\n\nGradient checkpointing and activation offloading are techniques used to optimize the performance of deep learning\nmodels by reducing the memory footprint and improving computational efficiency.\n\n### Enabling Gradient Checkpointing\n\n```yaml\ngradient_checkpointing: true\n```\n\n### Enabling Activation Offloading\n\n```yaml\ngradient_checkpointing: true  # required for activation offloading\nactivation_offloading: true\n```\n\nActivation offloading variants:\n\nThe default `activation_offloading: true` offloads activations to CPU and uses CUDA streams\nto overlap the communications and computations when offloading.\n\nThe `activation_offloading: legacy` naively offloads activations to CPU and without additional optimizations.\n\nFor resource constrained environments with limited CPU memory, `activation_offloading: disk` offloads\nactivations to disk instead of CPU RAM so that much larger context lengths can be trained with minimal memory.\n"
  },
  {
    "path": "docs/inference.qmd",
    "content": "---\ntitle: \"Inference and Merging\"\nformat:\n  html:\n    toc: true\n    toc-depth: 3\n    number-sections: true\nexecute:\n  enabled: false\n---\n\nThis guide covers how to use your trained models for inference, including model loading, interactive testing, merging adapters, and common troubleshooting steps.\n\n## Quick Start {#sec-quickstart}\n\n::: {.callout-tip}\nUse the same config used for training on inference/merging.\n:::\n\n### Basic Inference {#sec-basic}\n\n::: {.panel-tabset}\n\n## LoRA Models\n\n```{.bash}\naxolotl inference your_config.yml --lora-model-dir=\"./lora-output-dir\"\n```\n\n## Full Fine-tuned Models\n\n```{.bash}\naxolotl inference your_config.yml --base-model=\"./completed-model\"\n```\n\n:::\n\n## Advanced Usage {#sec-advanced}\n\n### Gradio Interface {#sec-gradio}\n\nLaunch an interactive web interface:\n\n```{.bash}\naxolotl inference your_config.yml --gradio\n```\n\n### File-based Prompts {#sec-file-prompts}\n\nProcess prompts from a text file:\n\n```{.bash}\ncat /tmp/prompt.txt | axolotl inference your_config.yml \\\n  --base-model=\"./completed-model\" --prompter=None\n```\n\n### Memory Optimization {#sec-memory}\n\nFor large models or limited memory:\n\n```{.bash}\naxolotl inference your_config.yml --load-in-8bit=True\n```\n\n## Merging LoRA Weights {#sec-merging}\n\nMerge LoRA adapters with the base model:\n\n```{.bash}\naxolotl merge-lora your_config.yml --lora-model-dir=\"./completed-model\"\n```\n\n### Memory Management for Merging {#sec-memory-management}\n\n::: {.panel-tabset}\n\n## Configuration Options\n\n```{.yaml}\ngpu_memory_limit: 20GiB  # Adjust based on your GPU\nlora_on_cpu: true        # Process on CPU if needed\n```\n\n## Force CPU Merging\n\n```{.bash}\nCUDA_VISIBLE_DEVICES=\"\" axolotl merge-lora ...\n```\n\n:::\n\n## Tokenization {#sec-tokenization}\n\n### Common Issues {#sec-tokenization-issues}\n\n::: {.callout-warning}\nTokenization mismatches between training and inference are a common source of problems.\n:::\n\nTo debug:\n\n1. Check training tokenization:\n```{.bash}\naxolotl preprocess your_config.yml --debug\n```\n\n2. Verify inference tokenization by decoding tokens before model input\n\n3. Compare token IDs between training and inference\n\n### Special Tokens {#sec-special-tokens}\n\nConfigure special tokens in your YAML:\n\n```{.yaml}\nspecial_tokens:\n  bos_token: \"<s>\"\n  eos_token: \"</s>\"\n  unk_token: \"<unk>\"\ntokens:\n  - \"<|im_start|>\"\n  - \"<|im_end|>\"\n```\n\n## Troubleshooting {#sec-troubleshooting}\n\n### Common Problems {#sec-common-problems}\n\n::: {.panel-tabset}\n\n## Memory Issues\n\n- Use 8-bit loading\n- Reduce batch sizes\n- Try CPU offloading\n\n## Token Issues\n\n- Verify special tokens\n- Check tokenizer settings\n- Compare training and inference preprocessing\n\n## Performance Issues\n\n- Verify model loading\n- Check prompt formatting\n- Ensure temperature/sampling settings\n\n:::\n\nFor more details, see our [debugging guide](debugging.qmd).\n"
  },
  {
    "path": "docs/input_output.qmd",
    "content": "---\ntitle: Template-free prompt construction\ndescription: \"Template-free prompt construction with the `input_output` format\"\n---\n\nThe documentation moved to [here](dataset-formats/template_free.qmd).\n"
  },
  {
    "path": "docs/installation.qmd",
    "content": "---\ntitle: \"Installation\"\nformat:\n  html:\n    toc: true\n    toc-depth: 3\n    number-sections: true\nexecute:\n  enabled: false\n---\n\nThis guide covers all the ways you can install and set up Axolotl for your environment.\n\n## Requirements {#sec-requirements}\n\n- NVIDIA GPU (Ampere architecture or newer for `bf16` and Flash Attention) or AMD GPU\n- Python ≥3.11\n- PyTorch ≥2.6.0\n\n## Installation Methods {#sec-installation-methods}\n\n::: {.callout-important}\nPlease make sure to have Pytorch installed before installing Axolotl in your local environment.\n\nFollow the instructions at: [https://pytorch.org/get-started/locally/](https://pytorch.org/get-started/locally/)\n:::\n\n::: {.callout-important}\nFor Blackwell GPUs, please use Pytorch 2.9.1 and CUDA 12.8.\n:::\n\n### PyPI Installation (Recommended) {#sec-pypi}\n\n```{.bash}\npip3 install -U packaging setuptools wheel ninja\npip3 install --no-build-isolation axolotl[flash-attn,deepspeed]\n```\n\nWe use `--no-build-isolation` in order to detect the installed PyTorch version (if\ninstalled) in order not to clobber it, and so that we set the correct version of\ndependencies that are specific to the PyTorch version or other installed\nco-dependencies.\n\n### uv Installation {#sec-uv}\n\nuv is a fast, reliable Python package installer and resolver built in Rust. It offers significant performance improvements over pip and provides better dependency resolution, making it an excellent choice for complex environments.\n\nInstall uv if not already installed\n```{.bash}\ncurl -LsSf https://astral.sh/uv/install.sh | sh\nsource $HOME/.local/bin/env\n```\n\nChoose your CUDA version to use with PyTorch; e.g. `cu124`, `cu126`, `cu128`,\nthen create the venv and activate\n```{.bash}\nexport UV_TORCH_BACKEND=cu126\nuv venv --no-project --relocatable\nsource .venv/bin/activate\n```\n\nInstall PyTorch\n- PyTorch 2.6.0 recommended\n```{.bash}\nuv pip install packaging setuptools wheel\nuv pip install torch==2.6.0\nuv pip install awscli pydantic\n```\n\nInstall axolotl from PyPi\n```{.bash}\nuv pip install --no-build-isolation axolotl[deepspeed,flash-attn]\n\n# optionally install with vLLM if you're using torch==2.6.0 and want to train w/ GRPO\nuv pip install --no-build-isolation axolotl[deepspeed,flash-attn,vllm]\n```\n\n### Edge/Development Build {#sec-edge-build}\n\nFor the latest features between releases:\n\n```{.bash}\ngit clone https://github.com/axolotl-ai-cloud/axolotl.git\ncd axolotl\npip3 install -U packaging setuptools wheel ninja\npip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'\n```\n\n### Docker {#sec-docker}\n\n```{.bash}\ndocker run --gpus '\"all\"' --rm -it axolotlai/axolotl:main-latest\n```\n\nFor development with Docker:\n\n```{.bash}\ndocker compose up -d\n```\n\n::: {.callout-tip}\n### Advanced Docker Configuration\n```{.bash}\ndocker run --privileged --gpus '\"all\"' --shm-size 10g --rm -it \\\n  --name axolotl --ipc=host \\\n  --ulimit memlock=-1 --ulimit stack=67108864 \\\n  --mount type=bind,src=\"${PWD}\",target=/workspace/axolotl \\\n  -v ${HOME}/.cache/huggingface:/root/.cache/huggingface \\\n  axolotlai/axolotl:main-latest\n```\n:::\n\n::: {.callout-important}\nFor Blackwell GPUs, please use `axolotlai/axolotl:main-py3.11-cu128-2.9.1` or the cloud variant `axolotlai/axolotl-cloud:main-py3.11-cu128-2.9.1`.\n:::\n\nPlease refer to the [Docker documentation](docker.qmd) for more information on the different Docker images that are available.\n\n## Cloud Environments {#sec-cloud}\n\n### Cloud GPU Providers {#sec-cloud-gpu}\n\nFor providers supporting Docker:\n\n- Use `axolotlai/axolotl-cloud:main-latest`\n- Available on:\n    - [RunPod](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)\n    - [Vast.ai](https://cloud.vast.ai?ref_id=62897&template_id=bdd4a49fa8bce926defc99471864cace&utm_source=axolotl&utm_medium=partner&utm_campaign=template_launch_july2025&utm_content=docs_link)\n    - [PRIME Intellect](https://app.primeintellect.ai/dashboard/create-cluster?image=axolotl&location=Cheapest&security=Cheapest&show_spot=true)\n    - [Modal](https://www.modal.com?utm_source=github&utm_medium=github&utm_campaign=axolotl)\n    - [Novita](https://novita.ai/gpus-console?templateId=311)\n    - [JarvisLabs.ai](https://jarvislabs.ai/templates/axolotl)\n    - [Latitude.sh](https://latitude.sh/blueprint/989e0e79-3bf6-41ea-a46b-1f246e309d5c)\n\n### Google Colab {#sec-colab}\n\n[![](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/axolotl-ai-cloud/axolotl/blob/main/examples/colab-notebooks/colab-axolotl-example.ipynb#scrollTo=msOCO4NRmRLa)\n\n## Platform-Specific Instructions {#sec-platform-specific}\n\n### macOS {#sec-macos}\n\n```{.bash}\npip3 install --no-build-isolation -e '.'\n```\n\nSee @sec-troubleshooting for Mac-specific issues.\n\n### Windows {#sec-windows}\n\n::: {.callout-important}\nWe recommend using WSL2 (Windows Subsystem for Linux) or Docker.\n:::\n\n## Environment Managers {#sec-env-managers}\n\n### Conda/Pip venv {#sec-conda}\n\n1. Install Python ≥3.11\n2. Install PyTorch: https://pytorch.org/get-started/locally/\n3. Install Axolotl:\n   ```{.bash}\n   pip3 install -U packaging setuptools wheel ninja\n   pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'\n   ```\n4. (Optional) Login to Hugging Face:\n   ```{.bash}\n   hf auth login\n   ```\n\n## Troubleshooting {#sec-troubleshooting}\n\nIf you encounter installation issues, see our [FAQ](faq.qmd) and [Debugging Guide](debugging.qmd).\n"
  },
  {
    "path": "docs/lora_optims.qmd",
    "content": "---\ntitle: \"LoRA Optimizations\"\ndescription: \"Custom autograd functions and Triton kernels in Axolotl for optimized LoRA fine-tuning\"\n---\n\nInspired by [Unsloth](https://github.com/unslothai/unsloth), we've implemented two\noptimizations for LoRA and QLoRA fine-tuning, supporting both single GPU and multi-GPU\n(including the DDP, DeepSpeed, and FSDP2 settings) training. These include (1) SwiGLU\nand GEGLU activation function Triton kernels, and (2) LoRA MLP and attention custom\nautograd functions. Our goal was to leverage operator fusion and tensor re-use in order\nto improve speed and reduce memory usage during the forward and backward passes of\nthese calculations.\n\nWe currently support several common model architectures, including (but not limited to):\n\n- `llama`\n- `mistral`\n- `qwen2`\n- `gemma`\n- `gemma2`\n- `gemma3`\n\n<details>\n\nThe set of models we support is currently limited by our attention patching strategy,\nwhich assumes (and replaces) specific code blocks for query / key / value and output\nprojections:\n\n```python\nORIGINAL_QKV_CODE = \"\"\"\n    query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)\n    key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)\n    value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)\n\"\"\".lstrip(\n    \"\\n\"\n)\n\nORIGINAL_O_CODE = \"\"\"\n    attn_output = self.o_proj(attn_output)\n\"\"\".lstrip(\n    \"\\n\"\n)\n```\n\nIs replaced with:\n\n```python\nPATCHED_QKV_CODE = \"\"\"\n    query_states, key_states, value_states = self.apply_qkv(hidden_states)\n    query_states = query_states.view(hidden_shape).transpose(1, 2)\n    key_states = key_states.view(hidden_shape).transpose(1, 2)\n    value_states = value_states.view(hidden_shape).transpose(1, 2)\n\"\"\".lstrip(\n    \"\\n\"\n)\n\nPATCHED_O_CODE = \"\"\"\n    attn_output = self.apply_o(attn_output)\n\"\"\".lstrip(\n    \"\\n\"\n)\n```\n\nWhere `apply_qkv` and `apply_o` are defined in the `axolotl.kernels.lora` module.\n\nWe welcome testing of other model architectures and / or PRs to expand our patching\nlogic to be compatible with more of them.\n\n</details>\n\n::: {.callout-tip}\nCheck out our [LoRA optimizations blog](https://axolotlai.substack.com/p/accelerating-lora-fine-tuning-with).\n:::\n\n## Usage\n\nThese optimizations can be enabled in your Axolotl config YAML file. The\n`lora_mlp_kernel` option enables the optimized MLP path, while `lora_qkv_kernel` and\n`lora_o_kernel` enable the fused query-key-value projection and optimized output\nprojection, respectively.\n\n```yaml\nlora_mlp_kernel: true\nlora_qkv_kernel: true\nlora_o_kernel: true\n```\n\n::: {.callout-note}\nCurrently, LoRA kernels are not supported for RLHF training, only SFT.\n:::\n\n::: {.callout-warning}\nLoRA kernels do not support remote modeling code.\n:::\n\n## Requirements\n\n- One or more NVIDIA or AMD GPUs (in order to use the Triton kernels)\n    - Note: Set `TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1` to enable [memory-efficient attention on AMD GPUs](https://github.com/ROCm/aotriton/issues/16#issuecomment-2346675491)\n- Targeted LoRA adapters cannot use Dropout\n    - This may limit model expressivity / cause overfitting\n- Targeted LoRA adapters cannot have bias terms\n    - This may limit model expressivity\n\nModels with pre-existing LoRA adapters that use Dropout or have bias terms may need to\nbe re-finetuned without these features in order to be useful.\n\n## Implementation details\n\n### Custom autograd functions\n\nThe LoRA MLP autograd function optimizes the entire MLP computation path. It fuses the\nLoRA and base weight computations together and provides a single, efficient backward\npass for the entire MLP block.\n\nFor attention components, similar optimizations are provided through a function that\nhandles the query, key, and value projections, and a function that handles the output\nprojection. They are designed to work with the existing `transformers` attention\nimplementation via some monkey-patching logic.\n\n### Triton kernels\n\nTwo activation functions (SwiGLU and GeGLU) are implemented with Triton kernels for\nimproved speed and memory performance. These kernels handle both the forward and\nbackward passes.\n\n### Integration\n\nThe custom autograd functions and Triton kernels are designed to work together. The\nautograd function manages the high-level computation flow and gradient tracking, while\ncalling the Triton kernels for the activation function computation. During the backward\npass, the kernel computes both the activation output and the required gradients, which\nthe autograd function then uses to compute the final gradients for the entire\ncomputation path.\n\n## Future Work\n\n- Support for additional model architectures\n- Support for dropout and bias\n- Additional operator fusions\n"
  },
  {
    "path": "docs/lr_groups.qmd",
    "content": "---\ntitle: Learning Rate Groups\ndescription: \"Setting different learning rates by module name\"\n---\n\n## Background\n\nInspired by LoRA+, Axolotl allows practitioners to specify separate learning rates for each module or groups of\nmodules in a model.\n\n## Example\n\n```yaml\nlr_groups:\n  - name: o_proj\n    modules:\n      - self_attn.o_proj.weight\n    lr: 1e-6\n  - name: q_proj\n    modules:\n      - model.layers.2.self_attn.q_proj.weight\n    lr: 1e-5\n\nlearning_rate: 2e-5\n```\n\nIn this example, we have a default learning rate of 2e-5 across the entire model, but we have a separate learning rate\nof 1e-6 for all the self attention `o_proj` modules across all layers, and a learning are of 1e-5 to the 3rd layer's\nself attention `q_proj` module.\n\n::: {.callout-note}\n\nWe currently only support varying `lr` for now. If you're interested in adding support for others (`weight_decay`), we welcome PRs. See https://github.com/axolotl-ai-cloud/axolotl/blob/613bcf90e58f3ab81d3827e7fc572319908db9fb/src/axolotl/core/trainers/mixins/optimizer.py#L17\n\n:::\n"
  },
  {
    "path": "docs/mac.qmd",
    "content": "---\ntitle: Mac M-series\ndescription: Mac M-series support\n---\n\nCurrently Axolotl on Mac is partially usable, many of the dependencies of Axolotl including Pytorch do not support MPS or have incomplete support.\n\nCurrent support:\n\n- [x] Support for all models\n- [x] Full training of models\n- [x] LoRA training\n- [x] Sample packing\n- [ ] FP16 and BF16 (awaiting AMP support for MPS in Pytorch)\n- [ ] Tri-dao's flash-attn (until it is supported use spd_attention as an alternative)\n- [ ] xformers\n- [ ] bitsandbytes (meaning no 4/8 bits loading and bnb optimizers)\n- [ ] qlora\n- [ ] DeepSpeed\n\nUntested:\n\n- FSDP\n"
  },
  {
    "path": "docs/mixed_precision.qmd",
    "content": "---\ntitle: \"Mixed Precision Training\"\nformat:\n  html:\n    toc: true\n    toc-depth: 3\n    number-sections: true\n    code-tools: true\nexecute:\n  enabled: false\n---\n\nMixed precision training uses lower precision data types to reduce memory usage and increase training speed while maintaining model quality. Axolotl supports several mixed precision formats:\n\n- **FP16** - Half precision 16-bit (Pascal generation+)\n- **BF16** - Brain Float 16-bit (Ampere generation+)\n- **FP8** - 8-bit floating point (Hopper generation+)\n\n## FP16 Mixed Precision {#sec-fp16}\n\n### Overview {#sec-fp16-overview}\n\nFP16 is the traditional half-precision format, supported on older GPUs but can be less numerically stable than BF16.\n\n### Configuration {#sec-fp16-config}\n\n```{.yaml}\nfp16: true\n```\n\n### FP16 Considerations {#sec-fp16-considerations}\n\n- May require gradient scaling to prevent underflow\n- Less numerically stable than BF16\n- Can cause training instability with some model architectures\n- Consider using BF16 if your hardware supports it\n\n## BF16 Mixed Precision {#sec-bf16}\n\n### Overview {#sec-bf16-overview}\n\nBF16 (Brain Float 16) offers better numerical stability than FP16 and is the recommended mixed precision format for modern GPUs. It provides the same dynamic range as FP32 while using half the memory.\n\n### Configuration {#sec-bf16-config}\n\n```{.yaml}\n# Automatic BF16 detection (recommended)\nbf16: auto\n\n# Or explicitly enable\nbf16: true\n\n# For evaluation with BF16\nbf16: full  # Equivalent to bf16_full_eval in the HF trainer\n```\n\n## FP8 Mixed Precision {#sec-fp8}\n\n::: {.callout-note}\nFP8 support is experimental and requires compatible hardware (H100, H200) and recent PyTorch versions with TorchAO.\n:::\n\n### What is FP8? {#sec-fp8-overview}\n\nFP8 (8-bit floating point) can provide significant time savings compared to FP16/BF16 while maintaining training stability. Axolotl's implementation uses PyTorch's TorchAO library with \"tensorwise\" scaling strategy.\n\n### Requirements {#sec-fp8-software}\n\n- Hopper+ GPUs (H100/H200)\n- PyTorch 2.7+ (+ compatible TorchAO version)\n- CUDA 12.4+\n\n### Configuration {#sec-fp8-config}\n\nAdd to your YAML config:\n\n```{.yaml}\n# Enable FP8 mixed precision\nfp8: true\n\n# Optional: Enable FP8 for FSDP all-gather operations\nfp8_enable_fsdp_float8_all_gather: true\n\n# Enable torch.compile (almost always necessary for FP8 speedups)\ntorch_compile: true\n```\n\n::: {.callout-important}\n**torch.compile is critical for FP8 performance**\n\nFP8 training requires `torch_compile: true` to see meaningful speedups. Without compilation, FP8 may actually be slower and use more memory than FP16/BF16.\n:::\n\n### Advanced FP8 Configs {#sec-fp8-advanced}\n\nFor [FSDP](multi-gpu.qmd#sec-fsdp) (Fully Sharded Data Parallel) training:\n\n```{.yaml}\nfp8: true\nfp8_enable_fsdp_float8_all_gather: true\n\ntorch_compile: true\n\n# FSDP configuration\nfsdp_version: 2\nfsdp_config:\n  offload_params: false\n  cpu_ram_efficient_loading: true\n  auto_wrap_policy: TRANSFORMER_BASED_WRAP\n  transformer_layer_cls_to_wrap: LlamaDecoderLayer\n  state_dict_type: FULL_STATE_DICT\n  reshard_after_forward: true\n```\n\n## Best Practices {#sec-best-practices}\n\n### Choosing Precision Format {#sec-choosing-format}\n\n- **Start with automatic detection**: `bf16: auto`\n- **For Hopper+ (H100/H200)**: Try FP8 + torch.compile for maximum speed\n- **For Ampere (A100/RTX 30/40)**: Use BF16\n- **For older Pascal/Turing GPUs**: Use FP16 with caution\n- **For very old or unsupported GPUs**: Use FP32\n\n### Validation and Testing {#sec-validation}\n\nAlways validate your mixed precision setup:\n\n- **Start with a small dataset** to verify stability\n- **Monitor loss curves** for irregularities\n- **Compare with FP32 baseline** when possible\n- **Test evaluation metrics** match expectations\n\n### FP8 Particulars {#sec-fp8-details}\n\n- Use cases\n  - Single GPU training\n  - Multi GPU training with FSDP2 or Deepspeed\n- Speedups\n  - Please refer to the [TorchAO FP8 training benchmarks](https://github.com/pytorch/ao/tree/main/torchao/float8#rowwise-scaling) for expected matmul speedups for different (M, K, N) settings\n  - Concrete number for LLaMA 3 8B training can be found [here](https://github.com/pytorch/ao/tree/main/torchao/float8#training-benchmarks)\n- Known issues:\n  - FP8 + DDP + `torch.compile` (causes [error](https://gist.github.com/djsaunde/0c1664c32e44a64d31b5e01b4aafe5c4))\n  - FP8 + FSDP2 + `torch.compile` + FSDP2 activation checkpointing tends to be _slower_ than the BF16 equivalent training\n  - Flash Attention 2 does not play nicely with `torch.compile`\n\nSee `examples/llama-3/3b-fp8-fsdp2.yaml` for an optimized example config. Enabling FP8 mixed precision + FP8 all-gather training results in ~10% faster iterations per second vs. BF16 for a relatively small (3B param) model\n\nFor more information on multi-GPU training, see our [Multi-GPU guide](multi-gpu.qmd).\n"
  },
  {
    "path": "docs/multi-gpu.qmd",
    "content": "---\ntitle: \"Multi-GPU\"\nformat:\n  html:\n    toc: true\n    toc-depth: 3\n    # number-sections: true\n    code-tools: true\nexecute:\n  enabled: false\n---\n\nThis guide covers advanced training configurations for multi-GPU setups using Axolotl.\n\n## Overview {#sec-overview}\n\nWhen training on multiple GPUs, Axolotl supports 3 sharding/parallelism strategies. Additionally, you can layer specific optimization features on top of that strategy.\n\nYou generally cannot combine these strategies; they are mutually exclusive.\n\n1.  **DeepSpeed**: Powerful optimization library, supports ZeRO stages 1-3.\n2.  **FSDP (Fully Sharded Data Parallel)**: PyTorch's native sharding implementation (Recommended).\n3.  **DDP (Distributed Data Parallel)**: PyTorch's native parallelism implementation (Default if neither of the above are selected).\n\nThese features can often be combined with the strategies above:\n\n*   **Sequence Parallelism**: Splits long sequences across GPUs (Compatible with DDP, DeepSpeed, and FSDP).\n*   **FSDP + QLoRA**: Combines 4-bit quantization with FSDP (Specific to FSDP).\n\n## DeepSpeed {#sec-deepspeed}\n\n### Configuration {#sec-deepspeed-config}\n\nAdd to your YAML config:\n\n```{.yaml}\ndeepspeed: deepspeed_configs/zero1.json\n```\n### Usage {#sec-deepspeed-usage}\n\n```{.bash}\n# Fetch deepspeed configs (if not already present)\naxolotl fetch deepspeed_configs\n\n# Passing arg via config\naxolotl train config.yml\n\n# Passing arg via cli\naxolotl train config.yml --deepspeed deepspeed_configs/zero1.json\n```\n\n### ZeRO Stages {#sec-zero-stages}\n\nWe provide default configurations for:\n\n- ZeRO Stage 1 (`zero1.json`)\n- ZeRO Stage 1 with torch compile (`zero1_torch_compile.json`)\n- ZeRO Stage 2 (`zero2.json`)\n- ZeRO Stage 3 (`zero3.json`)\n- ZeRO Stage 3 with bf16 (`zero3_bf16.json`)\n- ZeRO Stage 3 with bf16 and CPU offload params(`zero3_bf16_cpuoffload_params.json`)\n- ZeRO Stage 3 with bf16 and CPU offload params and optimizer (`zero3_bf16_cpuoffload_all.json`)\n\n::: {.callout-tip}\n\nChoose the configuration that offloads the least amount to memory while still being able to fit on VRAM for best performance.\n\nStart from Stage 1 -> Stage 2 -> Stage 3.\n\n:::\n\n## Fully Sharded Data Parallel (FSDP) {#sec-fsdp}\n\nFSDP allows you to shard model parameters, gradients, and optimizer states across data parallel workers.\n\n::: {.callout-note}\n\nFSDP2 is recommended for new users. FSDP1 is deprecated and will be removed in an upcoming release of Axolotl.\n\n:::\n\n### FSDP + QLoRA {#sec-fsdp-qlora}\n\nFor combining FSDP with QLoRA, see our [dedicated guide](fsdp_qlora.qmd).\n\n### Migrating from FSDP1 to FSDP2 {#sec-migrate-fsdp1-fsdp2}\n\nTo migrate your config from FSDP1 to FSDP2, you must use the `fsdp_version` top-level config field to specify the FSDP version, and\nalso follow the config field mapping below to update field names.\n\n#### Config mapping\n\nFSDP1 | FSDP2\n-------- | --------\nfsdp_sharding_strategy | reshard_after_forward\nfsdp_backward_prefetch_policy | **REMOVED**\nfsdp_backward_prefetch | **REMOVED**\nfsdp_forward_prefetch | **REMOVED**\nfsdp_sync_module_states | **REMOVED**\nfsdp_cpu_ram_efficient_loading | cpu_ram_efficient_loading\nfsdp_state_dict_type | state_dict_type\nfsdp_use_orig_params | **REMOVED**\nfsdp_activation_checkpointing | activation_checkpointing\n\nFor more details, please see the migration guide in the [torchtitan repo](https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md). In Axolotl,\nif you were using the following FSDP1 config:\n\n```{.yaml}\nfsdp_version: 1\nfsdp_config:\n  fsdp_offload_params: false\n  fsdp_cpu_ram_efficient_loading: true\n  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP\n  fsdp_transformer_layer_cls_to_wrap: Qwen3DecoderLayer\n  fsdp_state_dict_type: FULL_STATE_DICT\n  fsdp_sharding_strategy: FULL_SHARD\n```\n\nYou can migrate to the following FSDP2 config:\n\n```{.yaml}\nfsdp_version: 2\nfsdp_config:\n  offload_params: false\n  cpu_ram_efficient_loading: true\n  auto_wrap_policy: TRANSFORMER_BASED_WRAP\n  transformer_layer_cls_to_wrap: Qwen3DecoderLayer\n  state_dict_type: FULL_STATE_DICT\n  reshard_after_forward: true\n```\n\n### FSDP1 (deprecated) {#sec-fsdp-config}\n\n::: {.callout-note}\n\nUsing `fsdp` to configure FSDP is deprecated and will be removed in an upcoming release of Axolotl. Please use `fsdp_config` as above instead.\n\n:::\n\n```{.yaml}\nfsdp:\n  - full_shard\n  - auto_wrap\nfsdp_config:\n  fsdp_offload_params: true\n  fsdp_state_dict_type: FULL_STATE_DICT\n  fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer\n```\n\n\n## Sequence parallelism {#sec-sequence-parallelism}\n\nWe support sequence parallelism (SP) via the\n[ring-flash-attention](https://github.com/zhuzilin/ring-flash-attention) project. This\nallows one to split up sequences across GPUs, which is useful in the event that a\nsingle sequence causes OOM errors during model training.\n\nSee our [dedicated guide](sequence_parallelism.qmd) for more information.\n\n## Performance Optimization {#sec-performance}\n\n### Liger Kernel Integration {#sec-liger}\n\nPlease see [docs](custom_integrations.qmd#liger) for more info.\n\n## Troubleshooting {#sec-troubleshooting}\n\n### NCCL Issues {#sec-nccl}\n\nFor NCCL-related problems, see our [NCCL troubleshooting guide](nccl.qmd).\n\n### Common Problems {#sec-common-problems}\n\n::: {.panel-tabset}\n\n## Memory Issues\n\n- Reduce `micro_batch_size`\n- Reduce `eval_batch_size`\n- Adjust `gradient_accumulation_steps`\n- Consider using a higher ZeRO stage\n\n## Training Instability\n\n- Start with DeepSpeed ZeRO-2\n- Monitor loss values\n- Check learning rates\n\n:::\n\nFor more detailed troubleshooting, see our [debugging guide](debugging.qmd).\n"
  },
  {
    "path": "docs/multi-node.qmd",
    "content": "---\ntitle: Multi Node\ndescription: How to use Axolotl on multiple machines\n---\n\nThe below are three ways to train multi-node in Axolotl.\n\n::: {.callout-important}\nEach machine needs a copy of Axolotl, we suggest using the same commit to ensure compatibility.\n\nYou will also need to have the same configuration file for your model on each machine.\n\nMake sure the main machine is reachable by other machines.\n:::\n\n## Accelerate\n\nYou will need to create a configuration for accelerate, either by using `accelerate config` and follow the instructions or you can use one of the preset below:\n\n~/.cache/huggingface/accelerate/default_config.yaml\n```yaml\ncompute_environment: LOCAL_MACHINE\ndebug: false\ndistributed_type: FSDP\ndowncast_bf16: 'no'\nmachine_rank: 0 # Set to 0 for the main machine, increment by one for other machines\nmain_process_ip: 10.0.0.4 # Set to main machine's IP\nmain_process_port: 5000\nmain_training_function: main\nmixed_precision: bf16\nnum_machines: 2 # Change to the number of machines\nnum_processes: 4 # That's the total number of GPUs, (for example: if you have 2 machines with 4 GPU, put 8)\nrdzv_backend: static\nsame_network: true\ntpu_env: []\ntpu_use_cluster: false\ntpu_use_sudo: false\nuse_cpu: false\n```\n\nConfigure your model to use FSDP in the Axolotl yaml. For example:\n```yaml\nfsdp_version: 2\nfsdp_config:\n  offload_params: true\n  state_dict_type: FULL_STATE_DICT\n  auto_wrap_policy: TRANSFORMER_BASED_WRAP\n  transformer_layer_cls_to_wrap: LlamaDecoderLayer\n  reshard_after_forward: true\n```\n\nAll you have to do now is launch using accelerate as you would usually do on each machine and voila, the processes will start once you have launched accelerate on every machine.\n\n## Raytrain\n\nPlease see ray train doc [here](ray-integration.qmd).\n\n## Torchrun\n\nIf you are using Infiniband, we recommend torchrun to utilize the full bandwidth.\n\nSet the following env (change buffersize/socketname depending on your system):\n\n```bash\nexport NCCL_IB_DISABLE=0\nexport NCCL_SOCKET_IFNAME=\"eth0,en,eth,em,bond\"\nexport NCCL_BUFFSIZE=2097152\n```\n\nRun the following on each node:\n\n### Option 1: New Axolotl CLI with launcher args (Recommended)\n\n```bash\naxolotl train config.yaml --launcher torchrun -- --nnodes $num_nodes --nproc_per_node $gpu_per_node --rdzv_id $rdzv_id --rdzv_backend c10d --rdzv_endpoint \"$head_node_ip:$head_node_port\"\n```\n\n### Option 2: Direct torchrun (Legacy)\n\n```bash\ntorchrun --nnodes $num_nodes --nproc_per_node $gpu_per_node --rdzv_id $rdzv_id --rdzv_backend c10d --rdzv_endpoint \"$head_node_ip:$head_node_port\" -m axolotl.cli.train config.yaml\n```\n\nPlease make sure to substitute the placeholder variables:\n\n- `num_nodes`: Number of nodes (containing GPUs)\n- `gpu_per_node`: Number of gpus per node\n- `head_node_ip`: IP of the head node (make sure other machines can connect to this)\n- `head_node_port`: Port of the head node (make sure other machines can connect to this. Default 29400)\n- `rdzv_id`: A unique job ID that is used by the job across nodes.\n\nThe new CLI approach (Option 1) is recommended as it provides consistent argument handling and works seamlessly with other Axolotl CLI features.\n\nMore info on the available configs can be found on the Pytorch docs [here](https://pytorch.org/docs/stable/elastic/run.html)\n"
  },
  {
    "path": "docs/multimodal.qmd",
    "content": "---\ntitle: MultiModal / Vision Language Models (BETA)\nformat:\n  html:\n    toc: true\n    toc-depth: 3\n---\n\n## Supported Models\n\n- [Mllama](#sec-mllama)\n- [Llama4](#sec-llama4)\n- [Pixtral](#sec-pixtral)\n- [Llava-1.5](#sec-llava-15)\n- [Mistral-Small-3.1](#sec-mistral-small-31)\n- [Mistral-Small-4](#sec-mistral-small-4)\n- [Magistral-Small-2509](#sec-magistral-small-2509)\n- [Voxtral](#sec-voxtral)\n- [Gemma-3](#sec-gemma-3)\n- [Gemma-3n](#sec-gemma-3n)\n- [Qwen2-VL](#sec-qwen2-vl)\n- [Qwen2.5-VL](#sec-qwen25-vl)\n- [Qwen3.5](#sec-qwen3-5)\n- [GLM-4.6V](#sec-glm-4-6v)\n- [SmolVLM2](#sec-smolvlm2)\n- [LFM2-VL](#sec-lfm2-vl)\n- [Intern-VL](#sec-intern-vl)\n\n## Usage\n\nMultimodal support is limited and doesn't have full feature parity.\n\nHere are the hyperparams you'll need to use to finetune a multimodal model.\n\n```yaml\nprocessor_type: AutoProcessor\n\nskip_prepare_dataset: true\nremove_unused_columns: false  # leave columns in place as they are needed to handle image embeddings during training\nsample_packing: false  # not yet supported with multimodal\n\nchat_template:  # see in next section if specified\n\n# example dataset\ndatasets:\n  - path: HuggingFaceH4/llava-instruct-mix-vsft\n    type: chat_template\n    split: train[:1%]\n\n# (optional) if doing lora, only finetune the Language model,\n# leave the vision model and vision tower frozen\n# load_in_8bit: true\nadapter: lora\nlora_target_modules: 'model.language_model.layers.[\\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'\n\n# (optional) if you want to resize images to a set size\nimage_size: 512\nimage_resize_algorithm: bilinear\n```\n\nPlease see [examples](https://github.com/axolotl-ai/axolotl/tree/main/examples) folder for full configs.\n\n::: {.callout-tip}\nSome of our chat_templates have been extended to support broader dataset types. This should not break any existing configs.\n:::\n\n::: {.callout-note}\nAs of now, we do not truncate nor drop samples based on `sequence_len` as each arch has different ways to process non-text tokens. We are looking for help on this.\n:::\n\n### Mllama {#sec-mllama}\n\n```yaml\nbase_model: meta-llama/Llama-3.2-11B-Vision-Instruct\n\nchat_template: llama3_2_vision\n```\n\n### Llama4 {#sec-llama4}\n\n```yaml\nbase_model: meta-llama/Llama-4-Scout-17B-16E-Instruct\n\nchat_template: llama4\n```\n\n### Pixtral {#sec-pixtral}\n\n```yaml\nbase_model: mistralai/Pixtral-12B-2409\n\nchat_template: pixtral\n```\n\n### Llava-1.5 {#sec-llava-15}\n\n```yaml\nbase_model: llava-hf/llava-1.5-7b-hf\n\nchat_template: llava\n```\n\n### Mistral-Small-3.1 {#sec-mistral-small-31}\n\n::: {.callout-tip}\nPlease make sure to install vision lib via `pip install 'mistral-common[opencv]==1.8.5'`\n:::\n\n```yaml\nbase_model: mistralai/Mistral-Small-3.1-24B-Instruct-2503\n```\n\n### Mistral-Small-4 {#sec-mistral-small-4}\n\n```yaml\nbase_model: mistralai/Mistral-Small-4-119B-2603\n```\n\n### Magistral-Small-2509 {#sec-magistral-small-2509}\n\n::: {.callout-tip}\nPlease make sure to install vision lib via `pip install 'mistral-common[opencv]==1.8.5'`\n:::\n\n```yaml\nbase_model: mistralai/Magistral-Small-2509\n```\n\n### Voxtral {#sec-voxtral}\n\n::: {.callout-tip}\nPlease make sure to install audio lib via `pip3 install librosa==0.11.0 'mistral_common[audio]==1.8.3'`\n:::\n\n```yaml\nbase_model: mistralai/Voxtral-Mini-3B-2507\n\nprocessor_type: VoxtralProcessor\n```\n\n### Gemma-3 {#sec-gemma-3}\n\n::: {.callout-tip}\nThe Gemma3-1B model is a text-only model, so please train as regular text model.\n:::\n\nFor multi-modal 4B/12B/27B models, use the following config:\n\n```yaml\nbase_model: google/gemma-3-4b-it\n\nchat_template: gemma3\n```\n\n### Gemma-3n {#sec-gemma-3n}\n\n::: {.callout-warning}\nThe model's initial loss and grad norm will be very high. We suspect this to be due to the Conv in the vision layers.\n:::\n\n::: {.callout-tip}\nPlease make sure to install `timm` via `pip3 install timm==1.0.17`\n:::\n\n```yaml\nbase_model: google/gemma-3n-E2B-it\n\nchat_template: gemma3n\n```\n\n### Qwen2-VL {#sec-qwen2-vl}\n\n```yaml\nbase_model: Qwen/Qwen2-VL-7B-Instruct\n\nchat_template: qwen2_vl\n```\n\n### Qwen2.5-VL {#sec-qwen25-vl}\n\n```yaml\nbase_model: Qwen/Qwen2.5-VL-7B-Instruct\n\nchat_template: qwen2_vl  # same as qwen2-vl\n```\n\n### Qwen3-VL {#sec-qwen3-vl}\n\n```yaml\nbase_model: Qwen/Qwen3-VL-4B-Instruct\n\nchat_template: qwen2_vl  # same as qwen2-vl\n```\n\n### Qwen3.5 {#sec-qwen3-5}\n\n```yaml\nbase_model: Qwen/Qwen3.5-9B\n\nchat_template: qwen3_5\n```\n\n### GLM-4.6V {#sec-glm-4-6v}\n\nBoth GLM-4.6V (106B MoE) and GLM-4.6V-Flash (9B) are supported.\n\n```yaml\n# GLM-4.6V (106B MoE version)\nbase_model: zai-org/GLM-4.6V\n\n# OR GLM-4.6V-Flash (9B version)\nbase_model: zai-org/GLM-4.6V-Flash\n```\n\n### SmolVLM2 {#sec-smolvlm2}\n\n::: {.callout-tip}\nPlease make sure to install `num2words` via `pip3 install num2words==0.5.14`\n:::\n\n```yaml\nbase_model: HuggingFaceTB/SmolVLM2-500M-Video-Instruct\n```\n\n### LFM2-VL {#sec-lfm2-vl}\n\n::: {.callout-warning}\nPlease uninstall `causal-conv1d` via `pip3 uninstall -y causal-conv1d`\n:::\n\n```yaml\nbase_model: LiquidAI/LFM2-VL-450M\n```\n\n### Intern-VL {#sec-intern-vl}\n\n::: {.callout-tip}\nPlease make sure to install `timm` via `pip3 install timm==1.0.19`\n:::\n\n```yaml\nbase_model: OpenGVLab/InternVL3_5-8B\n```\n\n## Dataset Format\n\nFor multi-modal datasets, we adopt an extended `chat_template` format similar to OpenAI's Message format.\n\n- A message is a list of `role` and `content`.\n- `role` can be `system`, `user`, `assistant`, etc.\n- `content` is a list of `type` and (`text`, `image`, `path`, `url`, `base64`, or `audio`).\n\n### Image\n\n::: {.callout-note}\nFor backwards compatibility:\n\n- If the dataset has a `images` or `image` column of `list[Image]`, it will be appended to the first `content` list as `{\"type\": \"image\", \"image\": ...}`. However, if the content already has a `{\"type\": \"image\"}` but no `image` key, it will be set the `image` key.\n- If `content` is a string, it will be converted to a list with `type` as `text`.\n:::\n\nFor image loading, you can use the following keys within `content` alongside `\"type\": \"image\"`:\n\n- `\"path\": \"/path/to/image.jpg\"`\n- `\"url\": \"https://example.com/image.jpg\"`\n- `\"base64\": \"...\"`\n- `\"image\": PIL.Image`\n\n### Audio\n\nFor audio loading, you can use the following keys within `content` alongside `\"type\": \"audio\"`:\n\n- `\"path\": \"/path/to/audio.mp3\"`\n- `\"url\": \"https://example.com/audio.mp3\"`\n- `\"audio\": np.ndarray`\n\n::: {.callout-tip}\n\nYou may need to install `librosa` via `pip3 install librosa==0.11.0`.\n\n:::\n\n### Video\n\n::: {.callout-warning}\n\nThis is not well tested at the moment. We welcome contributors!\n\n:::\n\nFor video loading, you can use the following keys within `content` alongside `\"type\": \"video\"`:\n\n- `\"path\": \"/path/to/video.mp4\"`\n- `\"url\": \"https://example.com/video.mp4\"`\n- `\"video\": np.ndarray | list[PIL.Image.Image] | torch.Tensor` (or list of the aforementioned)\n\n### Example\n\nHere is an example of a multi-modal dataset:\n```json\n[\n  {\n    \"messages\": [\n        {\n            \"role\": \"system\",\n            \"content\": [\n              {\"type\": \"text\", \"text\": \"You are a helpful assistant.\"}\n              ]\n        },\n        {\n            \"role\": \"user\",\n            \"content\": [\n                {\"type\": \"image\", \"url\": \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg\"},\n                {\"type\": \"text\", \"text\": \"Describe this image in detail.\"}\n            ]\n        },\n        {\n            \"role\": \"assistant\",\n            \"content\": [\n              {\"type\": \"text\", \"text\": \"The image is a bee.\"}\n            ]\n        }\n    ]\n  }\n]\n```\n\n## FAQ\n\n1. `PIL.UnidentifiedImageError: cannot identify image file ...`\n\n`PIL` could not retrieve the file at `url` using `requests`. Please check for typo. One alternative reason is that the request is blocked by the server.\n"
  },
  {
    "path": "docs/multipack.qmd",
    "content": "---\ntitle: Multipack (Sample Packing)\ndescription: Multipack is a technique to pack multiple sequences into a single batch to increase training throughput.\n---\n\n## Visualization of Multipack with Flash Attention\n\nBecause Flash Attention simply drops the attention mask, we do not need to\nconstruct a 4d attention mask. We only need to concatenate the sequences into\na single batch and let flash attention know where each new sequence begins.\n\n\n4k context, bsz =4,\neach character represents 256 tokens\nX represents a padding token\n\n```\n   0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5\n[[ A A A A A A A A A A A ]\n   B B B B B B ]\n   C C C C C C C ]\n   D D D D ]]\n\n[[ E E E E E E E E ]\n [ F F F F ]\n [ G G G ]\n [ H H H H ]]\n\n[[ I I I ]\n [ J J J ]\n [ K K K K K]\n [ L L L ]]\n```\n\nafter padding to longest input in each step\n```\n   0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5\n[[ A A A A A A A A A A A ]\n   B B B B B B X X X X X X ]\n   C C C C C C C X X X X ]\n   D D D D X X X X X X X ]]\n\n[[ E E E E E E E E ]\n [ F F F F X X X X ]\n [ G G G X X X X X ]\n [ H H H H X X X X ]]\n\n[[ I I I X X ]\n [ J J J X X ]\n [ K K K K K ]\n [ L L L X X ]]\n```\n\nw packing ( note it's the same effective number of tokens per step, but a true bsz of 1)\n```\n   0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5\n[[ A A A A A A A A A A A B B B B B\n   B C C C C C C C D D D D E E E E\n   E E E E F F F F F G G G H H H H\n   I I I J J J J K K K K K L L L X ]]\n```\n\ncu_seqlens:\n[[ 0, 11, 17, 24, 28, 36, 41 44, 48, 51, 55, 60, 64]]\n\n\n## Multipack without Flash Attention\n\nMultipack can still be achieved without Flash attention, but with lower packing\nefficiency as we are not able to join multiple batches into a single batch due to\ncontext length limits without flash attention. We can use either Pytorch's Scaled\nDot Product Attention implementation or native Pytorch attention implementation\nalong with [4d attention masks](https://github.com/huggingface/transformers/pull/27539)\nto pack sequences together and avoid cross attention.\n\n<img src=\"./images/4d-mask.png\" alt=\"axolotl\" width=\"800\">\n"
  },
  {
    "path": "docs/nccl.qmd",
    "content": "---\ntitle: NCCL\ndescription: Troubleshooting NCCL issues\n---\n\nNVIDIA NCCL is a library to facilitate and optimize multi-GPU communication operations, such as broadcast, all-gather, reduce, all-reduce, etc. Broadly, NCCL configuration is highly environment-specific and is configured via several [environment variables](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html). A common NCCL-related problem occurs when a long-running operation times out causing the training process to abort:\n\n```text\nWatchdog caught collective operation timeout: WorkNCCL(SeqNum=42, OpType=ALLGATHER, Timeout(ms)=1800000) ran for 1806948 milliseconds before timing out.\n```\n\nOften, this timeout will happen after 30 minutes (the default setting) and is accompanied by below-average power consumption with near 100% GPU utilization before the error is raised. Nvidia recommends [disabling PCI access control services (ACS)](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/troubleshooting.html#pci-access-control-services-acs) as a possible solution if this is available to you.\n\nForcing cross-GPU communication via [NVLink](https://en.wikipedia.org/wiki/NVLink) may help without increasing timeouts. To verify that your configuration is leveraging NVLink run the following command:\n\n```bash\nnvidia-smi nvlink --status\n```\n\nTo force NCCL to use NVLink, simply set this in the environment:\n\n```bash\nexport NCCL_P2P_LEVEL=NVL\n```\n\nIf NVLink is not available in your environment there are other options for ``NCCL_P2P_LEVEL`` in the table below:\n\n| NCCL_P2P_LEVEL | Description |\n| -------------- | ----------- |\n| PIX | P2P data transfers through no more than a single PCIe bridge. Faster data transfer rates vs to paths involving multiple bridges, but slower compared to direct GPU-to-GPU communication. |\n| PXB | P2P data transfers through multiple PCIe bridges but not going through the PCIe Host Bridge; this path involves a complex routing process, potentially incurring a moderate level of latency. |\n| PHB | P2P data transfers occur over the PCIe and through a PCIe Host Bridge, typically involving the CPU, which can facilitate direct memory access but might introduce additional latency compared to more direct paths (ex PIX, NVL) |\n\nTo validate that acceptable data transfer speeds exist for your training job, running [NCCL Tests](https://github.com/NVIDIA/nccl-tests/blob/master/README.md) can help pinpoint bottlenecks, for example:\n\n```bash\n./build/all_reduce_perf -b 8 -e 128M -f 2 -g 3\n```\n\nIt can be useful when debugging NCCL communication timeouts to activate additional logging in both PyTorch and NCCL:\n\n```bash\nexport NCCL_DEBUG=INFO\nexport NCCL_DEBUG_SUBSYS=ALL\nexport TORCH_DISTRIBUTED_DEBUG=INFO\nexport TORCHELASTIC_ERROR_FILE=/PATH/TO/torcherror.log\n```\n\nFinally, if you believe your training job needs more time you can increase the timeout past 30 minutes by setting the ``ddp_timeout`` value in the Axolotl configuration. See [PyTorch init_process_group](https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group) for documentation on this value.\n"
  },
  {
    "path": "docs/nd_parallelism.qmd",
    "content": "---\ntitle: \"N-D Parallelism (Beta)\"\n---\n\nAxolotl enables training models at scale by composing different parallelism techniques. This is essential when:\n\n- A model's weights are too large to fit on a single GPU's memory.\n- A model's activations, especially with very long contexts, are too large for a single GPU.\n- You want to accelerate training by using multiple GPUs or nodes.\n\nor combinations of the above!\n\n## Core Concepts\n\nParallelism strategies can be combined. The key is understanding how each one divides the workload. PyTorch's `DeviceMesh` is the modern way to manage these combinations, creating a logical grid of your GPUs and assigning different parallel strategies to different dimensions of the grid.\n\n### Data Parallelism {#sec-dp}\n\nData Parallelism focuses on splitting the global data batch across GPUs.\n\n- Distributed Data Parallel (DDP): The classic approach. The full model is replicated on every GPU. Each GPU processes a different slice of the data batch. Gradients are then averaged across all GPUs after the backward pass to keep the models synchronized. This can substantially improve data throughput compared to single-device training, but requires that each GPU is able to hold the entire model, its gradients, and optimizer states.\n\n- [Fully Sharded Data Parallel (FSDP)](multi-gpu.qmd#fully-sharded-data-parallel-(fsdp)): A highly memory-efficient form of data parallelism (inspired by DeepSpeed's ZeRO). Instead of replicating the model, FSDP shards the model's *parameters, gradients, and optimizer states* across the GPUs in the data-parallel group. During computation, each GPU receives the specific parameters it needs via an `all_gather` operation just before they are used, and they can be discarded immediately after (`reshard-after-forward`).\n    - FSDP maps to ZeRO stages:\n        - ZeRO-2 (`reshard_after_forward=False`): Shards gradients and optimizer states. Model weights are replicated on each GPU.\n        - ZeRO-3 (`reshard_after_forward=True`): Shards gradients, optimizer states, AND model parameters. This provides the most memory savings at the cost of more communication (re-gathering parameters for both forward and backward passes).\n\n### [Experimental] Tensor Parallelism (TP) {#sec-tp}\n\nAlso known as \"horizontal model parallelism,\" as described in the [Megatron-LM paper](https://arxiv.org/pdf/1909.08053.pdf). Instead of splitting the batch, TP splits the model's layers themselves across GPUs.\n\n- How it works: For a linear layer `Y = XA`, the weight matrix `A` is split column-wise (`A = [A_1, A_2]`). The computation becomes `Y_1 = XA_1` and `Y_2 = XA_2`, which can happen in parallel on different GPUs. The final output `Y` is simply the concatenation of `Y_1` and `Y_2`. Check [this comment](https://github.com/huggingface/transformers/issues/10321#issuecomment-783543530) for more detailed info.\n- Requirement: TP involves frequent, small communications within a forward/backward pass. It requires a very fast interconnect between GPUs (e.g., NVLink) and is typically not recommended across different nodes.\n\n### Context Parallelism (CP) {#sec-cp}\n\nContext Parallelism, also called [Sequence Parallelism](sequence_parallelism.qmd), addresses the memory bottleneck from long sequences. The input sequence itself is split along the sequence length dimension and distributed across GPUs.\n\n- How it works: If you have a sequence of 8192 tokens and a `context_parallel_size` of 4, each GPU will only handle a chunk of 2048 tokens.\n- The Challenge: Attention is not local; every token needs to \"attend to\" every other token. Splitting the sequence breaks this.\n- The Solution (`ring-flash-attention`): An efficient communication protocol is used. To compute attention for its local sequence chunk, each GPU passes its Key-Value (KV) cache to its neighbor in a \"ring.\" After `N-1` steps, every GPU has seen the KV-cache from all other GPUs, allowing it to compute the correct attention values for its chunk. This is implemented using the highly optimized `flash-attention` kernel at each step.\n\n### Hybrid Sharding Data Parallel (HSDP) {#sec-hsdp}\n\nHSDP is a 2D strategy that intelligently combines FSDP and DDP, typically for multi-node training.\n\n- Intra-Node (within a machine): Use FSDP. This is efficient because GPUs on the same node have fast interconnects (NVLink), making the `all_gather` operations for sharded parameters fast.\n- Inter-Node (across machines): Use DDP. The gradient synchronization between nodes is less frequent than FSDP's parameter gathering, making it a better fit for the slower node-to-node network (e.g., Ethernet/Infiniband).\n- Example: With 2 nodes of 8 GPUs each (16 total), you could have `dp_shard_size=8` (FSDP within each node) and `dp_replicate_size=2` (DDP across the two nodes).\n\n## Usage\n\n```yaml\n# FSDP config. See https://docs.axolotl.ai/docs/multi-gpu.html#sec-fsdp\nfsdp_version: 2\nfsdp_config:\n  # ...\n\n# The number of GPUs to shard the model parameters across (FSDP dimension).\ndp_shard_size: 4\n\n# The number of times to replicate the sharded model (DDP dimension).\ndp_replicate_size: 2\n\n# Number of GPUs for Tensor Parallelism.\ntensor_parallel_size: 1  # (default is 1, no TP)\n\n# Number of GPUs for Context/Sequence Parallelism.\ncontext_parallel_size: 1 # (default is 1, no CP)\n```\n\nNote: We recommend FSDP. DeepSpeed is only compatible with `tensor_parallel_size`.\n\n## Examples\n\n::: {.callout-tip}\nSee our example configs [here](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/distributed-parallel).\n:::\n\n1.  HSDP on 2 nodes with 4 GPUs each (8 GPUs total):\n    - You want FSDP within each node and DDP across nodes.\n    - Set `dp_shard_size: 4` and `dp_replicate_size: 2`.\n\n2.  FSDP + TP on a single 8-GPU node:\n    - You want to split the model across 4 GPUs using FSDP, and further split each layer across 2 GPUs with TP.\n    - Set `dp_shard_size: 4` and `tensor_parallel_size: 2`.\n\n3.  FSDP + CP on a single 8-GPU node for long context:\n    - You want to shard the model across all 8 GPUs and also split the sequence length across all 8 GPUs.\n    - Set `dp_shard_size: 8` and `context_parallel_size: 8`. Note: this means the data parallel group and context parallel group are the same. A more common setup might be to shard across a smaller group.\n\n## Support Matrix\n\nThis matrix describes how different parallelism methods can be combined in Axolotl.\n\n| Combination | `dp_replicate_size` | `dp_shard_size` | `tp_size` | `cp_size` | Status & Notes |\n| --- | :---: | :---: |:---:|:---:|---|\n| **FSDP** (ZeRO-3) | 1 | >1 | 1 | 1 | ✅ Fully supported. Shards model across all GPUs. |\n| **HSDP** | >1 | >1 | 1 | 1 | ✅ Fully supported. FSDP intra-node, DDP inter-node. |\n| **FSDP + TP** | 1 | >1 | >1 | 1 | ✅ **2D Parallelism**. Shards the model across a `dp_shard` group, and TP-splits layers within the `tp` group. |\n| **HSDP + TP** | >1 | >1 | >1 | 1 | ✅ **3D Parallelism**. A powerful but complex combination. |\n| **FSDP + CP** | 1 | >1 | 1 | >1 | ✅ **2D Parallelism**. Combines FSDP with context parallelism. |\n| **FSDP + TP + CP**| 1 | >1 | >1| >1| ✅ **3D Parallelism**. Another advanced combination. |\n| DDP + TP/CP | >1 | 1 | >1 | >1 | ❌ **Not Supported**. The `ParallelismConfig` explicitly prevents this, as composing pure DDP with TP or CP is currently not supported. You should use FSDP + TP/CP instead (`dp_shard_size > 1`). |\n| Just TP / CP | 1 | 1 | >1 | >1 | ✅ Supported. Useful for inference or when the model fits on one GPU but context is too long. |\n\n- `tp_size` refers to `tensor_parallel_size`\n- `cp_size` refers to `context_parallel_size`\n"
  },
  {
    "path": "docs/optimizations.qmd",
    "content": "---\ntitle: Optimizations Guide\ndescription: A guide to the performance and memory optimizations available in Axolotl.\n---\n\nAxolotl includes numerous optimizations to speed up training, reduce memory usage, and handle large models.\n\nThis guide provides a high-level overview and directs you to the detailed documentation for each feature.\n\n## Speed Optimizations\n\nThese optimizations focus on increasing training throughput and reducing total training time.\n\n### Sample Packing\n\nImproves GPU utilization by combining multiple short sequences into a single packed sequence for training. This requires enabling one of the [attention](#attention-implementations) implementations below.\n\n- **Config:** `sample_packing: true`\n- **Learn more:** [Sample Packing](multipack.qmd)\n\n### Attention Implementations\n\nUsing an optimized attention implementation is critical for training speed.\n\n- **[Flash Attention 2](https://github.com/Dao-AILab/flash-attention)**: `flash_attention: true`. **(Recommended)** The industry standard for fast attention on modern GPUs. Requires Ampere or higher. For AMD, check [AMD Support](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#amd-rocm-support).\n- **[Flex Attention](https://pytorch.org/blog/flexattention/)**: `flex_attention: true`.\n- **[SDP Attention](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)**: `sdp_attention: true`. PyTorch's native implementation.\n- **[Xformers](https://github.com/facebookresearch/xformers)**: `xformers_attention: true`. Works with FP16.\n\n*Note: You should only enable one attention backend.*\n\n### LoRA Optimizations\n\nLeverages optimized kernels to accelerate LoRA training and reduce memory usage.\n\n- **Learn more:** [LoRA Optimizations Documentation](lora_optims.qmd)\n\n## Memory Optimizations\n\nThese techniques help you fit larger models or use bigger batch sizes on your existing hardware.\n\n### Parameter Efficient Finetuning (LoRA & QLoRA)\n\nDrastically reduces memory by training a small set of \"adapter\" parameters instead of the full model. This is the most common and effective memory-saving technique.\n\n- Examples: Find configs with `lora` or `qlora` in the [examples directory](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/llama-3).\n- Config Reference: See `adapter`, `load_in_4bit`, and `load_in_8bit` in the [Configuration Reference](config-reference.qmd).\n\n### Gradient Checkpointing & Activation Offloading\n\nThese techniques save VRAM by changing how activations are handled.\n\n- Gradient Checkpointing: re-computes activations during the backward pass, trading compute time for VRAM.\n- Activation Offloading: moves activations to CPU RAM or disk, trading I/O overhead for VRAM.\n- Learn more: [Gradient Checkpointing and Offloading Docs](gradient_checkpointing.qmd)\n\n### Cut Cross Entropy (CCE)\n\nReduces VRAM usage by using an optimized cross-entropy loss calculation.\n\n- **Learn more:** [Custom Integrations - CCE](custom_integrations.qmd#cut-cross-entropy)\n\n### Liger Kernels\n\nProvides efficient Triton kernels to improve training speed and reduce memory usage.\n\n- **Learn more:** [Custom Integrations - Liger Kernels](custom_integrations.qmd#liger-kernels)\n\n### Expert Kernels\n\nOptimized kernel implementations for Mixture of Experts (MoE) model training.\n\n- **ScatterMoE**: Triton-based MoE kernels with fused LoRA support.\n- **SonicMoE**: CUTLASS-based MoE kernels for NVIDIA Hopper and Blackwell GPUs.\n\n- **Learn more:** [Custom Integrations - Kernels Integration](custom_integrations.qmd#kernels-integration)\n\n## Long Context Models\n\nTechniques to train models on sequences longer than their original context window.\n\n### RoPE Scaling\n\nExtends a model's context window by interpolating its Rotary Position Embeddings.\n\n- **Config:** Pass the `rope_scaling` config under the `overrides_of_model_config: `. To learn how to set RoPE, check the respective model config.\n\n### Sequence Parallelism\n\nSplits long sequences across multiple GPUs, enabling training with sequence lengths that would not fit on a single device.\n\n- **Learn more:** [Sequence Parallelism Documentation](sequence_parallelism.qmd)\n\n### Artic Long Sequence Training (ALST)\n\nALST is a recipe that combines several techniques to train long-context models efficiently. It typically involves:\n\n- TiledMLP to reduce memory usage in MLP layers.\n- Tiled Loss functions (like [CCE](#cut-cross-entropy-(cce) or [Liger](#liger-kernels)).\n- Activation Offloading to CPU.\n\n- Example: [ALST Example Configuration](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/alst)\n\n## Large Models (Distributed Training)\n\nTo train models that don't fit on a single GPU, you'll need to use a distributed training strategy like FSDP or DeepSpeed. These frameworks shard the model weights, gradients, and optimizer states across multiple GPUs and nodes.\n\n- **Learn more:** [Multi-GPU Guide](multi-gpu.qmd)\n- **Learn more:** [Multi-Node Guide](multi-node.qmd)\n\n### N-D Parallelism (Beta)\n\nFor advanced scaling, Axolotl allows you to compose different parallelism techniques (e.g., Data, Tensor, Sequence Parallelism). This is a powerful approach to train an extremely large model by overcoming multiple bottlenecks at once.\n\n- **Learn more:** [N-D Parallelism Guide](nd_parallelism.qmd)\n\n\n## Quantization\n\nTechniques to reduce the precision of model weights for memory savings.\n\n### 4-bit Training (QLoRA)\n\nThe recommended approach for quantization-based training. It loads the base model in 4-bit using `bitsandbytes` and then trains QLoRA adapters. See [Adapter Finetuning](#adapter-finetuning-lora-qlora) for details.\n\n### FP8 Training\n\nEnables training with 8-bit floating point precision on supported hardware (e.g., NVIDIA Hopper series GPUs) for significant speed and memory gains.\n\n- **Example:** [Llama 3 FP8 FSDP Example](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/llama-3/3b-fp8-fsdp2.yaml)\n\n### Quantization Aware Training (QAT)\n\nSimulates quantization effects during training, helping the model adapt and potentially improving the final accuracy of the quantized model.\n\n- **Learn more:** [QAT Documentation](qat.qmd)\n\n### GPTQ\n\nAllows you to finetune LoRA adapters on top of a model that has already been quantized using the GPTQ method.\n\n- **Example:** [GPTQ LoRA Example](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/llama-2/gptq-lora.yml)\n\n### MoE Expert Quantization\n\nQuantizes MoE expert weights on load to reduce VRAM when training MoE models with adapters. Required for Transformers v5+ MoE models where experts use fused `nn.Parameter` tensors.\n\n- **Config:** `quantize_moe_experts: true`\n- **Learn more:** [MoE Expert Quantization](expert_quantization.qmd)\n"
  },
  {
    "path": "docs/optimizers.qmd",
    "content": "---\ntitle: Optimizers\ndescription: Configuring optimizers\n---\n\n## Overview\n\nAxolotl supports all optimizers supported by [transformers OptimizerNames](https://github.com/huggingface/transformers/blob/51f94ea06d19a6308c61bbb4dc97c40aabd12bad/src/transformers/training_args.py#L142-L187)\n\nHere is a list of optimizers supported by transformers as of `v4.54.0`:\n\n- `adamw_torch`\n- `adamw_torch_fused`\n- `adamw_torch_xla`\n- `adamw_torch_npu_fused`\n- `adamw_apex_fused`\n- `adafactor`\n- `adamw_anyprecision`\n- `adamw_torch_4bit`\n- `adamw_torch_8bit`\n- `ademamix`\n- `sgd`\n- `adagrad`\n- `adamw_bnb_8bit`\n- `adamw_8bit`  # alias for adamw_bnb_8bit\n- `ademamix_8bit`\n- `lion_8bit`\n- `lion_32bit`\n- `paged_adamw_32bit`\n- `paged_adamw_8bit`\n- `paged_ademamix_32bit`\n- `paged_ademamix_8bit`\n- `paged_lion_32bit`\n- `paged_lion_8bit`\n- `rmsprop`\n- `rmsprop_bnb`\n- `rmsprop_bnb_8bit`\n- `rmsprop_bnb_32bit`\n- `galore_adamw`\n- `galore_adamw_8bit`\n- `galore_adafactor`\n- `galore_adamw_layerwise`\n- `galore_adamw_8bit_layerwise`\n- `galore_adafactor_layerwise`\n- `lomo`\n- `adalomo`\n- `grokadamw`\n- `schedule_free_radam`\n- `schedule_free_adamw`\n- `schedule_free_sgd`\n- `apollo_adamw`\n- `apollo_adamw_layerwise`\n- `stable_adamw`\n\n\n## Custom Optimizers\n\nEnable custom optimizers by passing a string to the `optimizer` argument. Each optimizer will receive beta and epsilon args, however, some may accept additional args which are detailed below.\n\n### optimi_adamw\n\n```yaml\noptimizer: optimi_adamw\n```\n\n### ao_adamw_4bit\n\nDeprecated: Please use `adamw_torch_4bit`.\n\n### ao_adamw_8bit\n\nDeprecated: Please use `adamw_torch_8bit`.\n\n### ao_adamw_fp8\n\n\n```yaml\noptimizer: ao_adamw_fp8\n```\n\n### adopt_adamw\n\nGitHub: [https://github.com/iShohei220/adopt](https://github.com/iShohei220/adopt)\nPaper: [https://arxiv.org/abs/2411.02853](https://arxiv.org/abs/2411.02853)\n\n```yaml\noptimizer: adopt_adamw\n```\n\n### came_pytorch\n\nGitHub: [https://github.com/yangluo7/CAME/tree/master](https://github.com/yangluo7/CAME/tree/master)\nPaper: [https://arxiv.org/abs/2307.02047](https://arxiv.org/abs/2307.02047)\n\n```yaml\noptimizer: came_pytorch\n\n# optional args (defaults below)\nadam_beta1: 0.9\nadam_beta2: 0.999\nadam_beta3: 0.9999\nadam_epsilon: 1e-30\nadam_epsilon2: 1e-16\n```\n\n### muon\n\nBlog: [https://kellerjordan.github.io/posts/muon/](https://kellerjordan.github.io/posts/muon/)\nPaper: [https://arxiv.org/abs/2502.16982v1](https://arxiv.org/abs/2502.16982v1)\n\n```yaml\noptimizer: muon\n```\n\n### dion\n\nMicrosoft's Dion (DIstributed OrthoNormalization) optimizer is a scalable and communication-efficient\northonormalizing optimizer that uses low-rank approximations to reduce gradient communication.\n\nGitHub: [https://github.com/microsoft/dion](https://github.com/microsoft/dion)\nPaper: [https://arxiv.org/pdf/2504.05295](https://arxiv.org/pdf/2504.05295)\nNote: Implementation written for PyTorch 2.7+ for DTensor\n\n```yaml\noptimizer: dion\ndion_lr: 0.01\ndion_momentum: 0.95\nlr: 0.00001  # learning rate for embeddings and parameters that fallback to AdamW\n```\n"
  },
  {
    "path": "docs/qat.qmd",
    "content": "---\ntitle: \"Quantization Aware Training (QAT)\"\nback-to-top-navigation: true\ntoc: true\ntoc-expand: 2\ntoc-depth: 4\n---\n\n## Overview\n\n[Quantization Aware Training](https://pytorch.org/blog/introduction-to-quantization-on-pytorch/#quantization-aware-training) (QAT) is a technique for improving the accuracy of models which are quantized\nby applying \"fake\" quantizations to the model's weights (and optionally, activations) during training. This fake\nquantization allows for the model to adjust for noise introduced by the quantization, so when the model is eventually\nquantized, the accuracy loss is minimized. We use the quantization techniques implemented in [torchao](https://github.com/pytorch/ao) to provide\nsupport for QAT and post-training quantization (PTQ) in axolotl.\n\nWe recommend reviewing the excellent QAT tutorial in the [torchtune library](https://pytorch.org/torchtune/main/tutorials/qat_finetune.html#quantizing-the-qat-model),\nand the QAT documentation in the [torchao library](https://github.com/pytorch/ao/tree/main/torchao/quantization/qat), for more details.\n\n## Configuring QAT in Axolotl\n\nTo enable QAT in axolotl, add the following to your configuration file:\n\n```yaml\nqat:\n  activation_dtype: # Optional[str] = \"int8\". Fake quantization layout to use for activation quantization. Valid options are \"int4\", \"int8\", \"float8\"\n  weight_dtype: # Optional[str] = \"int8\". Fake quantization layout to use for weight quantization. Valid options are \"int4\", \"fp8\", and \"nvfp4\".\n  group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization\n  fake_quant_after_n_steps: # Optional[int] = None. The number of steps to apply fake quantization after\n```\n\nWe support the following quantization schemas:\n\n- `Int4WeightOnly` (requires the `fbgemm-gpu` extra when installing Axolotl)\n- `Int8DynamicActivationInt4Weight`\n- `Float8DynamicActivationFloat8Weight`\n- `Float8DynamicActivationInt4Weight`\n- `NVFP4`\n\nOnce you have finished training, you must quantize your model by using the same quantization configuration which you used to train the model with. You can use the [`quantize`](./quantize.qmd) command to do this.\n"
  },
  {
    "path": "docs/quantize.qmd",
    "content": "---\ntitle: \"Quantization with torchao\"\nback-to-top-navigation: true\ntoc: true\ntoc-expand: 2\ntoc-depth: 4\n---\n\nQuantization is a technique to lower the memory footprint of your model, potentially at the cost of accuracy or model performance. We support quantizing your model using the [torchao](https://github.com/pytorch/ao) library. Quantization is supported for both post-training quantization (PTQ) and quantization-aware training (QAT).\n\n\n::: {.callout-note}\n\nWe do not currently support quantization techniques such as GGUF/GPTQ,EXL2 at the moment.\n\n:::\n\n## Configuring Quantization in Axolotl\n\nQuantization is configured using the `quantization` key in your configuration file.\n\n```yaml\nbase_model: # The path to the model to quantize.\nquantization:\n  activation_dtype: # Optional[str] = \"int8\". Fake quantization layout to use for activation quantization. Valid options are \"int4\", \"int8\", \"float8\"\n  weight_dtype: # Optional[str] = \"int8\". Fake quantization layout to use for weight quantization. Valid options are \"int4\", \"fp8\", and \"nvfp4\".\n  group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization\n  quantize_embedding: # Optional[bool] = False. Whether to quantize the embedding layer.\n\noutput_dir:  # The path to the output directory.\n```\n\nOnce quantization is complete, your quantized model will be saved in the `{output_dir}/quantized` directory.\n\nYou may also use the `quantize` command to quantize a model which has been trained with [QAT](./qat.qmd) - you can do this by using the existing QAT configuration file which\nyou used to train the model:\n\n```yaml\n# qat.yml\nqat:\n  activation_dtype: int8\n  weight_dtype: int4\n  group_size: 256\n\noutput_dir: # The path to the output directory used during training where the final checkpoint has been saved.\n```\n\n```bash\naxolotl quantize qat.yml\n```\n\nThis ensures that an identical quantization configuration is used to quantize the model as was used to train it.\n\n\n::: {.callout-note}\n\nIf you have configured pushing to hub with `hub_model_id`, your model hub name will have the quantization schema appended to it,\ne.g. `axolotl-ai-cloud/qat-nvfp4-llama3B` will become `axolotl-ai-cloud/qat-nvfp4-llama3B-nvfp4w`\n\n:::\n"
  },
  {
    "path": "docs/ray-integration.qmd",
    "content": "---\ntitle: Ray Train\ndescription: How to use Axolotl with Ray Train\n---\n\nAxolotl supports using Ray as an alternative to `accelerate` for orchestrating training. This is especially useful for multi-node training since you only have to setup code and dependencies in a single node and launch training as if you were using a single node.\n\nWith the `--use-ray` CLI flag, Axolotl will use Ray Train's [`TorchTrainer`](https://docs.ray.io/en/latest/train/api/doc/ray.train.torch.TorchTrainer.html#ray.train.torch.TorchTrainer) to run training.\n\n## Ray cluster setup\n\nA prerequisite using the Ray Train integration is to setup a Ray cluster on your desired node(s). For a detailed guide on how you can get started with ray clusters, check the official Ray docs [here](https://docs.ray.io/en/latest/cluster/getting-started.html).\n\nEvery Ray cluster has one _head_ node and a set of worker nodes. The head node is just like any other worker node, but it also runs certain special processes related to scheduling and orchestration. Ray-enabled scripts are run on the head node and depending on the resources (number of CPUs, GPUs, etc) they request, will be scheduled to run certain tasks on the worker nodes. For more on key concepts behind a Ray cluster, you can refer this [doc](https://docs.ray.io/en/latest/cluster/key-concepts.html#cluster-key-concepts).\n\n## Sanity check\n\nTo run a sanity check on whether your ray cluster is setup properly, execute the following on the head node:\n\n```bash\nray status\n```\n\nThe output should have a summary of your Ray cluster - list of all the nodes in your cluster, the number of CPUs and GPUs in your cluster, etc. For example, if you have a cluster with 1 CPU-only head node and 2 4xL40S worker nodes, the output can look like this:\n\n\n```\nNode status\n---------------------------------------------------------------\nActive:\n 1 head\nIdle:\n 2 4xL40S:48CPU-384GB\nPending:\n (no pending nodes)\nRecent failures:\n (no failures)\n\nResources\n---------------------------------------------------------------\nUsage:\n 0.0/96.0 CPU\n 0.0/8.0 GPU\n 0B/800.00GiB memory\n 0B/229.57GiB object_store_memory\n\nDemands:\n (no resource demands)\n```\n\nYou should also be able to see the same on the [Ray dashboard](https://docs.ray.io/en/latest/ray-observability/getting-started.html).\n\n\n## Configuring training with Ray Train\n\nYou can find an example configuration at `configs/llama-3/lora-1b-ray.yaml`.\n\nThe key parameters to note here are:\n\n```yaml\nuse_ray: true\nray_num_workers: 4\n# optional\nresources_per_worker:\n    GPU: 1\n```\n\n- `use_ray`: This is the flag that enables the Ray Train integration. You can either use the corresponding `--use-ray` flag in the CLI or set `use_ray` in the config file.\n- `ray_num_workers`: This is the number of workers/GPUs to use for training.\n- `resources_per_worker`: This is the Ray [resource request](https://docs.ray.io/en/latest/ray-core/scheduling/resources.html) for each worker. This can be used to request a specific GPU type or a custom resource for each worker. For example, if your ray cluster has GPUs of different types, and you only want to use NVIDIA L40S GPUs, you can do\n\n```yaml\nresources_per_worker:\n    accelerator_type:L40S: 0.001\n```\n\n## Launching training\n\nYou can simply run the following command on the head node:\n\n```bash\naxolotl train examples/llama-3/lora-1b-ray.yml --use-ray\n```\n\nThis will launch training on the head node and workers will be scheduled automatically by Ray Train to run on the appropriate head or worker nodes.\n\nYou can also monitor training progress on the Ray dashboard.\n\nComing back to the example on a Ray cluster with 1 head node and 2 4xL40S worker nodes, let's say you want to make use of all 8 GPUs. You would be able to just set `ray_num_workers: 8` and run the previous command. The Cluster tab will show the following:\n\n![Ray dashboard](./images/ray-cluster-dashboard.png)\n"
  },
  {
    "path": "docs/reward_modelling.qmd",
    "content": "---\ntitle: \"Reward Modelling\"\ndescription: \"Reward models are used to guide models towards behaviors which is preferred by humans, by training over large datasets annotated with human preferences. \"\n---\n\n### Overview\n\nReward modelling is a technique used to train models to predict the reward or value of a given input. This is particularly useful in reinforcement learning scenarios where the model needs to evaluate the quality of its actions or predictions.\nWe support the reward modelling techniques supported by `trl`.\n\n### (Outcome) Reward Models\n\nOutcome reward models are trained using data which contains preference annotations for an entire interaction between the user and model (e.g. rather than per-turn or per-step).\nFor improved training stability, you can use the `center_rewards_coefficient` parameter to encourage mean-zero reward outputs ([see TRL docs](https://huggingface.co/docs/trl/v0.10.1/en/reward_trainer#centering-rewards)).\n\n```yaml\nbase_model: google/gemma-2-2b\nmodel_type: AutoModelForSequenceClassification\nnum_labels: 1\ntokenizer_type: AutoTokenizer\n\nreward_model: true\nchat_template: gemma\ndatasets:\n  - path: argilla/distilabel-intel-orca-dpo-pairs\n    type: bradley_terry.chat_template\n\nval_set_size: 0.1\neval_steps: 100\n```\n\nBradley-Terry chat templates expect single-turn conversations in the following format:\n\n```json\n{\n    \"system\": \"...\", // optional\n    \"input\": \"...\",\n    \"chosen\": \"...\",\n    \"rejected\": \"...\"\n}\n```\n\n### Process Reward Models (PRM)\n\n::: {.callout-tip}\nCheck out our [PRM blog](https://axolotlai.substack.com/p/process-reward-models).\n:::\n\nProcess reward models are trained using data which contains preference annotations for each step in a series of interactions. Typically, PRMs are trained to provide reward signals over each step of a reasoning trace and are used for downstream reinforcement learning.\n```yaml\nbase_model: Qwen/Qwen2.5-3B\nmodel_type: AutoModelForTokenClassification\nnum_labels: 2\n\nprocess_reward_model: true\ndatasets:\n  - path: trl-lib/math_shepherd\n    type: stepwise_supervised\n    split: train\n\nval_set_size: 0.1\neval_steps: 100\n```\n\nPlease see [stepwise_supervised](dataset-formats/stepwise_supervised.qmd) for more details on the dataset format.\n"
  },
  {
    "path": "docs/rlhf.qmd",
    "content": "---\ntitle: \"RLHF (Beta)\"\ndescription: \"Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human feedback.\"\nback-to-top-navigation: true\ntoc: true\ntoc-expand: 2\ntoc-depth: 4\n---\n\n## Overview\n\nReinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human\nfeedback. Various methods include, but not limited to:\n\n- [Direct Preference Optimization (DPO)](#dpo)\n- [Identity Preference Optimization (IPO)](#ipo)\n- [Kahneman-Tversky Optimization (KTO)](#kto)\n- [Odds Ratio Preference Optimization (ORPO)](#orpo)\n- [Group Relative Policy Optimization (GRPO)](#grpo)\n- [Group Reward-Decoupled Policy Optimization (GDPO)](#gdpo)\n\n\n## RLHF using Axolotl\n\n::: {.callout-important}\nThis is a BETA feature and many features are not fully implemented. You are encouraged to open new PRs to improve the integration and functionality.\n:::\n\nWe rely on the [TRL](https://github.com/huggingface/trl) library for implementations of various RL training methods, which we wrap around to expose in axolotl. Each method has their own supported ways of loading datasets and prompt formats.\n\n::: {.callout-tip}\nYou can find what each method supports by going into `src/axolotl/prompt_strategies/{method}` where `{method}` is one of our supported methods. The `type: ` can be retrieved from `{method}.{function_name}`.\n:::\n\n### DPO\n\nExample config:\n\n```yaml\nrl: dpo\ndatasets:\n  - path: Intel/orca_dpo_pairs\n    split: train\n    type: chatml.intel\n  - path: argilla/ultrafeedback-binarized-preferences\n    split: train\n    type: chatml\n```\n\nDPO supports the following types with the following dataset format:\n\n#### chatml.argilla\n\n```json\n{\n    \"system\": \"...\", // optional\n    \"instruction\": \"...\",\n    \"chosen_response\": \"...\",\n    \"rejected_response\": \"...\"\n}\n```\n\n#### chatml.argilla_chat\n\n```json\n{\n    \"chosen\": [\n        {\"role\": \"user\", \"content\": \"...\"},\n        {\"role\": \"assistant\", \"content\": \"...\"}\n    ],\n    \"rejected\": [\n        {\"role\": \"user\", \"content\": \"...\"},\n        {\"role\": \"assistant\", \"content\": \"...\"}\n    ]\n}\n```\n\n#### chatml.icr\n\n```json\n{\n    \"system\": \"...\", // optional\n    \"input\": \"...\",\n    \"chosen\": \"...\",\n    \"rejected\": \"...\"\n}\n```\n\n#### chatml.intel\n\n```json\n{\n    \"system\": \"...\", // optional\n    \"question\": \"...\",\n    \"chosen\": \"...\",\n    \"rejected\": \"...\"\n}\n```\n\n#### chatml.prompt_pairs\n\n```json\n{\n    \"system\": \"...\", // optional\n    \"prompt\": \"...\",\n    \"chosen\": \"...\",\n    \"rejected\": \"...\"\n}\n```\n\n#### chatml.ultra\n\n```json\n{\n    \"system\": \"...\", // optional\n    \"prompt\": \"...\",\n    \"chosen\": [\n        {\"role\": \"user\", \"content\": \"...\"},\n        {\"role\": \"assistant\", \"content\": \"...\"}\n    ],\n    \"rejected\": [\n        {\"role\": \"user\", \"content\": \"...\"},\n        {\"role\": \"assistant\", \"content\": \"...\"}\n    ]\n}\n```\n\n#### llama3.argilla\n\n```json\n{\n    \"system\": \"...\", // optional\n    \"instruction\": \"...\",\n    \"chosen_response\": \"...\",\n    \"rejected_response\": \"...\"\n}\n```\n\n#### llama3.argilla_chat\n\n```json\n{\n    \"chosen\": [\n        {\"role\": \"user\", \"content\": \"...\"},\n        {\"role\": \"assistant\", \"content\": \"...\"}\n    ],\n    \"rejected\": [\n        {\"role\": \"user\", \"content\": \"...\"},\n        {\"role\": \"assistant\", \"content\": \"...\"}\n    ]\n}\n```\n\n#### llama3.icr\n\n```json\n{\n    \"system\": \"...\", // optional\n    \"input\": \"...\",\n    \"chosen\": \"...\",\n    \"rejected\": \"...\"\n}\n```\n\n#### llama3.intel\n\n```json\n{\n    \"system\": \"...\", // optional\n    \"question\": \"...\",\n    \"chosen\": \"...\",\n    \"rejected\": \"...\"\n}\n```\n\n#### llama3.prompt_pairs\n\n```json\n{\n    \"system\": \"...\", // optional\n    \"prompt\": \"...\",\n    \"chosen\": \"...\",\n    \"rejected\": \"...\"\n}\n```\n\n#### llama3.ultra\n\n```json\n{\n    \"system\": \"...\", // optional\n    \"prompt\": \"...\",\n    \"chosen\": [\n        {\"role\": \"user\", \"content\": \"...\"},\n        {\"role\": \"assistant\", \"content\": \"...\"}\n    ],\n    \"rejected\": [\n        {\"role\": \"user\", \"content\": \"...\"},\n        {\"role\": \"assistant\", \"content\": \"...\"}\n    ]\n}\n```\n\n#### zephyr.nectar\n\n```json\n{\n    \"prompt\": \"...\",\n    \"answers\": [\n        {\n            \"answer\": \"...\",\n            \"rank\": 1\n        },\n        {\n            \"answer\": \"...\",\n            \"rank\": 2\n        }\n        // ... more answers with ranks\n    ]\n}\n```\n\n#### chat_template.argilla_chat\n\n```json\n{\n    \"chosen\": [\n        {\"role\": \"user\", \"content\": \"...\"},\n        {\"role\": \"assistant\", \"content\": \"...\"}\n    ],\n    \"rejected\": [\n        {\"role\": \"user\", \"content\": \"...\"},\n        {\"role\": \"assistant\", \"content\": \"...\"}\n    ]\n}\n```\n\n#### chat_template.default\n\n```yaml\nrl: dpo\ndatasets:\n  - path: ...\n    split: train\n    type: chat_template.default\n    field_messages: \"messages\"\n    field_chosen: \"chosen\"\n    field_rejected: \"rejected\"\n    message_property_mappings:\n      role: role\n      content: content\n    roles:\n      user: [\"user\"]\n      assistant: [\"assistant\"]\n      system: [\"system\"]\n```\n\nSample input format:\n\n```json\n{\n    \"messages\": [\n        {\n            \"role\": \"system\",\n            \"content\": \"...\"\n        },\n        {\n            \"role\": \"user\",\n            \"content\": \"...\"\n        },\n        // ... more messages\n    ],\n    \"chosen\": {\n        \"role\": \"assistant\",\n        \"content\": \"...\"\n    },\n    \"rejected\": {\n        \"role\": \"assistant\",\n        \"content\": \"...\"\n    }\n}\n```\n\n#### user_defined.default\n\nFor custom behaviors,\n\n```yaml\nrl: dpo\ndatasets:\n  - path: ...\n    split: train\n    type:\n      field_prompt: \"prompt\"\n      field_system: \"system\"\n      field_chosen: \"chosen\"\n      field_rejected: \"rejected\"\n      prompt_format: \"{prompt}\"\n      chosen_format: \"{chosen}\"\n      rejected_format: \"{rejected}\"\n```\n\nThe input format is a simple JSON input with customizable fields based on the above config.\n\n```json\n{\n    \"system\": \"...\",  // optional\n    \"prompt\": \"...\",\n    \"chosen\": \"...\",\n    \"rejected\": \"...\"\n}\n```\n\n### IPO\n\nAs IPO is just DPO with a different loss function, all supported dataset formats for [DPO](#dpo) are also supported for IPO.\n\n```yaml\nrl: ipo\n```\n\n### ORPO\n\nPaper: https://arxiv.org/abs/2403.07691\n\n```yaml\nrl: orpo\norpo_alpha: 0.1\nremove_unused_columns: false\n\nchat_template: chatml\ndatasets:\n  - path: argilla/ultrafeedback-binarized-preferences-cleaned\n    type: chat_template.argilla\n```\n\nORPO supports the following types with the following dataset format:\n\n#### chat_template.argilla\n\n```json\n{\n    \"system\": \"...\",  // optional\n    \"prompt\": \"...\",  // if available, will be taken as user message for single-turn instead of from list below\n\n    // chosen/rejected should be same till last content and only even-number of alternating user/assistant turns\n    \"chosen\": [\n        {\"role\": \"user\", \"content\": \"...\"},\n        {\"role\": \"assistant\", \"content\": \"...\"}\n    ],\n    \"rejected\": [\n        {\"role\": \"user\", \"content\": \"...\"},\n        {\"role\": \"assistant\", \"content\": \"...\"}\n    ]\n}\n```\n\n### KTO\n\n```yaml\nrl: kto\nrl_beta: 0.1  # default\nkto_desirable_weight: 1.0  # default\nkto_undesirable_weight: 1.0  # default\n\nremove_unused_columns: false\n\ndatasets:\n  - path: argilla/ultrafeedback-binarized-preferences-cleaned-kto\n    type: llama3.ultra\n    split: train\n\ngradient_checkpointing: true\ngradient_checkpointing_kwargs:\n  use_reentrant: true\n```\n\nKTO supports the following types with the following dataset format:\n\n#### chatml.argilla\n\n```json\n{\n    \"system\": \"...\", // optional\n    \"instruction\": \"...\",\n    \"completion\": \"...\"\n}\n```\n\n#### chatml.argilla_chat\n\n```json\n{\n    \"chosen\": [\n        {\"role\": \"user\", \"content\": \"...\"}\n    ],\n    \"completion\": [\n        {\"role\": \"assistant\", \"content\": \"...\"}\n    ]\n}\n```\n\n#### chatml.intel\n\n```json\n{\n    \"system\": \"...\", // optional\n    \"question\": \"...\",\n    \"completion\": \"...\"\n}\n```\n\n#### chatml.prompt_pairs\n\n```json\n{\n    \"system\": \"...\", // optional\n    \"prompt\": \"...\",\n    \"completion\": \"...\"\n}\n```\n\n#### chatml.ultra\n\n```json\n{\n    \"system\": \"...\", // optional\n    \"prompt\": \"...\",\n    \"completion\": \"...\"\n}\n```\n\n#### llama3.argilla\n\n```json\n{\n    \"system\": \"...\", // optional\n    \"instruction\": \"...\",\n    \"completion\": \"...\"\n}\n```\n\n#### llama3.argilla_chat\n\n```json\n{\n    \"completion\": [\n        {\"role\": \"user\", \"content\": \"...\"},\n        {\"role\": \"assistant\", \"content\": \"...\"}\n    ]\n}\n```\n\n#### llama3.intel\n\n```json\n{\n    \"system\": \"...\", // optional\n    \"question\": \"...\",\n    \"completion\": \"...\"\n}\n```\n\n#### llama3.prompt_pairs\n\n```json\n{\n    \"system\": \"...\", // optional\n    \"prompt\": \"...\",\n    \"completion\": \"...\"\n}\n```\n\n#### llama3.ultra\n\n```json\n{\n    \"system\": \"...\", // optional\n    \"prompt\": \"...\",\n    \"completion\": \"...\"\n}\n```\n\n#### user_defined.default\n\nFor custom behaviors,\n\n```yaml\nrl: kto\ndatasets:\n  - path: ...\n    split: train\n    type:\n      field_prompt: \"prompt\"\n      field_system: \"system\"\n      field_completion: \"completion\"\n      field_label: \"label\"\n      prompt_format: \"{prompt}\"\n      completion_format: \"{completion}\"\n```\n\nThe input format is a simple JSON input with customizable fields based on the above config.\n\n```json\n{\n    \"system\": \"...\",  // optional\n    \"prompt\": \"...\",\n    \"completion\": \"...\",\n    \"label\": \"...\"\n}\n```\n\n### GRPO\n\n::: {.callout-tip}\nCheck out our [GRPO cookbook](https://github.com/axolotl-ai-cloud/grpo_code).\n:::\n\nIn the latest GRPO implementation, `vLLM` is used to significantly speedup trajectory generation during training. In this example, we're using 4 GPUs - 2 for training, and 2 for vLLM:\n\n::: {.callout-important}\nMake sure you've installed the correct version of vLLM by including it as an extra when installing axolotl, e.g. `pip install axolotl[vllm]`.\n:::\n\n```yaml\nbase_model: Qwen/Qwen2.5-1.5B-Instruct\n\nvllm:\n    host: 0.0.0.0\n    port: 8000\n    tensor_parallel_size: 2\n    gpu_memory_utilization: 0.85\n    dtype: auto\n    # max_model_len: # you may find it useful to set the vLLM model context length if you know this beforehand\n\nrl: grpo\ntrl:\n    use_vllm: true\n    vllm_server_host: 0.0.0.0\n    vllm_server_port: 8000\n    vllm_server_timeout: 300\n```\n\n```bash\nCUDA_VISIBLE_DEVICES=2,3 axolotl vllm-serve grpo.yaml\n```\n\nYour `vLLM` instance will now attempt to spin up, and it's time to kick off training utilizing our remaining two GPUs. In another terminal, execute:\n\n```bash\nCUDA_VISIBLE_DEVICES=0,1 axolotl train grpo.yaml --num-processes 2\n```\n\n::: {.callout-note}\nDue to TRL's implementation with vLLM, the vLLM instance must use the last N GPUs instead of the first N GPUs. This is why in the example above, we use `CUDA_VISIBLE_DEVICES=2,3` for the vLLM instance.\n:::\n\n#### Reward functions\n\nGRPO uses custom reward functions and transformations. Please have them ready locally.\n\nFor example, to load OpenAI's GSM8K and use a random reward for completions:\n\n```python\n# rewards.py\nimport random\n\ndef rand_reward_func(completions, **kwargs) -> list[float]:\n    return [random.uniform(0, 1) for _ in completions]\n\ndef oai_gsm8k_transform(cfg, *args, **kwargs):\n    def transform_fn(example, tokenizer=None):\n        label = example[\"answer\"].split(\"####\")[-1].strip().replace(\",\", \"\")\n        return {\n            \"prompt\": [{\"role\": \"user\", \"content\": example[\"question\"]},],\n            \"answer\": label,\n        }\n    return transform_fn, {\"remove_columns\": [\"question\"]}\n```\n\n```yaml\nrl: grpo\n\ntrl:\n    beta: 0.001\n    max_completion_length: 256\n    use_vllm: True\n    num_generations: 4\n    reward_funcs: [\"rewards.rand_reward_func\"]    # format: '{file_name}.{fn_name}'\n    reward_weights: [1.0]\ndatasets:\n  - path: openai/gsm8k\n    name: main\n    type: rewards.oai_gsm8k_transform  # format: '{file_name}.{fn_name}'\n```\n\nTo see other examples of custom reward functions, please see [TRL GRPO Docs](https://github.com/huggingface/trl/blob/main/docs/source/grpo_trainer.md#using-a-custom-reward-function).\n\nTo see all configs, please see [TRLConfig](https://github.com/axolotl-ai-cloud/axolotl/blob/v0.9.2/src/axolotl/utils/schemas/trl.py).\n\n#### OpenEnv Rollout Functions\n\nGRPO supports custom rollout functions for OpenEnv-style environments, enabling interactive tasks like web browsing, code execution, or tool use. This allows you to implement custom generation logic that interacts with external environments.\n\nFor example, to implement a simple math-solving environment with step-by-step verification:\n\n```python\n# math_env.py\nimport re\n\ndef math_solver_rollout(model, processing_class, prompts, generation_config=None):\n    \"\"\"\n    Custom rollout function that generates step-by-step math solutions.\n\n    Args:\n        model: The language model\n        processing_class: The tokenizer/processing_class\n        prompts: List of prompt dicts (with 'messages' key for chat format)\n        generation_config: Optional generation configuration\n\n    Returns:\n        List of completion strings\n    \"\"\"\n    completions = []\n\n    for prompt in prompts:\n        # Apply chat template to prompt\n        messages = prompt.get(\"messages\", [])\n        formatted_prompt = processing_class.apply_chat_template(\n            messages, processing_class=False, add_generation_prompt=True\n        )\n\n        # Generate step-by-step solution\n        full_response = \"\"\n        for step in range(5):  # Max 5 reasoning steps\n            current_input = formatted_prompt + full_response + \"\\nNext step:\"\n            inputs = processing_class(current_input, return_tensors=\"pt\").to(model.device)\n\n            outputs = model.generate(\n                **inputs,\n                max_new_tokens=100,\n                generation_config=generation_config,\n            )\n            step_text = processing_class.decode(\n                outputs[0][inputs.input_ids.shape[1]:],\n                skip_special_tokens=True\n            )\n\n            # Check if solution is complete\n            if \"FINAL ANSWER:\" in step_text:\n                full_response += step_text\n                break\n            full_response += step_text + \"\\n\"\n\n        completions.append(full_response)\n\n    return completions\n\ndef math_reward(prompts, completions, answers, **kwargs):\n    \"\"\"Reward function that checks mathematical correctness\"\"\"\n    rewards = []\n    for completion, correct_answer in zip(completions, answers):\n        # Extract predicted answer\n        match = re.search(r\"FINAL ANSWER:\\s*(.+)\", completion)\n        predicted = match.group(1).strip() if match else \"\"\n\n        # Compare with correct answer\n        reward = 1.0 if predicted == str(correct_answer) else 0.0\n        rewards.append(reward)\n\n    return rewards\n\ndef math_transform(cfg, *args, **kwargs):\n    \"\"\"Transform dataset to GRPO format with answer field\"\"\"\n    def transform_fn(example, processing_class=None):\n        return {\n            \"prompt\": [{\"role\": \"user\", \"content\": example[\"question\"]}],\n            \"answer\": str(example[\"answer\"]),\n        }\n    return transform_fn, {\"remove_columns\": [\"question\"]}\n```\n\n```yaml\nrl: grpo\n\ntrl:\n  beta: 0.001\n  max_completion_length: 512\n  num_generations: 4\n  rollout_func: \"math_env.math_solver_rollout\"  # Custom rollout function\n  reward_funcs: [\"math_env.math_reward\"]\n  reward_weights: [1.0]\n\ndatasets:\n  - path: openai/gsm8k\n    name: main\n    type: math_env.math_transform\n```\n\nThe `rollout_func` parameter accepts a fully qualified name (e.g., `module_name.function_name`) that points to a callable function in your local directory. The function receives:\n\n- `model`: The language model\n- `processing_class`: The tokenizer/processing class\n- `prompts`: List of prompt dictionaries\n- `generation_config` (optional): Generation configuration\n\nAnd should return a list of completion strings.\n\nFor more OpenEnv examples, see [TRL OpenEnv Documentation](https://huggingface.co/docs/trl/main/en/openenv).\n\n#### GRPO with DAPO/Dr. GRPO loss\n\nThe DAPO paper and subsequently Dr. GRPO paper proposed an alternative loss function for GRPO to remediate the penalty in longer responses.\n\n```yaml\ntrl:\n  loss_type: dr_grpo\n  # Normalizes loss based on max completion length (default: 256)\n  max_completion_length:\n```\n\nFor more information, see [GRPO docs](https://huggingface.co/docs/trl/v0.17.0/en/grpo_trainer#loss-types).\n\n#### Async GRPO\n\nAsync GRPO overlaps vLLM generation with training by producing rollouts in a background thread. While the model trains on the current batch, the next batch is already being generated. This can significantly reduce wall-clock time per step.\n\n```yaml\ntrl:\n  use_data_producer: true     # Enable data producer protocol\n  use_vllm: true\n  async_prefetch: true         # Generate rollouts in background thread\n  prefetch_depth: 1            # Number of rollouts to prefetch\n  vllm_sync_interval: 2        # Sync weights to vLLM every N steps\n```\n\n::: {.callout-note}\nBecause the background thread generates completions with slightly stale model weights, async GRPO uses importance sampling correction to account for the distribution shift. This is controlled by `vllm_importance_sampling_correction: true` (default when async is enabled).\n:::\n\n##### vLLM LoRA Sync\n\nBy default, weight sync to vLLM merges the LoRA adapter into the base model and broadcasts all parameters via NCCL. LoRA sync is a faster alternative that saves only the adapter weights to the filesystem and has vLLM load them natively using Punica kernels.\n\n```yaml\nadapter: lora\nlora_r: 32\nlora_alpha: 64\nlora_target_linear: true\n\ntrl:\n  vllm_lora_sync: true         # Enable native LoRA sync\n```\n\nWhen `vllm_lora_sync: true` is set, axolotl automatically selects the LoRA-aware vLLM serve module. Start vLLM as usual:\n\n```bash\nCUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml\n```\n\nThen start training on a separate GPU:\n\n```bash\nCUDA_VISIBLE_DEVICES=1 axolotl train config.yaml\n```\n\n::: {.callout-tip}\nLoRA sync is especially beneficial with multi-GPU training (FSDP/DeepSpeed), where NCCL merge-sync can cause GPU contention with vLLM generation.\n:::\n\n##### Streaming Partial Batch\n\nInstead of scoring the entire batch at once, streaming mode scores one prompt group at a time. This enables finer-grained zero-advantage skipping and reduces peak memory usage during scoring.\n\n```yaml\ntrl:\n  streaming_partial_batch: true\n```\n\n##### Importance Sampling Correction\n\nWhen using async prefetch, completions are generated from a slightly older version of the model. Importance sampling (IS) correction adjusts the policy gradient to account for this distribution shift.\n\n```yaml\ntrl:\n  vllm_importance_sampling_correction: true   # Enable IS correction\n  importance_sampling_level: token             # 'token' or 'sequence'\n  off_policy_mask_threshold: 0.5              # Mask sequences with IS ratio below this\n```\n\n- `importance_sampling_level: token` applies per-token IS ratios (recommended with Liger kernel)\n- `importance_sampling_level: sequence` applies per-sequence IS ratios\n- `off_policy_mask_threshold` masks out sequences where the IS ratio indicates they are too far off-policy\n\n##### Replay Buffer\n\nThe replay buffer caches rollout groups that had learning signal (non-zero reward variance) and uses them to replace zero-signal groups in later batches.\n\n```yaml\ntrl:\n  replay_buffer_size: 100       # Max cached groups (0 = disabled)\n  replay_recompute_logps: true  # Recompute log-probs for replayed data (recommended)\n```\n\n::: {.callout-note}\nWhen `replay_recompute_logps: true` (default), old log-probabilities are recomputed using the current model weights. This fixes the IS mismatch that would otherwise occur when replaying stale data.\n:::\n\n##### Deferred Re-rolling\n\nFailed prompts (where the model produces zero reward for all generations) are buffered and re-injected into later batches when the model may be better equipped to solve them.\n\n```yaml\ntrl:\n  reroll_start_fraction: 0.5    # Start re-rolling after 50% of training\n  reroll_max_groups: 1          # Max groups to replace per batch\n```\n\n##### Zero-Advantage Batch Skipping\n\nWhen all advantages in a micro-batch are zero (no learning signal), the forward/backward pass is skipped entirely. This is enabled by default and logged as `skipped_zero_adv_batches=1`.\n\n```yaml\ntrl:\n  skip_zero_advantage_batches: true   # default\n```\n\n##### Parallel Reward Workers\n\nReward functions that use `signal.alarm()` (e.g., `math_verify`) must run in the main thread. Parallel reward workers use subprocesses to work around this limitation while enabling concurrent reward computation.\n\n```yaml\ntrl:\n  reward_num_workers: 4         # Number of subprocess workers (1 = no parallelism)\n```\n\n##### Full Async GRPO Example\n\n```yaml\nbase_model: Qwen/Qwen2.5-1.5B-Instruct\n\nvllm:\n    host: 0.0.0.0\n    port: 8000\n    gpu_memory_utilization: 0.35\n    dtype: auto\n\nadapter: lora\nlora_r: 32\nlora_alpha: 64\nlora_target_linear: true\n\nrl: grpo\ntrl:\n  use_data_producer: true\n  use_vllm: true\n  async_prefetch: true\n  prefetch_depth: 1\n  vllm_sync_interval: 2\n  vllm_lora_sync: true\n  streaming_partial_batch: true\n  vllm_importance_sampling_correction: true\n  off_policy_mask_threshold: 0.5\n  importance_sampling_level: token\n  num_generations: 8\n  max_completion_length: 512\n  reward_funcs:\n    - rewards.accuracy_reward\n  reroll_start_fraction: 0.5\n  replay_buffer_size: 100\n  reward_num_workers: 4\n  skip_zero_advantage_batches: true\n\ndatasets:\n  - path: AI-MO/NuminaMath-TIR\n    type: rewards.prompt_transform\n    split: train\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 2\nmax_steps: 500\nlearning_rate: 1e-5\nbf16: true\ngradient_checkpointing: true\n```\n\n```bash\n# Terminal 1: Start vLLM on GPU 0\nCUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml\n\n# Terminal 2: Train on GPU 1\nCUDA_VISIBLE_DEVICES=1 axolotl train config.yaml\n```\n\n##### Multi-GPU Async GRPO\n\nAsync GRPO supports FSDP and DeepSpeed ZeRO-3 for multi-GPU training. vLLM runs on one GPU while training is distributed across the remaining GPUs.\n\n**FSDP:**\n\n```yaml\nfsdp:\n  - full_shard\n  - auto_wrap\nfsdp_config:\n  fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer\ngradient_checkpointing_kwargs:\n  use_reentrant: false\n```\n\n**DeepSpeed ZeRO-3:**\n\n```yaml\ndeepspeed: deepspeed_configs/zero3_bf16.json\ngradient_checkpointing_kwargs:\n  use_reentrant: true   # Required for ZeRO-3\n```\n\n```bash\n# Terminal 1: Start vLLM on GPU 0\nCUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml\n\n# Terminal 2: Train on GPUs 0,1\nCUDA_VISIBLE_DEVICES=0,1 accelerate launch --num_processes 2 -m axolotl.cli.train config.yaml\n```\n\n::: {.callout-important}\nWith multi-GPU async prefetch, only rank 0 generates completions in the background thread. Results are broadcast to all ranks on the main thread. This avoids FSDP/DeepSpeed collective deadlocks from unsynchronized background threads.\n:::\n\n### GDPO\n\nGDPO (Group Reward-Decoupled Policy Optimization) extends GRPO for multi-reward training. It addresses the **reward advantage collapse** problem by normalizing each reward function independently before combining them.\n\n::: {.callout-tip}\nUse GDPO when training with multiple reward functions. For single reward, GRPO and GDPO produce equivalent results.\n:::\n\nPaper: [https://arxiv.org/pdf/2501.05242](https://arxiv.org/pdf/2501.05242)\n\nGDPO uses TRL's native `multi_objective_aggregation` parameter under the hood. When you set `rl: gdpo`, axolotl automatically configures TRL to use `normalize_then_sum` aggregation.\n\n```yaml\nbase_model: Qwen/Qwen2.5-1.5B-Instruct\n\nvllm:\n    host: 0.0.0.0\n    port: 8000\n    tensor_parallel_size: 2\n    gpu_memory_utilization: 0.85\n\nrl: gdpo\n\ntrl:\n    beta: 0.001\n    max_completion_length: 256\n    use_vllm: true\n    num_generations: 4\n    reward_funcs:\n        - rewards.format_reward\n        - rewards.correctness_reward\n    reward_weights: [1.0, 2.0]\n\ndatasets:\n    - path: openai/gsm8k\n      name: main\n      type: rewards.oai_gsm8k_transform\n```\n\nYou can also use GRPO with explicit aggregation control:\n\n```yaml\nrl: grpo\ntrl:\n    multi_objective_aggregation: normalize_then_sum  # GDPO behavior\n    # or: sum_then_normalize  # Default GRPO behavior\n```\n\n#### GDPO vs GRPO\n\n| Aspect | GRPO | GDPO |\n|--------|------|------|\n| **Aggregation** | `sum_then_normalize` | `normalize_then_sum` |\n| **Multi-reward** | May collapse advantages | Preserves reward signals |\n| **Single reward** | Standard behavior | Equivalent to GRPO |\n\n#### Why GDPO?\n\nWhen using multiple rewards with GRPO, different reward combinations can produce identical advantages:\n\n```\n# Example: format + correctness rewards\n[format=0, correct=3] → sum=3\n[format=1, correct=2] → sum=3  ← GRPO sees these as equal!\n[format=2, correct=1] → sum=3\n[format=3, correct=0] → sum=3\n```\n\nGDPO normalizes each reward independently, preserving their relative differences.\n\n#### Reward Functions\n\nGDPO uses the same reward function format as GRPO:\n\n```python\n# rewards.py\ndef format_reward(completions, **kwargs) -> list[float]:\n    return [1.0 if len(c) > 10 else 0.0 for c in completions]\n\ndef correctness_reward(completions, answers, **kwargs) -> list[float]:\n    rewards = []\n    for completion, answer in zip(completions, answers):\n        # Your scoring logic here\n        rewards.append(score)\n    return rewards\n```\n\n#### Sequence Parallelism\n\nGDPO supports sequence parallelism for long-context training:\n\n```yaml\nrl: gdpo\ncontext_parallel_size: 2\n```\n\n### SimPO\n\nSimPO uses [CPOTrainer](https://huggingface.co/docs/trl/main/en/cpo_trainer) but with alternative loss function.\n\n```yaml\nrl: simpo\nrl_beta: 0.1  # default in CPOTrainer\ncpo_alpha: 1.0  # default in CPOTrainer\nsimpo_gamma: 0.5  # default in CPOTrainer\n```\n\nThis method uses the same dataset format as [DPO](#dpo).\n\n### Using local dataset files\n\n```yaml\ndatasets:\n  - ds_type: json\n    data_files:\n      - orca_rlhf.jsonl\n    split: train\n    type: chatml.intel\n```\n\n### TRL auto-unwrapping for PEFT\n\nTRL supports auto-unwrapping PEFT models for RL training paradigms which rely on a reference model. This significantly reduces memory pressure as an additional refreference model does not need to be loaded, and reference model log-probabilities can be obtained by disabling PEFT adapters. This is enabled by default. To turn it off, pass the following config:\n\n```yaml\n# load ref model when adapter training.\nrl_adapter_ref_model: true\n```\n"
  },
  {
    "path": "docs/scripts/examples-allowlist.yml",
    "content": "examples:\n  # December 2025\n  - name: kimi-linear\n    title: Kimi Linear\n  - name: plano\n    title: Plano Orchestrator\n  - name: mimo\n    title: MiMo\n  - name: internvl3_5\n    title: InternVL 3.5\n\n  # AllenAI\n  - name: olmo3\n    title: OLMo 3\n\n  # ArceeAI\n  - name: trinity\n    title: Trinity\n  - name: arcee\n    title: Arcee AFM\n\n  # MistralAI\n  - name: ministral3/think\n    title: Ministral 3 Thinking\n  - name: ministral3/vision\n    title: Ministral 3 Vision\n  - name: magistral/think\n    title: Magistral Thinking\n  - name: magistral/vision\n    title: Magistral Vision\n  - name: ministral\n    title: Ministral\n  - name: mistral-small\n    title: Mistral Small 3.1/3.2\n  - name: voxtral\n    title: Voxtral\n  - name: devstral\n    title: Devstral\n  - name: mistral\n    title: Mistral 7B\n\n  # Meta\n  - name: llama-4\n    title: Llama 4\n  - name: llama-2\n    title: Llama 2\n\n  # Alibaba\n  - name: qwen3-next\n    title: Qwen 3 Next\n  - name: qwen3\n    title: Qwen 3\n\n  # Google\n  - name: gemma3n\n    title: Gemma 3n\n\n  # Swiss AI\n  - name: apertus\n    title: Apertus\n\n  # GPT-OSS\n  - name: gpt-oss\n    title: GPT-OSS\n  - name: seed-oss\n    title: Seed-OSS\n\n  # Microsoft\n  - name: phi\n    title: Phi\n\n  # SmolVLM\n  - name: smolvlm2\n    title: SmolVLM 2\n\n  # IBM\n  - name: granite4\n    title: Granite 4\n\n  # LiquidAI\n  - name: LiquidAI\n    title: Liquid Foundation Models 2\n\n  # Other\n  - name: hunyuan\n    title: Hunyuan\n  - name: jamba\n    title: Jamba\n  - name: orpheus\n    title: Orpheus\n"
  },
  {
    "path": "docs/scripts/generate_config_docs.py",
    "content": "# type: ignore\n\n\"\"\"\nQuarto documentation generation from Pydantic models. Uses Pydantic model source code\nto automatically group fields, including inherited fields from parent classes.\n\"\"\"\n\nimport ast\nimport inspect\nimport textwrap\nimport types\nimport typing\nfrom typing import Any, FrozenSet, Type, Union\n\nfrom pydantic import BaseModel\n\nfrom axolotl.utils.schemas.config import AxolotlInputConfig\n\n\nclass QuartoGenerator:\n    \"\"\"Generate Quarto documentation from Pydantic models.\"\"\"\n\n    def __init__(self):\n        self._class_fields_cache = {}\n        self._inheritance_map_cache = {}\n        self._nested_models_cache = {}\n\n    def _get_direct_fields(self, cls: Type[BaseModel]) -> FrozenSet[str]:\n        \"\"\"Get fields defined directly in a single class (not inherited).\"\"\"\n        if cls in self._class_fields_cache:\n            return self._class_fields_cache[cls]\n\n        fields = set()\n\n        # Get annotated fields\n        if hasattr(cls, \"__annotations__\"):\n            fields.update(cls.__annotations__.keys())\n\n        # Filter out private/special methods\n        fields = {f for f in fields if not f.startswith(\"_\")}\n\n        result = frozenset(fields)\n        self._class_fields_cache[cls] = result\n        return result\n\n    def _is_pydantic_model(self, type_obj) -> bool:\n        \"\"\"Check if a type is a Pydantic BaseModel.\"\"\"\n        return inspect.isclass(type_obj) and issubclass(type_obj, BaseModel)\n\n    def _extract_nested_type(self, field_type) -> Any:\n        \"\"\"Extract the actual type from complex type annotations.\"\"\"\n        # Handle Annotated types (Python 3.9+)\n        if hasattr(typing, \"get_origin\") and hasattr(typing, \"get_args\"):\n            origin = typing.get_origin(field_type)\n            args = typing.get_args(field_type)\n\n            if origin is not None:\n                # Handle Annotated[SomeType, ...] - extract the first argument\n                if hasattr(typing, \"Annotated\") and origin is typing.Annotated:\n                    if args:\n                        return self._extract_nested_type(\n                            args[0]\n                        )  # Recursively process the actual type\n\n                # Handle list[SomeType], List[SomeType], etc.\n                elif origin in (list, typing.List):\n                    if args:\n                        return self._extract_nested_type(\n                            args[0]\n                        )  # Extract element type\n\n                # Handle Union types (including | syntax)\n                elif origin is typing.Union:\n                    # Get non-None types from the Union\n                    non_none_types = [arg for arg in args if arg is not type(None)]\n                    if len(non_none_types) >= 1:\n                        # Prioritize Pydantic models over primitive types\n                        pydantic_models = [\n                            arg\n                            for arg in non_none_types\n                            if self._is_pydantic_model(arg)\n                        ]\n                        if pydantic_models:\n                            # Return the first Pydantic model found\n                            return self._extract_nested_type(pydantic_models[0])\n\n                        # No Pydantic models, return the first non-None type\n                        return self._extract_nested_type(non_none_types[0])\n\n        # Handle new Python 3.10+ union syntax (PeftConfig | None)\n        if hasattr(field_type, \"__class__\") and field_type.__class__ is types.UnionType:\n            # Get non-None types from the Union\n            non_none_types = [\n                arg for arg in field_type.__args__ if arg is not type(None)\n            ]\n            if len(non_none_types) >= 1:\n                # Prioritize Pydantic models over primitive types\n                pydantic_models = [\n                    arg for arg in non_none_types if self._is_pydantic_model(arg)\n                ]\n                if pydantic_models:\n                    return self._extract_nested_type(pydantic_models[0])\n                return self._extract_nested_type(non_none_types[0])\n\n        # Handle old typing.Union syntax (fallback)\n        if hasattr(field_type, \"__origin__\"):\n            if field_type.__origin__ is Union:\n                # Get non-None types from the Union\n                non_none_types = [\n                    arg for arg in field_type.__args__ if arg is not type(None)\n                ]\n                if len(non_none_types) >= 1:\n                    # Prioritize Pydantic models over primitive types\n                    pydantic_models = [\n                        arg for arg in non_none_types if self._is_pydantic_model(arg)\n                    ]\n                    if pydantic_models:\n                        return self._extract_nested_type(pydantic_models[0])\n                    return self._extract_nested_type(non_none_types[0])\n            # Handle other generic types like dict[str, Any], etc.\n            elif hasattr(field_type, \"__args__\"):\n                return field_type\n\n        return field_type\n\n    def _extract_all_pydantic_models_from_type(\n        self, field_type\n    ) -> list[type[BaseModel]]:\n        \"\"\"Extract all Pydantic models from a type annotation, including from Unions.\"\"\"\n        models = []\n\n        if field_type is None:\n            return models\n\n        # Handle Annotated types\n        if hasattr(typing, \"get_origin\") and hasattr(typing, \"get_args\"):\n            origin = typing.get_origin(field_type)\n            args = typing.get_args(field_type)\n\n            if origin is not None:\n                # Handle Annotated[SomeType, ...] - extract from the first argument\n                if hasattr(typing, \"Annotated\") and origin is typing.Annotated:\n                    if args:\n                        models.extend(\n                            self._extract_all_pydantic_models_from_type(args[0])\n                        )\n                    return models\n\n                # Handle list[SomeType], List[SomeType], etc.\n                if origin in (list, typing.List):\n                    if args:\n                        models.extend(\n                            self._extract_all_pydantic_models_from_type(args[0])\n                        )\n                    return models\n\n                # Handle Union types\n                if origin is typing.Union:\n                    for arg in args:\n                        if arg is not type(None):  # Skip None type\n                            models.extend(\n                                self._extract_all_pydantic_models_from_type(arg)\n                            )\n                    return models\n\n        # Handle new Python 3.10+ union syntax\n        if hasattr(field_type, \"__class__\") and field_type.__class__ is types.UnionType:\n            for arg in field_type.__args__:\n                if arg is not type(None):  # Skip None type\n                    models.extend(self._extract_all_pydantic_models_from_type(arg))\n            return models\n\n        # Handle old typing.Union syntax (fallback)\n        if hasattr(field_type, \"__origin__\") and field_type.__origin__ is Union:\n            for arg in field_type.__args__:\n                if arg is not type(None):  # Skip None type\n                    models.extend(self._extract_all_pydantic_models_from_type(arg))\n            return models\n\n        # Check if this type itself is a Pydantic model\n        if self._is_pydantic_model(field_type):\n            models.append(field_type)\n\n        return models\n\n    def _get_nested_models(\n        self, model_class: type[BaseModel], visited=None\n    ) -> dict[str, type[BaseModel]]:\n        \"\"\"Get all nested Pydantic models from a model class.\"\"\"\n        if visited is None:\n            visited = set()\n\n        # Avoid infinite recursion\n        if model_class in visited:\n            return {}\n\n        if model_class in self._nested_models_cache:\n            return self._nested_models_cache[model_class]\n\n        visited.add(model_class)\n        nested_models = {}\n\n        # Check all fields in the model\n        for field_info in model_class.model_fields.values():\n            field_type = self._extract_nested_type(field_info.annotation)\n\n            if self._is_pydantic_model(field_type):\n                nested_models[field_type.__name__] = field_type\n                # Recursively get nested models from this nested model\n                deeper_nested = self._get_nested_models(field_type, visited.copy())\n                nested_models.update(deeper_nested)\n\n        self._nested_models_cache[model_class] = nested_models\n        return nested_models\n\n    def _build_inheritance_map(self, child_class: Type[BaseModel]):\n        \"\"\"Build inheritance map for a class and all its parents.\"\"\"\n        if child_class in self._inheritance_map_cache:\n            return self._inheritance_map_cache[child_class]\n\n        inheritance_map = {}\n\n        # Get MRO and filter out BaseModel and object\n        mro_classes = [\n            cls\n            for cls in child_class.__mro__\n            if cls not in (BaseModel, object) and hasattr(cls, \"__annotations__\")\n        ]\n\n        # Process each class in the MRO\n        for cls in mro_classes:\n            inheritance_map[cls] = self._get_direct_fields(cls)\n\n        self._inheritance_map_cache[child_class] = inheritance_map\n        return inheritance_map\n\n    def _wrap_comment(self, text: str, width: int = 88) -> list[str]:\n        \"\"\"Wrap a comment to specified width, accounting for '# ' prefix.\"\"\"\n        if not text.strip():\n            return [\"#\"]\n\n        # Account for \"# \" prefix (2 characters)\n        content_width = width - 2\n        wrapped_lines = textwrap.wrap(text, width=content_width)\n        return [f\"# {line}\" for line in wrapped_lines]\n\n    def _extract_type_from_source(\n        self, model_class: type[BaseModel], field_name: str\n    ) -> str:\n        \"\"\"Extract the actual type annotation text from source code, checking inheritance chain.\"\"\"\n        # Use inheritance map to check classes efficiently\n        inheritance_map = self._build_inheritance_map(model_class)\n\n        # Check classes in MRO order\n        for cls in model_class.__mro__:\n            if cls in inheritance_map and field_name in inheritance_map[cls]:\n                type_annotation = self._get_type_from_class_source(cls, field_name)\n                if type_annotation != \"unknown\":\n                    return type_annotation\n\n        return \"unknown\"\n\n    def _get_type_from_class_source(self, class_obj: type, field_name: str) -> str:\n        \"\"\"Extract type annotation from a specific class's source code.\"\"\"\n        try:\n            source = inspect.getsource(class_obj)\n            tree = ast.parse(source)\n        except (OSError, TypeError):\n            return \"unknown\"\n\n        # Find the class definition\n        for node in tree.body:\n            if isinstance(node, ast.ClassDef) and node.name == class_obj.__name__:\n                # Find the field assignment\n                for body_node in node.body:\n                    if isinstance(body_node, ast.AnnAssign) and isinstance(\n                        body_node.target, ast.Name\n                    ):\n                        if body_node.target.id == field_name and body_node.annotation:\n                            return ast.unparse(body_node.annotation)\n                break\n\n        return \"unknown\"\n\n    def _extract_field_groups_from_all_classes(\n        self, model_class: type[BaseModel]\n    ) -> list[dict]:\n        \"\"\"Extract field groups from all classes in the inheritance hierarchy.\"\"\"\n        all_groups = []\n        inheritance_map = self._build_inheritance_map(model_class)\n\n        # Get all Pydantic base classes in MRO order (most specific first)\n        # This puts AxolotlInputConfig fields first, then parent class fields\n        pydantic_classes = [\n            cls\n            for cls in model_class.__mro__\n            if cls in inheritance_map and inheritance_map[cls]\n        ]\n\n        # Extract groups from each class\n        for cls in pydantic_classes:\n            class_groups = self._extract_field_groups_from_source(cls)\n            for group in class_groups:\n                all_groups.append(group)\n\n        # If no groups found, create a default grouping by class\n        if not all_groups:\n            for cls in pydantic_classes:\n                fields_in_class = inheritance_map[cls]\n                if fields_in_class:\n                    all_groups.append(\n                        {\n                            \"fields\": list(fields_in_class),\n                        }\n                    )\n\n        return all_groups\n\n    def _extract_field_groups_from_source(\n        self, model_class: type[BaseModel]\n    ) -> list[dict]:\n        \"\"\"Extract field groups from source code based on blank lines and comments.\"\"\"\n        try:\n            source = inspect.getsource(model_class)\n            tree = ast.parse(source)\n        except (OSError, TypeError):\n            # Fallback if we can't get source code\n            fields_in_class = self._get_direct_fields(model_class)\n            if fields_in_class:\n                return [\n                    {\n                        \"fields\": list(fields_in_class),\n                    }\n                ]\n            return []\n\n        groups = []\n        current_group_fields = []\n        current_group_comment = None\n\n        # Find the class definition\n        class_node = None\n        for node in ast.walk(tree):\n            if isinstance(node, ast.ClassDef) and node.name == model_class.__name__:\n                class_node = node\n                break\n\n        if not class_node:\n            fields_in_class = self._get_direct_fields(model_class)\n            if fields_in_class:\n                return [\n                    {\n                        \"fields\": list(fields_in_class),\n                    }\n                ]\n            return []\n\n        # Parse the source lines to detect groupings\n        source_lines = source.split(\"\\n\")\n\n        # Get fields that are actually defined in this specific class\n        fields_in_class = self._get_direct_fields(model_class)\n\n        # Find assignments that correspond to model fields for THIS class only\n        field_assignments = []\n        for node in class_node.body:\n            if isinstance(node, ast.AnnAssign) and isinstance(node.target, ast.Name):\n                field_name = node.target.id\n                if field_name in fields_in_class:\n                    field_assignments.append(\n                        {\n                            \"name\": field_name,\n                            \"lineno\": node.lineno,\n                            \"end_lineno\": getattr(node, \"end_lineno\", node.lineno),\n                        }\n                    )\n\n        if not field_assignments:\n            if fields_in_class:\n                return [\n                    {\n                        \"fields\": list(fields_in_class),\n                    }\n                ]\n            return []\n\n        # Sort by line number\n        field_assignments.sort(key=lambda x: x[\"lineno\"])\n\n        # Group fields based on blank lines and comments\n        for i, field_info in enumerate(field_assignments):\n            field_name = field_info[\"name\"]\n            current_line = field_info[\"lineno\"]\n\n            # Check if this starts a new group (blank line before or significant gap)\n            is_new_group = False\n\n            if i == 0:\n                is_new_group = True\n            else:\n                prev_end_line = field_assignments[i - 1][\"end_lineno\"]\n\n                # Check for blank lines or comments between fields\n                lines_between = source_lines[prev_end_line : current_line - 1]\n                has_blank_line = any(line.strip() == \"\" for line in lines_between)\n                has_comment = any(\n                    line.strip().startswith(\"#\") for line in lines_between\n                )\n\n                # Start new group if there's a blank line or comment, or significant gap\n                if has_blank_line or has_comment or (current_line - prev_end_line > 3):\n                    is_new_group = True\n\n            if is_new_group and current_group_fields:\n                # Save the previous group\n                groups.append(\n                    {\n                        \"fields\": current_group_fields.copy(),\n                        \"description\": current_group_comment,\n                    }\n                )\n                current_group_fields = []\n                current_group_comment = None\n\n            current_group_fields.append(field_name)\n\n        # Add the final group\n        if current_group_fields:\n            groups.append(\n                {\n                    \"fields\": current_group_fields,\n                    \"description\": current_group_comment,\n                }\n            )\n\n        return groups\n\n    def _generate_field_documentation(\n        self,\n        model_class: type[BaseModel],\n        field_name: str,\n        field_info: dict,\n        field_type_str: str,\n        is_required: bool,\n        indent_level: int = 0,\n        visited_models: set = None,\n    ) -> list[str]:\n        \"\"\"Generate documentation for a single field, expanding nested models inline.\"\"\"\n        if visited_models is None:\n            visited_models = set()\n\n        lines = []\n        indent = \"  \" * indent_level\n\n        # Get the actual field type for nested model detection\n        if field_name in model_class.model_fields:\n            pydantic_field_info = model_class.model_fields[field_name]\n            actual_field_type = pydantic_field_info.annotation\n        else:\n            actual_field_type = None\n\n        # Add description comment if available\n        description = field_info.get(\"description\", \"\")\n        if description:\n            wrapped_lines = self._wrap_comment(description, width=88 - len(indent))\n            for line in wrapped_lines:\n                lines.append(f\"{indent}{line}\")\n\n        # Extract nested Pydantic models from the type annotation\n        nested_models = self._extract_all_pydantic_models_from_type(actual_field_type)\n\n        # Filter out already visited models to prevent infinite recursion\n        expandable_models = [\n            model for model in nested_models if model not in visited_models\n        ]\n\n        if expandable_models:\n            # This field contains Pydantic models that can be expanded\n\n            # Show the field with its full type annotation\n            field_line = f\"{indent}{field_name}: {field_type_str}\"\n            if field_info.get(\"default\") is not None:\n                field_line += f\" = {field_info['default']}\"\n            if is_required:\n                field_line += \" (required)\"\n            lines.append(field_line)\n\n            # Add to visited to prevent infinite recursion\n            new_visited = visited_models.copy()\n            new_visited.update(expandable_models)\n\n            # Expand each nested Pydantic model\n            for i, nested_model in enumerate(expandable_models):\n                if i > 0:\n                    lines.append(\"\\n\")\n                lines.append(f\"{indent}  # For {nested_model.__name__}:\")\n\n                # Get nested model schema\n                try:\n                    nested_schema = nested_model.model_json_schema()\n                    nested_properties = nested_schema.get(\"properties\", {})\n                    nested_required = nested_schema.get(\"required\", [])\n                except Exception:\n                    # Fallback: use model fields directly\n                    nested_properties = {}\n                    nested_required = []\n                    for (\n                        nested_field_name,\n                        nested_field_info,\n                    ) in nested_model.model_fields.items():\n                        nested_description = \"\"\n                        if (\n                            hasattr(nested_field_info, \"json_schema_extra\")\n                            and nested_field_info.json_schema_extra\n                        ):\n                            nested_description = (\n                                nested_field_info.json_schema_extra.get(\n                                    \"description\", \"\"\n                                )\n                            )\n                        elif (\n                            hasattr(nested_field_info, \"description\")\n                            and nested_field_info.description\n                        ):\n                            nested_description = nested_field_info.description\n\n                        nested_default_val = None\n                        if (\n                            hasattr(nested_field_info, \"default\")\n                            and nested_field_info.default is not None\n                        ):\n                            if str(nested_field_info.default) != \"PydanticUndefined\":\n                                nested_default_val = nested_field_info.default\n\n                        nested_properties[nested_field_name] = {\n                            \"type\": \"unknown\",\n                            \"description\": nested_description,\n                            \"default\": nested_default_val,\n                        }\n\n                        if nested_field_info.is_required():\n                            nested_required.append(nested_field_name)\n\n                # Get field groups for the nested model\n                nested_field_groups = self._extract_field_groups_from_all_classes(\n                    nested_model\n                )\n\n                # Generate nested fields with increased indentation\n                for i, group in enumerate(nested_field_groups):\n                    if not group[\"fields\"]:\n                        continue\n\n                    # Add blank line between groups (except before first group)\n                    if i > 0:\n                        lines.append(\"\")\n\n                    # Process nested fields\n                    for nested_field_name in group[\"fields\"]:\n                        if nested_field_name not in nested_properties:\n                            continue\n\n                        nested_field_info = nested_properties[nested_field_name]\n                        nested_field_type = self._extract_type_from_source(\n                            nested_model, nested_field_name\n                        )\n                        nested_is_required = nested_field_name in nested_required\n\n                        # Recursively generate documentation for nested field\n                        nested_lines = self._generate_field_documentation(\n                            nested_model,\n                            nested_field_name,\n                            nested_field_info,\n                            nested_field_type,\n                            nested_is_required,\n                            indent_level + 1,\n                            new_visited,\n                        )\n                        lines.extend(nested_lines)\n        else:\n            # Regular field (no expandable nested models)\n            field_line = f\"{indent}{field_name}: {field_type_str}\"\n            if field_info.get(\"default\") is not None:\n                field_line += f\" = {field_info['default']}\"\n            if is_required:\n                field_line += \" (required)\"\n            lines.append(field_line)\n\n        return lines\n\n    def generate_qmd(\n        self,\n        model_class: type[BaseModel],\n        title: str | None = None,\n        expand_nested: bool = True,\n    ) -> str:\n        \"\"\"Auto-generate config reference documentation including inherited fields.\"\"\"\n\n        if title is None:\n            title = f\"{model_class.__name__} Reference\"\n\n        # Try to get JSON schema, with fallback for serialization issues\n        try:\n            schema = model_class.model_json_schema()\n            properties = schema.get(\"properties\", {})\n            required = schema.get(\"required\", [])\n        except Exception as e:\n            print(\n                f\"Warning: Could not generate JSON schema ({e}). Using model fields instead.\"\n            )\n            # Fallback: use model fields directly\n            properties = {}\n            required = []\n            for field_name, field_info in model_class.model_fields.items():\n                # Extract description from json_schema_extra or field info\n                description = \"\"\n                if (\n                    hasattr(field_info, \"json_schema_extra\")\n                    and field_info.json_schema_extra\n                ):\n                    description = field_info.json_schema_extra.get(\"description\", \"\")\n                elif hasattr(field_info, \"description\") and field_info.description:\n                    description = field_info.description\n\n                # Get default value\n                default_val = None\n                if hasattr(field_info, \"default\") and field_info.default is not None:\n                    # Handle special Pydantic default markers\n                    if str(field_info.default) != \"PydanticUndefined\":\n                        default_val = field_info.default\n\n                properties[field_name] = {\n                    \"type\": \"unknown\",\n                    \"description\": description,\n                    \"default\": default_val,\n                }\n\n                if field_info.is_required():\n                    required.append(field_name)\n\n        # Extract field groups from all classes in inheritance hierarchy\n        field_groups = self._extract_field_groups_from_all_classes(model_class)\n\n        # Start building QMD content\n        qmd_lines = [\n            \"---\",\n            f\"title: {title}\",\n            \"description: A complete list of all configuration options.\",\n            \"---\",\n            \"\",\n        ]\n\n        # Generate one big code block with all fields (inline nested expansion)\n        qmd_lines.append(\"```yaml\")\n\n        for i, group in enumerate(field_groups):\n            if not group[\"fields\"]:\n                continue\n\n            # Add blank line between groups (except before first group)\n            if i > 0:\n                qmd_lines.append(\"\")\n\n            # Process fields in the order they appear in source\n            for field_name in group[\"fields\"]:\n                if field_name not in properties:\n                    continue\n\n                field_info = properties[field_name]\n                field_type = self._extract_type_from_source(model_class, field_name)\n                is_required = field_name in required\n\n                if expand_nested:\n                    # Check if this field has nested models\n                    if field_name in model_class.model_fields:\n                        pydantic_field_info = model_class.model_fields[field_name]\n                        nested_models = self._extract_all_pydantic_models_from_type(\n                            pydantic_field_info.annotation\n                        )\n                        has_nested = bool(nested_models)\n                    else:\n                        has_nested = False\n\n                    # Add blank line before nested config\n                    if has_nested:\n                        qmd_lines.append(\"\")\n\n                    # Use the new inline generation method\n                    field_lines = self._generate_field_documentation(\n                        model_class,\n                        field_name,\n                        field_info,\n                        field_type,\n                        is_required,\n                        indent_level=0,\n                        visited_models=set(),\n                    )\n                    qmd_lines.extend(field_lines)\n\n                    # Add blank line after nested config\n                    if has_nested:\n                        qmd_lines.append(\"\")\n                else:\n                    # Original simple approach\n                    description = field_info.get(\"description\", \"\")\n                    default = field_info.get(\"default\")\n\n                    # Add wrapped comment for description\n                    if description:\n                        wrapped_lines = self._wrap_comment(description)\n                        qmd_lines.extend(wrapped_lines)\n\n                    line = f\"{field_name}: {field_type}\"\n                    if default is not None:\n                        line += f\" = {default}\"\n                    if is_required:\n                        line += \" (required)\"\n                    qmd_lines.append(line)\n\n        qmd_lines.append(\"```\")\n\n        # Join all lines and clean up any double newlines\n        content = \"\\n\".join(qmd_lines)\n\n        # Replace multiple consecutive newlines with just two newlines (one blank line)\n        import re\n\n        content = re.sub(r\"\\n{3,}\", \"\\n\\n\", content)\n\n        # Ensure single newline at the very end\n        content = content.rstrip(\"\\n\") + \"\\n\"\n\n        return content\n\n\ndef main():\n    generator = QuartoGenerator()\n\n    print(\"Generating config reference content...\")\n    qmd_content = generator.generate_qmd(AxolotlInputConfig, \"Config Reference\", True)\n\n    print(\"Writing to file...\")\n    with open(\"docs/config-reference.qmd\", \"w\", encoding=\"utf-8\") as f:\n        f.write(qmd_content)\n    print(\"Done!\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "docs/scripts/generate_examples_docs.py",
    "content": "\"\"\"\nauto generate example docs from allowlist\n\"\"\"\n\nimport re\nimport shutil\nimport sys\nfrom pathlib import Path\n\nimport yaml\n\n# Paths\nTHIS = Path(__file__).resolve()\nROOT = THIS.parents[2]  # repo root (docs/scripts -> docs -> ROOT)\nEXAMPLES_DIR = ROOT / \"examples\"\nOUTPUT_DIR = ROOT / \"docs\" / \"models\"\nALLOWLIST_YML = THIS.parent / \"examples-allowlist.yml\"\n\n\ndef slugify(name: str) -> str:\n    \"\"\"Convert a name to a slug (lowercase, hyphens for spaces).\"\"\"\n    s = re.sub(r\"[^a-zA-Z0-9\\s\\-]+\", \"\", name.strip())\n    s = re.sub(r\"\\s+\", \"-\", s).strip(\"-\").lower()\n    return s or \"example\"\n\n\ndef read_allowlist():\n    with open(ALLOWLIST_YML, \"r\", encoding=\"utf-8\") as f:\n        data = yaml.safe_load(f) or {}\n    items = data.get(\"examples\", [])\n    if not isinstance(items, list):\n        raise ValueError(\"`examples` must be a list in examples-allowlist.yml\")\n    return items\n\n\ndef find_readme(folder: Path) -> Path | None:\n    for name in (\"README.md\", \"Readme.md\", \"readme.md\"):\n        p = folder / name\n        if p.exists():\n            return p\n    return None\n\n\ndef remove_first_h1(md: str) -> tuple[str, str | None]:\n    \"\"\"\n    Remove the first H1 from markdown and return (modified_md, h1_title).\n    The H1 is removed since we use the frontmatter title instead.\n    \"\"\"\n    lines = md.splitlines()\n    result = []\n    h1_title = None\n    skipped_first = False\n\n    for line in lines:\n        if not skipped_first and line.startswith(\"# \"):\n            h1_title = line[2:].strip()\n            skipped_first = True\n            continue\n        result.append(line)\n\n    return \"\\n\".join(result), h1_title\n\n\nIMG_RE = re.compile(r\"!\\[[^\\]]*\\]\\(([^)]+)\\)\")\nLINK_RE = re.compile(r\"\\[([^\\]]+)\\]\\(([^)]+)\\)\")\n\n\ndef rewrite_and_copy_assets(md: str, src_dir: Path, dest_assets_root: Path) -> str:\n    \"\"\"\n    Copy local image assets referenced in markdown to\n    docs/examples/assets/... and rewrite the links.\n    \"\"\"\n    dest_assets = dest_assets_root / \"assets\"\n\n    def repl(m):\n        url = m.group(1).strip()\n        if re.match(r\"^(https?:)?//\", url):\n            return m.group(0)  # leave remote URLs\n        src_path = (src_dir / url).resolve()\n        if not src_path.exists():\n            return m.group(0)  # leave as-is if not found\n        rel = src_path.relative_to(src_dir)\n        # Create a unique asset path based on source directory name\n        asset_name = src_dir.name.replace(\"/\", \"-\")\n        dest_path = dest_assets / asset_name / rel\n        dest_path.parent.mkdir(parents=True, exist_ok=True)\n        shutil.copy2(src_path, dest_path)\n        new_rel = f\"assets/{asset_name}/{rel.as_posix()}\"\n        return m.group(0).replace(url, new_rel)\n\n    return IMG_RE.sub(repl, md)\n\n\ndef rewrite_readme_links(\n    md: str,\n    src_dir: Path,\n    examples_dir: Path,\n    parent_index_only: set,\n    current_src_path: str,\n    allowlist_entries: set,\n    current_output_path: str,\n) -> str:\n    \"\"\"\n    Rewrite links between README.md files to point to the correct .qmd files.\n    \"\"\"\n\n    def repl(m):\n        text = m.group(1)\n        url = m.group(2).strip()\n\n        # Skip remote URLs and anchor links\n        if re.match(r\"^(https?:)?//\", url) or url.startswith(\"#\"):\n            return m.group(0)\n\n        # Skip non-markdown files\n        if not url.lower().endswith(\".md\"):\n            return m.group(0)\n\n        # Resolve the target path\n        try:\n            target_path = (src_dir / url).resolve()\n\n            # Check if target is outside examples_dir\n            try:\n                rel_path = target_path.relative_to(examples_dir)\n            except ValueError:\n                # Target is outside examples_dir, leave as-is\n                return m.group(0)\n\n            parts = list(rel_path.parts)\n\n            # Determine the output path for the target\n            if len(parts) > 0 and parts[-1].lower() in (\"readme.md\", \"readme\"):\n                # This is a README link\n                if len(parts) == 1:\n                    # Link to root README -> index.qmd\n                    target_output = \"index.qmd\"\n                elif len(parts) == 2:\n                    if parts[0] == \".\":\n                        # Current directory README\n                        target_output = \"index.qmd\"\n                    else:\n                        # subdir/README.md\n                        parent_dir = parts[0]\n                        if parent_dir in parent_index_only:\n                            target_output = f\"{parent_dir}/index.qmd\"\n                        else:\n                            target_output = f\"{parent_dir}.qmd\"\n                else:\n                    # Deeper nesting: parent/subdir/README.md\n                    # Build the full path like \"parent/subdir\"\n                    full_path = \"/\".join(parts[:-1])  # Remove README.md\n                    # Check if this exact path is in allowlist\n                    if full_path in allowlist_entries:\n                        # This is a sub-entry with its own entry -> use .qmd\n                        target_output = f\"{full_path}.qmd\"\n                    elif parts[0] == \".\":\n                        # ./subdir/README.md -> check if subdir has own entry\n                        subdir = parts[1]\n                        if subdir in parent_index_only:\n                            target_output = f\"{subdir}/index.qmd\"\n                        else:\n                            target_output = f\"{subdir}.qmd\"\n                    else:\n                        # parent/subdir where parent doesn't have own entry\n                        target_output = f\"{full_path}/index.qmd\"\n            else:\n                # Regular .md file -> convert to .qmd, keep path structure\n                target_output = \"/\".join(parts)[:-2] + \"qmd\"\n\n            # Compute relative path from current output file to target\n            current_parts = current_output_path.split(\"/\")\n            target_parts = target_output.split(\"/\")\n\n            # Special case: if current is a subdir file and target is a single-component file at root\n            # Example: current=\"magistral/vision\", target=\"magistral.qmd\"\n            if len(current_parts) > 1 and len(target_parts) == 1:\n                # Current is in subdir, target is at root level\n                # Go up to root: ../ for each level\n                up_count = len(current_parts) - 1\n                rel_parts = [\"..\"] * up_count + [target_parts[0]]\n                new_url = \"/\".join(rel_parts)\n            else:\n                # Find common prefix\n                i = 0\n                while (\n                    i < min(len(current_parts) - 1, len(target_parts))\n                    and current_parts[i] == target_parts[i]\n                ):\n                    i += 1\n\n                # Build relative path: go up (../) then down to target\n                up_count = len(current_parts) - 1 - i\n                rel_parts = [\"..\"] * up_count + target_parts[i:]\n\n                if not rel_parts or rel_parts == [\"..\"]:\n                    # Points to same directory or parent\n                    new_url = \"/\".join(rel_parts) if rel_parts else \".\"\n                else:\n                    new_url = \"/\".join(rel_parts)\n\n            return f\"[{text}]({new_url})\"\n        except (ValueError, IndexError):\n            return m.group(0)\n\n    return LINK_RE.sub(repl, md)\n\n\ndef write_qmd(out_path: Path, title: str, body_md: str):\n    out_path.parent.mkdir(parents=True, exist_ok=True)\n    fm = f\"---\\ntitle: {title!r}\\nexecute:\\n  eval: false\\nformat:\\n  html:\\n    toc: true\\n---\\n\\n\"\n    out_path.write_text(fm + body_md, encoding=\"utf-8\")\n\n\ndef update_quarto_yml(generated: list[tuple[str, str, str]]):\n    \"\"\"\n    Update _quarto.yml with the generated example files in the correct order.\n    This keeps the sidebar in sync with the allowlist.\n\n    Model Guides is now nested under \"Getting Started\" section.\n    Creates nested sections for models with sub-entries (e.g., magistral, ministral3).\n    Parent pages are now flat files (e.g., ministral3.qmd) with sub-pages in subdirs.\n    \"\"\"\n    quarto_yml = ROOT / \"_quarto.yml\"\n    if not quarto_yml.exists():\n        print(f\"[WARN] {quarto_yml} not found, skipping update\", file=sys.stderr)\n        return\n\n    content = quarto_yml.read_text(encoding=\"utf-8\")\n\n    # First pass: find all parents that have sub-entries\n    parents_with_subs = set()\n    for path, _name, _title in generated:\n        if \"/\" in path:\n            parent = path.split(\"/\")[0]\n            parents_with_subs.add(parent)\n\n    # Build the YAML contents while preserving allowlist order\n    lines = []\n    processed_sections = set()\n\n    for path, _name, title in generated:\n        # Check if this is a parent page that has sub-pages\n        if path in parents_with_subs:\n            # This is a parent page with sub-pages - create a nested section\n            if path not in processed_sections:\n                processed_sections.add(path)\n                section_title = (\n                    title or path.replace(\"-\", \" \").replace(\"_\", \" \").title()\n                )\n                lines.append(f'                - section: \"{section_title}\"')\n                lines.append(\"                  contents:\")\n                # Add the parent page first\n                lines.append(f\"                    - docs/models/{path}.qmd\")\n                # Then add all sub-pages\n                for sub_path, _sub_name, _sub_title in generated:\n                    if \"/\" in sub_path and sub_path.split(\"/\")[0] == path:\n                        lines.append(\n                            f\"                    - docs/models/{sub_path}.qmd\"\n                        )\n        elif \"/\" not in path:\n            # This is a flat item with no sub-pages\n            # Skip if it was already included as part of a parent section\n            if path not in processed_sections:\n                lines.append(f\"                - docs/models/{path}.qmd\")\n\n    yaml_content = \"\\n\".join(lines) + \"\\n\"\n\n    # Pattern to match only the Model Guides contents, stopping at the next item\n    # in Getting Started (lines starting with 12 spaces: same level as the section)\n    pattern = r'(            - section: \"Model Guides\"\\n              contents:)([^\\n]*|.*?)(?=\\n            - |\\n        - section:|\\n\\nformat:)'\n\n    def replacement(match):\n        prefix = match.group(1)\n        return prefix + \"\\n\" + yaml_content\n\n    new_content = re.sub(pattern, replacement, content, flags=re.DOTALL)\n\n    if new_content != content:\n        quarto_yml.write_text(new_content, encoding=\"utf-8\")\n        print(f\"Updated {quarto_yml}\")\n    else:\n        print(f\"No changes needed for {quarto_yml}\")\n\n\ndef main():\n    allow = read_allowlist()\n    if not EXAMPLES_DIR.exists():\n        print(f\"[WARN] {EXAMPLES_DIR} not found\", file=sys.stderr)\n        return\n\n    (OUTPUT_DIR / \"assets\").mkdir(parents=True, exist_ok=True)\n\n    # First pass: identify which parents have their own entry vs only sub-entries\n    parent_entries = set()  # Parents that have their own entry\n    parent_with_subs = set()  # Parents that have sub-entries\n    allowlist_entries = set()  # All entries in allowlist\n\n    for item in allow:\n        if isinstance(item, str):\n            name = item\n        else:\n            name = item.get(\"name\")\n\n        allowlist_entries.add(name)\n\n        if \"/\" in name:\n            parent = name.split(\"/\")[0]\n            parent_with_subs.add(parent)\n        else:\n            parent_entries.add(name)\n\n    # Parents with subs that DON'T have their own entry -> use index.qmd\n    parent_index_only = parent_with_subs - parent_entries\n\n    generated = []\n    seen_dirs = set()  # Track which parent directories we've created index for\n\n    for item in allow:\n        if isinstance(item, str):\n            name = item\n            title = None\n        else:\n            name = item.get(\"name\")\n            title = item.get(\"title\")\n\n        if not name:\n            print(f\"[WARN] Skipping item without name: {item}\", file=sys.stderr)\n            continue\n\n        src_dir = EXAMPLES_DIR / name\n        if not src_dir.exists() or not src_dir.is_dir():\n            print(f\"[WARN] Skipping {name} (not a directory)\", file=sys.stderr)\n            continue\n\n        readme = find_readme(src_dir)\n        if not readme:\n            print(f\"[WARN] Skipping {name} (no README.md)\", file=sys.stderr)\n            continue\n\n        md = readme.read_text(encoding=\"utf-8\")\n\n        # Determine output path first (needed for link rewriting)\n        parts = name.split(\"/\")\n        if len(parts) == 1:\n            # Simple case: no subdirectory\n            out_path = OUTPUT_DIR / f\"{parts[0]}.qmd\"\n            sidebar_path = parts[0]\n        else:\n            # Has subdirectory: e.g., magistral/think\n            parent = parts[0]\n            child = \"-\".join(parts[1:])  # handle nested subdirs\n            out_path = OUTPUT_DIR / parent / f\"{child}.qmd\"\n            sidebar_path = f\"{parent}/{child}\"\n\n        # Remove the first H1 (we use frontmatter title instead)\n        md, _ = remove_first_h1(md)\n        # Rewrite links between README files\n        md = rewrite_readme_links(\n            md,\n            src_dir,\n            EXAMPLES_DIR,\n            parent_index_only,\n            name,\n            allowlist_entries,\n            sidebar_path,\n        )\n        md = rewrite_and_copy_assets(md, src_dir, OUTPUT_DIR)\n\n        # Handle parent page generation for sub-entries\n        if len(parts) > 1:\n            # Has subdirectory: e.g., magistral/think\n            parent = parts[0]\n\n            # Create parent.qmd if not already done and parent doesn't have own entry\n            if parent not in seen_dirs and parent in parent_index_only:\n                parent_readme = find_readme(EXAMPLES_DIR / parent)\n                if parent_readme:\n                    parent_md = parent_readme.read_text(encoding=\"utf-8\")\n                    parent_md, _ = remove_first_h1(parent_md)\n                    parent_md = rewrite_readme_links(\n                        parent_md,\n                        EXAMPLES_DIR / parent,\n                        EXAMPLES_DIR,\n                        parent_index_only,\n                        parent,\n                        allowlist_entries,\n                        parent,\n                    )\n                    parent_md = rewrite_and_copy_assets(\n                        parent_md, EXAMPLES_DIR / parent, OUTPUT_DIR\n                    )\n                    parent_title = parent.replace(\"-\", \" \").replace(\"_\", \" \").title()\n                    write_qmd(OUTPUT_DIR / f\"{parent}.qmd\", parent_title, parent_md)\n                    generated.append((parent, parent, parent_title))\n                    seen_dirs.add(parent)\n\n        if not title:\n            title = name.replace(\"/\", \" \").replace(\"-\", \" \").title()\n\n        write_qmd(out_path, title, md)\n        generated.append((sidebar_path, name, title))\n\n    # Index page - preserve allowlist order\n    if generated:\n        listing = \"\\n\".join(\n            [f\"- [{title}]({path}.qmd)\" for path, name, title in generated]\n        )\n        index_md = (\n            \"# Model Guides\\n\\nBelow are the curated examples for training various model architectures:\\n\\n\"\n            + listing\n            + \"\\n\"\n        )\n        index_fm = (\n            \"---\\nexecute:\\n  eval: false\\nformat:\\n  html:\\n    toc: true\\n---\\n\\n\"\n        )\n        (OUTPUT_DIR / \"index.qmd\").write_text(index_fm + index_md, encoding=\"utf-8\")\n\n        # Auto-update _quarto.yml to keep sidebar in sync\n        update_quarto_yml(generated)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "docs/sequence_parallelism.qmd",
    "content": "---\ntitle: Sequence Parallelism\ndescription: Train with long sequences split across multiple GPUs.\n---\n\nSequence parallelism is a technique that splits sequences across multiple GPUs,\nallowing you to train with very long sequences that wouldn't fit on a single GPU. Each\nGPU processes a different portion of the sequence, and the results are aggregated\nthrough a ring communication pattern.\n\n## When to Use Sequence Parallelism\n\nUse sequence parallelism when:\n\n- You need to train with sequence lengths that don't fit into a single GPU's memory\n- You have multiple GPUs available\n- You're experiencing OOM (Out Of Memory) errors with long sequences\n\n## Configuration\n\nTo enable sequence parallelism, add the following to your configuration file:\n\n```yaml\n# Set to a divisor (> 1) of the number of GPUs available\ncontext_parallel_size: 4  # Split sequences across 4 GPUs\n# Optional; strides across the key dimension. Larger values use more memory but should make training faster.\nheads_k_stride: 1\n# Optional; one of \"varlen_llama3\" or \"batch_ring\". Defaults to\n# \"varlen_llama3\" when `sample_packing: true`, and \"batch_ring\" otherwise.\nring_attn_func:\n```\n\nThe `context_parallel_size` should be a divisor of the total number of GPUs. For example:\n\n- With 8 GPUs, valid values would be 2, 4, or 8\n- With 4 GPUs, valid values would be 2 or 4\n\n## Implementation Details\n\nWhen sequence parallelism is enabled:\n\n1. Each sequence is divided into equal chunks across the GPUs in a sequence parallel group\n2. The data collator handles the chunking of input_ids, attention_mask, labels, and position_ids\n3. Position IDs are adjusted to maintain proper relative positions\n4. The trainer uses special ring communication patterns for attention operations\n\n## Requirements\n\nTo use sequence parallelism, you need:\n\n- Multiple GPUs (at least 2)\n- The `ring-flash-attn` package. Install with:\n  - `pip install axolotl[ring-flash-attn]` (preferred)\n  - `pip install ring-flash-attn>=0.1.4`\n\n## Limitations\n\n- Flash attention must be enabled for this to work (`flash_attention: true` in config YAML)\n- May have a small performance overhead due to communication between GPUs\n\n## Example\n\n```yaml\nbase_model: meta-llama/Llama-3-8B-Instruct\nsequence_len: 8192\n\n...\n\ncontext_parallel_size: 4  # Split each sequence into 4 parts, one per GPU\n# Optional; strides across the key dimension. Larger values use more memory but should make training faster.\nheads_k_stride: 1\n# Optional; one of \"varlen_llama3\" or \"batch_ring\". Defaults to\n# \"varlen_llama3\" when `sample_packing: true`, and \"batch_ring\" otherwise.\nring_attn_func:\n\n...\n```\n\nThis will train the Llama 3 8B model with 8K context length, with each sequence split\ninto 2 subsequences of length 4096 across 2 GPUs.\n\n## Sample Packing with Sequence Parallelism\n\nSequence parallelism is compatible with Axolotl's sample packing functionality. When using both features together:\n\n1. Samples are first packed together\n2. The packed sequences are then divided across GPUs in the sequence parallel group\n3. Position IDs are automatically adjusted to maintain proper relative positions\n\n## Effect on Batch Size\n\nWhen using sequence parallelism, your effective global batch size is **divided** by the `context_parallel_size`. This happens because:\n\n- Each group of `context_parallel_size` GPUs works on the same batch (just different parts of each sequence)\n- The number of batches processed per step decreases\n\nFor example:\n- With 8 GPUs and no sequence parallelism: 8 different batches processed per step\n- With 8 GPUs and `context_parallel_size=4`: Only 2 different batches processed per step (each split across 4 GPUs)\n- If your per-GPU `micro_batch_size` is 2, the global batch size decreases from 16 to 4\n"
  },
  {
    "path": "docs/streaming.qmd",
    "content": "---\ntitle: Streaming Datasets\ndescription: How to use streaming mode for large-scale datasets and memory-efficient training\norder: 10\n---\n\nStreaming enables memory-efficient training with large datasets by loading data\nincrementally rather than loading the entire dataset into memory at once.\n\nUse streaming when:\n\n- Your dataset is too large to fit in memory (e.g. when you're doing pretraining with massive text corpora)\n- You want to start training immediately without preprocessing the entire dataset\n\nStreaming works with both remote and locally stored datasets!\n\n::: {.callout-note}\nStreaming currently only supports a single dataset. Multi-dataset support will be added soon.\n:::\n\n\n## Configuration\n\n### Basic Streaming\n\nEnable streaming mode by setting the `streaming` flag:\n\n```yaml\nstreaming: true\n```\n\n### Pretraining with Streaming\n\nFor pretraining tasks, streaming is automatically enabled when using `pretraining_dataset`:\n\n```yaml\npretraining_dataset:\n  - path: HuggingFaceFW/fineweb-edu\n    type: pretrain\n    text_column: text\n    split: train\n\n# Optionally, enable sample packing\nstreaming_multipack_buffer_size: 10000\nsample_packing: true\n```\n\n### SFT with Streaming\n\nFor supervised fine-tuning with streaming:\n\n```yaml\nstreaming: true\ndatasets:\n  - path: tatsu-lab/alpaca\n    type: alpaca\n    split: train\n\n# Optionally, enable sample packing\nstreaming_multipack_buffer_size: 10000\nsample_packing: true\n```\n\n## Configuration Options\n\n### `streaming_multipack_buffer_size`\n\nControls the buffer size for multipack streaming (default: 10,000). This determines how\nmany samples are buffered before packing. Larger buffers can improve packing efficiency\nbut use more memory.\n\n### `shuffle_merged_datasets`\n\nWhen enabled, shuffles the streaming dataset using the buffer. This requires additional\nmemory for the shuffle buffer.\n\n## Sample Packing with Streaming\n\nSample packing is supported for streaming datasets. When enabled, multiple samples are\npacked into a single sequence to maximize GPU utilization:\n\n```yaml\nsample_packing: true\nstreaming_multipack_buffer_size: 10000\n\n# For SFT: attention is automatically isolated between packed samples\n# For pretraining: control with pretrain_multipack_attn\npretrain_multipack_attn: true  # prevent cross-attention between packed samples\n```\n\nFor more information, see our [documentation](multipack.qmd) on multipacking.\n\n## Important Considerations\n\n### Memory Usage\n\nWhile streaming reduces memory usage compared to loading entire datasets, you still need\nto consider:\n\n- You can control the memory usage by adjusting `streaming_multipack_buffer_size`\n- Sample packing requires buffering multiple samples\n- Shuffling requires additional memory for the shuffle buffer\n\n### Performance\n\n- Streaming may have slightly higher latency compared to preprocessed datasets, as samples are processed on-the-fly\n- Network speed and disk read speed are important when streaming from remote sources or a local dataset, respectively\n- Consider using `axolotl preprocess` for smaller or more frequently used datasets\n\n### Evaluation Datasets\n\nEvaluation datasets are not streamed to ensure consistent evaluation metrics. They're\nloaded normally even when training uses streaming.\n\n## Examples\n\nSee the `examples/streaming/` directory for complete configuration examples:\n\n- `pretrain.yaml`: Pretraining with streaming dataset\n- `sft.yaml`: Supervised fine-tuning with streaming\n"
  },
  {
    "path": "docs/telemetry.qmd",
    "content": "---\ntitle: Telemetry\ndescription: A description of the telemetry implementation in Axolotl.\n---\n\n# Telemetry in Axolotl\n\nAxolotl implements anonymous telemetry to help maintainers understand how the library\nis used and where users encounter issues. This data helps prioritize features, optimize\nperformance, and fix bugs.\n\n## Data Collection\n\nWe collect:\n\n- System info: OS, Python version, Axolotl version, PyTorch version, Transformers\nversion, etc.\n- Hardware info: CPU count, memory, GPU count and models\n- Runtime metrics: Training progress, memory usage, timing information\n- Usage patterns: Models (from a whitelist) and configurations used\n- Error tracking: Stack traces and error messages (sanitized to remove personal\ninformation)\n\nPersonally identifiable information (PII) is not collected.\n\n## Implementation\n\nTelemetry is implemented using PostHog and consists of:\n\n- `axolotl.telemetry.TelemetryManager`: A singleton class that initializes the\ntelemetry system and provides methods for tracking events.\n- `axolotl.telemetry.errors.send_errors`: A decorator that captures exceptions and\nsends sanitized stack traces.\n- `axolotl.telemetry.runtime_metrics.RuntimeMetricsTracker`: A class that tracks\nruntime metrics during training.\n- `axolotl.telemetry.callbacks.TelemetryCallback`: A Trainer callback that sends\nruntime metrics telemetry.\n\nThe telemetry system will block training startup for 10 seconds to ensure users are\naware of data collection, unless telemetry is explicitly enabled or disabled.\n\n## Opt-Out Mechanism\n\nTelemetry is **enabled by default** on an opt-out basis. To disable it, set\n`AXOLOTL_DO_NOT_TRACK=1` or `DO_NOT_TRACK=1`.\n\nA warning message will be logged on start to clearly inform users about telemetry.\nWe will remove this after some period.\n\nTo hide the warning message about telemetry that is displayed on train, etc. startup,\nexplicitly set: `AXOLOTL_DO_NOT_TRACK=0` (enable telemetry) or `AXOLOTL_DO_NOT_TRACK=1`\n(explicitly disable telemetry).\n\n## Privacy\n\n- All path-like config information is automatically redacted from telemetry data\n- Model information is only collected for whitelisted organizations\n    - See `axolotl/telemetry/whitelist.yaml` for the set of whitelisted organizations\n- Each run generates a unique anonymous ID\n    - This allows us to link different telemetry events in a single same training run\n- Telemetry is only sent from the main process to avoid duplicate events\n"
  },
  {
    "path": "docs/torchao.qmd",
    "content": "---\ntitle: \"PyTorch ao\"\ndescription: \"Custom data types and layouts for training and inference\"\n---\n\nTo use experimental optimizers (`AdamWFp8`, `AdamW4bit`, `AdamW8bit`) from Pytorch Ao, please install the package as shown below.\n\n::: {.callout-tip}\nSome experimental optimizers are already present in regular Pytorch, so please re-check if you actually need this package!\n:::\n\n### Installation\n\nStable Release from the PyTorch index\n\n```bash\npip install torchao --extra-index-url https://download.pytorch.org/whl/cu121 # full options are cpu/cu118/cu121/cu124\n```\n\n\nNightly release\n\n```bash\npip install --pre torchao-nightly --index-url https://download.pytorch.org/whl/nightly/cu121 # full options are cpu/cu118/cu121/cu124\n```\n"
  },
  {
    "path": "docs/unsloth.qmd",
    "content": "---\ntitle: \"Unsloth\"\ndescription: \"Hyper-optimized QLoRA finetuning for single GPUs\"\n---\n\n### Overview\n\nUnsloth provides hand-written optimized kernels for LLM finetuning that slightly improve speed and VRAM over\nstandard industry baselines.\n\n::: {.callout-important}\nDue to breaking changes in transformers `v4.48.0`, users will need to downgrade to `<=v4.47.1` to use this patch.\n\nThis will later be deprecated in favor of [LoRA Optimizations](lora_optims.qmd).\n:::\n\n\n### Installation\n\nThe following will install the correct unsloth and extras from source.\n\n```bash\npython scripts/unsloth_install.py | sh\n```\n\n### Usage\n\nAxolotl exposes a few configuration options to try out unsloth and get most of the performance gains.\n\nOur unsloth integration is currently limited to the following model architectures:\n - llama\n\nThese options are specific to LoRA finetuning and cannot be used for multi-GPU finetuning\n```yaml\nunsloth_lora_mlp: true\nunsloth_lora_qkv: true\nunsloth_lora_o: true\n```\n\nThese options are composable and can be used with multi-gpu finetuning\n```yaml\nunsloth_cross_entropy_loss: true\nunsloth_rms_norm: true\nunsloth_rope: true\n```\n\n### Limitations\n\n- Single GPU only; e.g. no multi-gpu support\n- No deepspeed or FSDP support (requires multi-gpu)\n- LoRA + QLoRA support only. No full fine tunes or fp8 support.\n- Limited model architecture support. Llama, Phi, Gemma, Mistral only\n- No MoE support.\n"
  },
  {
    "path": "examples/LiquidAI/README.md",
    "content": "# Finetune Liquid Foundation Models 2 (LFM2) with Axolotl\n\n[Liquid Foundation Models 2 (LFM2)](https://huggingface.co/collections/LiquidAI/lfm2-686d721927015b2ad73eaa38) are a family of small, open-weight models from [Liquid AI](https://www.liquid.ai/) focused on quality, speed, and memory efficiency. Liquid AI released text-only [LFM2](https://huggingface.co/collections/LiquidAI/lfm2-686d721927015b2ad73eaa38) and text+vision [LFM2-VL](https://huggingface.co/collections/LiquidAI/lfm2-vl-68963bbc84a610f7638d5ffa) models.\n\nLFM2 features a new hybrid Liquid architecture with multiplicative gates, short-range convolutions, and grouped query attention, enabling fast training and inference.\n\nThis guide shows how to fine-tune both the LFM2 and LFM2-VL models with Axolotl.\n\nThanks to the team at LiquidAI for giving us early access to prepare for these releases.\n\n## Getting Started\n\n1.  Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).\n\n    Here is an example of how to install from pip:\n    ```bash\n    # Ensure you have a compatible version of Pytorch installed\n    pip3 install packaging setuptools wheel ninja\n    pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'\n    ```\n\n2.  Run one of the finetuning examples below.\n\n    **LFM2**\n    ```bash\n    # FFT SFT (1x48GB @ 25GiB)\n    axolotl train examples/LiquidAI/lfm2-350m-fft.yaml\n    ```\n\n    **LFM2-VL**\n    ```bash\n    # LoRA SFT (1x48GB @ 2.7GiB)\n    axolotl train examples/LiquidAI/lfm2-vl-lora.yaml\n    ```\n\n    **LFM2-MoE**\n    ```bash\n    pip install git+https://github.com/huggingface/transformers.git@0c9a72e4576fe4c84077f066e585129c97bfd4e6\n\n    # LoRA SFT (1x48GB @ 16.2GiB)\n    axolotl train examples/LiquidAI/lfm2-8b-a1b-lora.yaml\n    ```\n\n### TIPS\n\n- **Installation Error**: If you encounter `ImportError: ... undefined symbol ...` or `ModuleNotFoundError: No module named 'causal_conv1d_cuda'`, the `causal-conv1d` package may have been installed incorrectly. Try uninstalling it:\n  ```bash\n  pip uninstall -y causal-conv1d\n  ```\n\n- **Dataset Loading**: Read more on how to load your own dataset in our [documentation](https://docs.axolotl.ai/docs/dataset_loading.html).\n- **Dataset Formats**:\n  - For LFM2 models, the dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).\n  - For LFM2-VL models, Axolotl follows the multi-content Messages format. See our [Multimodal docs](https://docs.axolotl.ai/docs/multimodal.html#dataset-format) for details.\n\n## Optimization Guides\n\n- [Optimizations Guide](https://docs.axolotl.ai/docs/optimizations.html)\n\n## Related Resources\n\n- [LFM2 Blog](https://www.liquid.ai/blog/liquid-foundation-models-v2-our-second-series-of-generative-ai-models)\n- [LFM2-VL Blog](https://www.liquid.ai/blog/lfm2-vl-efficient-vision-language-models)\n- [LFM2-MoE Blog](https://www.liquid.ai/blog/lfm2-8b-a1b-an-efficient-on-device-mixture-of-experts)\n- [Axolotl Docs](https://docs.axolotl.ai)\n- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)\n- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)\n"
  },
  {
    "path": "examples/LiquidAI/lfm2-350m-fft.yaml",
    "content": "base_model: LiquidAI/LFM2-350M\n\nplugins:\n  - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin\n\neot_tokens:\n  - \"<|im_end|>\"\ndatasets:\n  - path: mlabonne/FineTome-100k\n    type: chat_template\n    split: train[:20%]\n    field_messages: conversations\n    message_field_role: from\n    message_field_content: value\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.05\noutput_dir: ./outputs/out\n\nsequence_len: 4096\nsample_packing: true\n\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 2\nmicro_batch_size: 4\nnum_epochs: 1\noptimizer: adamw_torch_fused\nlr_scheduler: cosine\nlearning_rate: 5e-5\n\nbf16: true\ntf32: true\n\ngradient_checkpointing: false\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 2\nsaves_per_epoch: 1\n\nweight_decay: 0.0\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/LiquidAI/lfm2-8b-a1b-lora.yaml",
    "content": "base_model: LiquidAI/LFM2-8B-A1B\n\nplugins:\n  - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin\n\nload_in_8bit: true\n\neot_tokens:\n  - \"<|im_end|>\"\ndatasets:\n  - path: mlabonne/FineTome-100k\n    type: chat_template\n    split: train[:20%]\n    field_messages: conversations\n    message_field_role: from\n    message_field_content: value\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.05\noutput_dir: ./outputs/out\n\nsequence_len: 4096\nsample_packing: true\n\nadapter: lora\nlora_model_dir:\n\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_modules: 'model.layers.[\\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 2\nmicro_batch_size: 4\nnum_epochs: 1\noptimizer: adamw_torch_fused\nlr_scheduler: cosine\nlearning_rate: 5e-5\n\nbf16: true\ntf32: true\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 2\nsaves_per_epoch: 1\n\nweight_decay: 0.0\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/LiquidAI/lfm2-vl-lora.yaml",
    "content": "base_model: LiquidAI/LFM2-VL-450M\ntrust_remote_code: true\nmodel_type: AutoModelForImageTextToText\nprocessor_type: AutoProcessor\n\nplugins:\n  - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin\n\n# these 3 lines are needed for now to handle vision chat templates w images\nskip_prepare_dataset: true\nremove_unused_columns: false\nsample_packing: false\n\ndatasets:\n  - path: HuggingFaceH4/llava-instruct-mix-vsft\n    type: chat_template\n    split: train[:1%]\n\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.0\noutput_dir: ./outputs/out\n\nadapter: lora\nlora_model_dir:\n\nsequence_len: 8192\npad_to_sequence_len: false\n\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_modules: 'model.language_model.layers.[\\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 1\nnum_epochs: 1\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: true\nfp16:\ntf32: true\n\ngradient_checkpointing: true\nlogging_steps: 1\nflash_attention: true\neager_attention:\n\nwarmup_ratio: 0.1\nevals_per_epoch: 1\nsaves_per_epoch: 1\nweight_decay: 0.0\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/alst/README.md",
    "content": "# Arctic Long Sequence Training (ALST)\n\nArtic Long Sequence Training (ALST) is a technique for training long context models using a variety of optimization\ntechniques. It is a combination of:\n- TiledMLP: Leverage tiling over the sequence dimension on MLP layers to reduce memory usage\n- Tiled Loss: Using optimized loss functions like Liger-Kernel or Cut Cross Entropy to reduce memory usage\n- Activation Offloading: Offload activations to CPU RAM to reduce memory usage\n\nFor more information, you can check out the ALST paper [here](https://www.arxiv.org/abs/2506.13996).\n\n## Usage\n\n```yaml\ntiled_mlp: true\n\n# See Sequence Parallelism docs\n# https://docs.axolotl.ai/docs/sequence_parallelism.html\ncontext_parallel_size: int\n\nplugins:\n# See Cut Cross Entropy docs\n# https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy\n  - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin\n\n# or Liger Kernel docs\n# https://docs.axolotl.ai/docs/custom_integrations.html#liger-kernels\n  - axolotl.integrations.liger.LigerPlugin\n# ...\n\n```\n"
  },
  {
    "path": "examples/alst/llama3-8b-deepspeed-alst.yaml",
    "content": "base_model: meta-llama/Llama-3.1-8B\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\ndatasets:\n  - path: togethercomputer/Long-Data-Collections\n    type: completion\n    field: text\n    data_files:\n      - pretrain/rp_sub.jsonl.zst\n  - path: princeton-nlp/TextbookChapters\n    type: completion\n    field: chapter\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.0\noutput_dir: ./outputs/out\n\nsequence_len: 500_000\nmin_sample_len: 200_000\nsample_packing: true\n\ntiled_mlp: true\ncontext_parallel_size: 8\nplugins:\n  - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin\n\ngradient_accumulation_steps: 1\nmicro_batch_size: 1\nnum_epochs: 1\noptimizer: adamw_torch_8bit\nlr_scheduler: cosine\nlearning_rate: 2e-5\n\nbf16: auto\ntf32: true\n\ngradient_checkpointing: true\nactivation_offloading: legacy\n\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_steps: 100\nsaves_per_epoch: 1\nevals_per_epoch: 2\nweight_decay: 0.0\nspecial_tokens:\n  pad_token: <|end_of_text|>\n\ndeepspeed: deepspeed_configs/zero3_bf16_cpuoffload_all.json\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/alst/llama3-8b-fsdp2-alst.yaml",
    "content": "base_model: meta-llama/Llama-3.1-8B\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\ndatasets:\n  - path: togethercomputer/Long-Data-Collections\n    type: completion\n    field: text\n    data_files:\n      - pretrain/rp_sub.jsonl.zst\n  - path: princeton-nlp/TextbookChapters\n    type: completion\n    field: chapter\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.0\noutput_dir: ./outputs/out\n\nsequence_len: 500_000\nmin_sample_len: 200_000\nsample_packing: true\n\ntiled_mlp: true\ncontext_parallel_size: 8\nplugins:\n  - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin\n\ngradient_accumulation_steps: 1\nmicro_batch_size: 1\nnum_epochs: 1\noptimizer: adamw_torch_8bit\nlr_scheduler: cosine\nlearning_rate: 2e-5\n\nbf16: auto\ntf32: true\n\ngradient_checkpointing: true\nactivation_offloading: legacy\n\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_steps: 100\nsaves_per_epoch: 1\nevals_per_epoch: 2\nweight_decay: 0.0\nspecial_tokens:\n  pad_token: <|end_of_text|>\n\nfsdp_version: 2\nfsdp_config:\n  offload_params: false  # offloading is currently not compatible with SP + torchao optimizer\n  state_dict_type: SHARDED_STATE_DICT\n  auto_wrap_policy: TRANSFORMER_BASED_WRAP\n  transformer_layer_cls_to_wrap: LlamaDecoderLayer\n  reshard_after_forward: true\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/apertus/README.md",
    "content": "# Finetune Swiss-AI's Apertus with Axolotl\n\n[Apertus](https://huggingface.co/collections/swiss-ai/apertus-llm-68b699e65415c231ace3b059) is a family of opensource models trained by Swiss-ai.\n\nThis guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.\n\n## Getting started\n\n1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Apertus is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).\n\n    Here is an example of how to install from main for pip:\n\n```bash\n# Ensure you have Pytorch installed (Pytorch 2.6.0 min)\ngit clone https://github.com/axolotl-ai-cloud/axolotl.git\ncd axolotl\n\npip3 install packaging==26.0 setuptools==75.8.0 wheel ninja\npip3 install --no-build-isolation -e '.[flash-attn]'\n\n# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy\npython scripts/cutcrossentropy_install.py | sh\n```\n\n2. (Optional, highly recommended) Install XIELU CUDA\n\n```bash\n## Recommended for reduced VRAM and faster speeds\n\n# Point to CUDA toolkit directory\n# For those using our Docker image, use the below path.\nexport CUDA_HOME=/usr/local/cuda\n\npip3 install git+https://github.com/nickjbrowning/XIELU@59d6031 --no-build-isolation --no-deps\n```\n\nFor any installation errors, see [XIELU Installation Issues](#xielu-installation-issues)\n\n3. Run the finetuning example:\n\n```bash\naxolotl train examples/apertus/apertus-8b-qlora.yaml\n```\n\nThis config uses about 8.7 GiB VRAM.\n\nLet us know how it goes. Happy finetuning! 🚀\n\n### Tips\n\n- For inference, the official Apertus team recommends `top_p=0.9` and `temperature=0.8`.\n- You can instead use full paremter fine-tuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.\n- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).\n- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).\n\n### XIELU Installation Issues\n\n#### `ModuleNotFoundError: No module named 'torch'`\n\nPlease check these one by one:\n- Running in correct environment\n- Env has PyTorch installed\n- CUDA toolkit is at `CUDA_HOME`\n\nIf those didn't help, please try the below solutions:\n\n1. Pass env for CMAKE and try install again:\n\n    ```bash\n    Python_EXECUTABLE=$(which python) pip3 install git+https://github.com/nickjbrowning/XIELU@59d6031 --no-build-isolation --no-deps\n    ```\n\n2. Git clone the repo and manually hardcode python path:\n\n    ```bash\n    git clone https://github.com/nickjbrowning/XIELU\n    cd xielu\n    git checkout 59d6031\n\n    cd xielu\n    nano CMakeLists.txt  # or vi depending on your preference\n    ```\n\n    ```diff\n    execute_process(\n    -    COMMAND ${Python_EXECUTABLE} -c \"import torch.utils; print(torch.utils.cmake_prefix_path)\"\n    +    COMMAND /root/miniconda3/envs/py3.11/bin/python -c \"import torch.utils; print(torch.utils.cmake_prefix_path)\"\n        RESULT_VARIABLE TORCH_CMAKE_PATH_RESULT\n        OUTPUT_VARIABLE TORCH_CMAKE_PATH_OUTPUT\n        ERROR_VARIABLE TORCH_CMAKE_PATH_ERROR\n    )\n    ```\n\n    ```bash\n    pip3 install . --no-build-isolation --no-deps\n    ```\n\n## Optimization Guides\n\n- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)\n- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)\n- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)\n\n## Related Resources\n\n- [Apertus Tech Report](https://github.com/swiss-ai/apertus-tech-report/blob/main/Apertus_Tech_Report.pdf)\n- [Axolotl Docs](https://docs.axolotl.ai)\n- [Axolotl Website](https://axolotl.ai)\n- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)\n- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)\n"
  },
  {
    "path": "examples/apertus/apertus-8b-qlora.yaml",
    "content": "base_model: swiss-ai/Apertus-8B-Instruct-2509\n\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nplugins:\n  - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin\n\nload_in_8bit: false\nload_in_4bit: true\n\ndatasets:\n  - path: fozziethebeat/alpaca_messages_2k_test\n    type: chat_template\n\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.1\noutput_dir: ./outputs/lora-out\n\nadapter: qlora\nlora_model_dir:\n\nsequence_len: 2048\nsample_packing: true\n\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_linear: true\nlora_target_modules:\n  - gate_proj\n  - down_proj\n  - up_proj\n  - q_proj\n  - v_proj\n  - k_proj\n  - o_proj\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 2\nnum_epochs: 1\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 1\nsaves_per_epoch: 1\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/arcee/README.md",
    "content": "# Finetune ArceeAI's AFM with Axolotl\n\n[Arcee Foundation Models (AFM)](https://huggingface.co/collections/arcee-ai/afm-45b-68823397c351603014963473) are a family of 4.5B parameter open weight models trained by Arcee.ai.\n\nThis guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.\n\nThanks to the team at Arcee.ai for using Axolotl in supervised fine-tuning the AFM model.\n\n## Getting started\n\n1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as AFM is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).\n\n    Here is an example of how to install from main for pip:\n\n```bash\n# Ensure you have Pytorch installed (Pytorch 2.6.0 min)\ngit clone https://github.com/axolotl-ai-cloud/axolotl.git\ncd axolotl\n\npip3 install packaging==26.0 setuptools==75.8.0 wheel ninja\npip3 install --no-build-isolation -e '.[flash-attn]'\n\n# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy\npython scripts/cutcrossentropy_install.py | sh\n```\n\n2. Run the finetuning example:\n\n```bash\naxolotl train examples/arcee/afm-4.5b-qlora.yaml\n```\n\nThis config uses about 7.8GiB VRAM.\n\nLet us know how it goes. Happy finetuning! 🚀\n\n### TIPS\n\n- For inference, the official Arcee.ai team recommends `top_p: 0.95`, `temperature: 0.5`, `top_k: 50`, and `repeat_penalty: 1.1`.\n- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.\n- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).\n- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).\n\n## Optimization Guides\n\n- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)\n- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)\n- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)\n\n## Related Resources\n\n- [AFM Blog](https://docs.arcee.ai/arcee-foundation-models/introduction-to-arcee-foundation-models)\n- [Axolotl Docs](https://docs.axolotl.ai)\n- [Axolotl Website](https://axolotl.ai)\n- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)\n- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)\n"
  },
  {
    "path": "examples/arcee/afm-4.5b-qlora.yaml",
    "content": "base_model: arcee-ai/AFM-4.5B\n\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nplugins:\n  - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin\n\nload_in_8bit: false\nload_in_4bit: true\n\ndatasets:\n  - path: fozziethebeat/alpaca_messages_2k_test\n    type: chat_template\n\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.1\noutput_dir: ./outputs/lora-out\n\nadapter: qlora\nlora_model_dir:\n\nsequence_len: 2048\nsample_packing: true\n\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_linear: true\nlora_target_modules:\n  - gate_proj\n  - down_proj\n  - up_proj\n  - q_proj\n  - v_proj\n  - k_proj\n  - o_proj\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 2\nnum_epochs: 1\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 1\nsaves_per_epoch: 1\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/archived/README.md",
    "content": "# Archived Examples\n\nThis directory contains examples that are no longer maintained and may no longer be functional.\n\nWe keep them around for archival purposes in case they are useful to others.\n"
  },
  {
    "path": "examples/archived/cerebras/btlm-ft.yml",
    "content": "base_model: cerebras/btlm-3b-8k-base\n# optionally might have model_type or tokenizer_type\nmodel_type: AutoModelForCausalLM\ntokenizer_type: GPT2Tokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\ntrust_remote_code: true\ntokenizer_use_fast: true\ntokenizer_legacy: true\npush_dataset_to_hub:\nhf_use_auth_token: true\ndatasets:\n  - path: mhenrichsen/alpaca_2k_test\n    type: alpaca\ndataset_prepared_path: last_prepared_run\nval_set_size: 0.05\n\nadapter:\nlora_model_dir:\nsequence_len: 2048\nmax_packed_sequence_len:\nsample_packing: false\nsample_packing_eff_est:\nsample_packing_seq_len_multiplier:\ntotal_num_tokens:\n\nlora_r:\nlora_alpha:\nlora_dropout:\nlora_target_modules:\nlora_target_linear:\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\noutput_dir: ./outputs/btlm-out\ngradient_accumulation_steps: 1\nmicro_batch_size: 1\nnum_epochs: 1\noptimizer: adamw_torch_fused\nadam_beta2: 0.95\nadam_eps: 0.000000001\nmax_grad_norm: 1.0\n\ntorchdistx_path:\nlr_scheduler: cosine\nlr_quadratic_warmup: true\nlearning_rate: 0.000085\ntrain_on_inputs: true\ngroup_by_length: false\nbf16: auto\ntf32: true\n\ngradient_checkpointing: false\nresume_from_checkpoint:\nlogging_steps: 1\n\nflash_attention: true\nsdp_attention:\nflash_optimum:\n\ngptq_groupsize:\ngptq_model_v1:\n\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\nsave_total_limit:\n\nweight_decay: 0.1\nspecial_tokens:\n  pad_token: \"<|endoftext|>\"\nfsdp:\n#  - full_shard\n#  - auto_wrap\nfsdp_config:\n#  fsdp_state_dict_type: FULL_STATE_DICT\n#  fsdp_transformer_layer_cls_to_wrap: BTLMBlock\n"
  },
  {
    "path": "examples/archived/cerebras/qlora.yml",
    "content": "base_model: cerebras/Cerebras-GPT-1.3B\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nload_in_8bit: false\nload_in_4bit: true\npush_dataset_to_hub:\ndatasets:\n  - path: teknium/GPT4-LLM-Cleaned\n    type: alpaca\ndataset_prepared_path:\nval_set_size: 0.05\nadapter: qlora\nlora_model_dir:\nsequence_len: 2048\nlora_r: 16\nlora_alpha: 32\nlora_dropout: 0.05\nlora_target_modules:\n  - c_fc\n  - c_attn\n  - c_proj\nlora_target_linear:\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\noutput_dir: ./outputs/qlora-out\nbatch_size: 4\nmicro_batch_size: 4\nnum_epochs: 2\noptimizer: paged_adamw_8bit\ntorchdistx_path:\nlr_scheduler: cosine\nlearning_rate: 0.0002\nbf16: auto\ntf32: true\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nxformers_attention: true\nflash_attention:\ngptq_groupsize:\ngptq_model_v1:\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\nweight_decay: 0.1\nspecial_tokens:\n  pad_token: \"<|endoftext|>\"\n"
  },
  {
    "path": "examples/archived/code-llama/13b/lora.yml",
    "content": "base_model: codellama/CodeLlama-13b-hf\n# optionally might have model_type or tokenizer_type\nmodel_type: LlamaForCausalLM\ntokenizer_type: CodeLlamaTokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nload_in_8bit: true\nload_in_4bit: false\n\ndatasets:\n  - path: mhenrichsen/alpaca_2k_test\n    type: alpaca\ndataset_prepared_path:\nval_set_size: 0.05\noutput_dir: ./outputs/lora-out\n\nsequence_len: 4096\nsample_packing: true\n\n\nadapter: lora\nlora_model_dir:\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_linear: true\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 2\nnum_epochs: 4\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n  bos_token: \"<s>\"\n  eos_token: \"</s>\"\n  unk_token: \"<unk>\"\n"
  },
  {
    "path": "examples/archived/code-llama/13b/qlora.yml",
    "content": "base_model: codellama/CodeLlama-13b-hf\n# optionally might have model_type or tokenizer_type\nmodel_type: LlamaForCausalLM\ntokenizer_type: CodeLlamaTokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nload_in_8bit: false\nload_in_4bit: true\n\ndatasets:\n  - path: mhenrichsen/alpaca_2k_test\n    type: alpaca\ndataset_prepared_path:\nval_set_size: 0.05\noutput_dir: ./outputs/qlora-out\n\nadapter: qlora\nlora_model_dir:\n\nsequence_len: 4096\nsample_packing: true\n\n\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_linear: true\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 2\nnum_epochs: 4\noptimizer: paged_adamw_32bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n  bos_token: \"<s>\"\n  eos_token: \"</s>\"\n  unk_token: \"<unk>\"\n"
  },
  {
    "path": "examples/archived/code-llama/34b/lora.yml",
    "content": "base_model: codellama/CodeLlama-34b-hf\n# optionally might have model_type or tokenizer_type\nmodel_type: LlamaForCausalLM\ntokenizer_type: CodeLlamaTokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nload_in_8bit: true\nload_in_4bit: false\n\ndatasets:\n  - path: mhenrichsen/alpaca_2k_test\n    type: alpaca\ndataset_prepared_path:\nval_set_size: 0.05\noutput_dir: ./outputs/lora-out\n\nsequence_len: 4096\nsample_packing: true\n\n\nadapter: lora\nlora_model_dir:\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_linear: true\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 2\nnum_epochs: 4\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n  bos_token: \"<s>\"\n  eos_token: \"</s>\"\n  unk_token: \"<unk>\"\n"
  },
  {
    "path": "examples/archived/code-llama/34b/qlora.yml",
    "content": "base_model: codellama/CodeLlama-34b-hf\n# optionally might have model_type or tokenizer_type\nmodel_type: LlamaForCausalLM\ntokenizer_type: CodeLlamaTokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nload_in_8bit: false\nload_in_4bit: true\n\ndatasets:\n  - path: mhenrichsen/alpaca_2k_test\n    type: alpaca\ndataset_prepared_path:\nval_set_size: 0.05\noutput_dir: ./outputs/qlora-out\n\nadapter: qlora\nlora_model_dir:\n\nsequence_len: 4096\nsample_packing: true\n\n\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_linear: true\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 2\nnum_epochs: 4\noptimizer: paged_adamw_32bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n  bos_token: \"<s>\"\n  eos_token: \"</s>\"\n  unk_token: \"<unk>\"\n"
  },
  {
    "path": "examples/archived/code-llama/7b/lora.yml",
    "content": "base_model: codellama/CodeLlama-7b-hf\n# optionally might have model_type or tokenizer_type\nmodel_type: LlamaForCausalLM\ntokenizer_type: CodeLlamaTokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nload_in_8bit: true\nload_in_4bit: false\n\ndatasets:\n  - path: mhenrichsen/alpaca_2k_test\n    type: alpaca\ndataset_prepared_path:\nval_set_size: 0.05\noutput_dir: ./outputs/lora-out\n\nsequence_len: 4096\nsample_packing: true\n\n\nadapter: lora\nlora_model_dir:\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_linear: true\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 2\nnum_epochs: 4\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n  bos_token: \"<s>\"\n  eos_token: \"</s>\"\n  unk_token: \"<unk>\"\n"
  },
  {
    "path": "examples/archived/code-llama/7b/qlora.yml",
    "content": "base_model: codellama/CodeLlama-7b-hf\n# optionally might have model_type or tokenizer_type\nmodel_type: LlamaForCausalLM\ntokenizer_type: CodeLlamaTokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nload_in_8bit: false\nload_in_4bit: true\n\ndatasets:\n  - path: mhenrichsen/alpaca_2k_test\n    type: alpaca\ndataset_prepared_path:\nval_set_size: 0.05\noutput_dir: ./outputs/qlora-out\n\nadapter: qlora\nlora_model_dir:\n\nsequence_len: 4096\nsample_packing: true\n\n\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_linear: true\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 2\nnum_epochs: 4\noptimizer: paged_adamw_32bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n  bos_token: \"<s>\"\n  eos_token: \"</s>\"\n  unk_token: \"<unk>\"\n"
  },
  {
    "path": "examples/archived/code-llama/README.md",
    "content": "# Overview\n\nThis is an example of CodeLLaMA configuration for 7b, 13b and 34b.\n\nThe 7b variant fits on any 24GB VRAM GPU and will take up about 17 GB of VRAM during training if using qlora and 20 GB if using lora. On a RTX 4090 it trains 3 epochs of the default dataset in about 15 minutes.\n\nThe 13b variant will fit if you change these settings to these values:\ngradient_accumulation_steps: 2\nmicro_batch_size: 1\n\nThe 34b variant does not fit on 24GB of VRAM - you will need something with +40 gb VRAM that also supports flash attention v2 - A6000 or A100 are good choices.\n\n```shell\naccelerate launch scripts/finetune.py examples/code-llama/[MODEL_SIZE]/qlora.yml\n\n```\nor\n\n```shell\naccelerate launch scripts/finetune.py examples/code-llama/[MODEL_SIZE]/lora.yml\n\n```\n"
  },
  {
    "path": "examples/archived/dbrx/16bit-lora.yaml",
    "content": "base_model: LnL-AI/dbrx-base-converted-v2\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\ntrust_remote_code: true\n\ndatasets:\n  - path: tatsu-lab/alpaca\n    type: alpaca\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.0\noutput_dir: ./outputs/out\n\nsequence_len: 512\nsample_packing: false\npad_to_sequence_len: false\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\nadapter: lora\nlora_model_dir:\nlora_r: 8\nlora_alpha: 16\nlora_dropout: 0.05\n# w1, w2, & v1 will hang the trainer\nlora_target_modules:\n  - q_proj # attn\n  - k_proj # attn\n  - v_proj # attn\n  - out_proj # attn\n  - layer # router\n#  - w1\n#  - w2\n#  - v1\n\ngradient_accumulation_steps: 1\nmicro_batch_size: 1\nnum_epochs: 1\noptimizer: paged_adamw_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: false  # don't use with fsdp_activation_checkpointing\ngradient_checkpointing_kwargs:\n  use_reentrant: false\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch:\nsaves_per_epoch: 1\n\nweight_decay: 0.0\nfsdp:\n  - full_shard\n  - auto_wrap\nfsdp_config:\n  fsdp_limit_all_gathers: true\n  fsdp_sync_module_states: true\n  fsdp_offload_params: false\n  fsdp_use_orig_params: false\n  fsdp_cpu_ram_efficient_loading: true\n  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP\n  fsdp_transformer_layer_cls_to_wrap: DbrxBlock\n  fsdp_state_dict_type: FULL_STATE_DICT\n  fsdp_activation_checkpointing: true\n"
  },
  {
    "path": "examples/archived/dbrx/8bit-lora.yaml",
    "content": "base_model: LnL-AI/dbrx-base-converted-v2\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\ntrust_remote_code: true\n\nload_in_8bit: true\nload_in_4bit: false\n\ndatasets:\n  - path: tatsu-lab/alpaca\n    type: alpaca\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.0\noutput_dir: ./outputs/out\n\nsequence_len: 512\nsample_packing: false\npad_to_sequence_len: false\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\nadapter: lora\nlora_model_dir:\nlora_r: 8\nlora_alpha: 16\nlora_dropout: 0.05\n# w1, w2, & v1 will hang the trainer\nlora_target_modules:\n  - q_proj # attn\n  - k_proj # attn\n  - v_proj # attn\n  - out_proj # attn\n  - layer # router\n#  - w1\n#  - w2\n#  - v1\n\ngradient_accumulation_steps: 1\nmicro_batch_size: 1\nnum_epochs: 1\noptimizer: paged_adamw_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: false  # don't use with fsdp_activation_checkpointing\ngradient_checkpointing_kwargs:\n  use_reentrant: false\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch:\nsaves_per_epoch: 1\n\nweight_decay: 0.0\nfsdp:\n  - full_shard\n  - auto_wrap\nfsdp_config:\n  fsdp_limit_all_gathers: true\n  fsdp_sync_module_states: true\n  fsdp_offload_params: false\n  fsdp_use_orig_params: false\n  fsdp_cpu_ram_efficient_loading: true\n  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP\n  fsdp_transformer_layer_cls_to_wrap: DbrxBlock\n  fsdp_state_dict_type: FULL_STATE_DICT\n  fsdp_activation_checkpointing: true\n"
  },
  {
    "path": "examples/archived/dbrx/README.md",
    "content": "# DBRX MoE\n\nCurrently, for LoRA, only the `q_proj`, `k_proj`, `v_proj` `out_proj` and `layer` Linear layers are trainable.\n\nWe are using the \"converted\" base models based on [this issue](https://huggingface.co/databricks/dbrx-instruct/discussions/10)\nwhere the Experts are fused as an `nn.Parameter` rather than a `nn.Linear` layer. However, the implementation\nis still a bit buggy and attempting to train a LoRA adapter over those `w1`, `w2` and `v1` layers\nresults in the trainer hanging.\n\n\n### FSDP\nWe've tested using the [`LnL-AI/dbrx-base-converted-v2`](https://huggingface.co/LnL-AI/dbrx-base-converted-v2) model as the base model for FSDP.\n\nThe high memory usage seen w/ FSDP is due to FSDP not supporting 8bit optimizers.\n\n- 16-bit LoRA w/ FSDP\n  - ✅ w/o CPU Offload - 8x80GB uses ~80GiB/gpu\n  - ❌ w/ CPU Offload - `paged_adamw_8bit` optimizer errors from being on cpu\n- ✅ 8-bit LoRA w/ FSDP\n- ❌ 4-bit QLoRA w/ FSDP - errors w/: `Error an illegal memory access was encountered at line 90 in file /src/csrc/ops.cu`\n- ✅ bf16 full finetune w/ FSDP, freezing all but first 8 layers (8x80GB uses ~78GiB/gpu)\n\n\n### Deepspeed\n\nWIP\n"
  },
  {
    "path": "examples/archived/dbrx/fft-ds-zero3.yaml",
    "content": "base_model: LnL-AI/dbrx-base-converted-v2\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\ntrust_remote_code: true\n\ndatasets:\n  - path: tatsu-lab/alpaca\n    type: alpaca\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.0\noutput_dir: ./outputs/out\n\nsequence_len: 512\nsample_packing: false\npad_to_sequence_len: false\n\nunfrozen_parameters:\n  - transformer.blocks.[0-7].\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 1\nmicro_batch_size: 1\nnum_epochs: 1\noptimizer: paged_adamw_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\ngradient_checkpointing_kwargs:\n  use_reentrant: false\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch:\nsaves_per_epoch: 1\n\nweight_decay: 0.0\ndeepspeed: deepspeed_configs/zero3_bf16.json\n"
  },
  {
    "path": "examples/archived/deepcoder/deepcoder-14B-preview-lora.yml",
    "content": "base_model: agentica-org/DeepCoder-14B-Preview\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nload_in_8bit: true\nload_in_4bit: false\nstrict: false\n\ndatasets:\n  - path: fozziethebeat/alpaca_messages_2k_test\n    type: chat_template\n\ndataset_prepared_path:\nval_set_size: 0.05\noutput_dir: ./outputs/lora-out\n\nsequence_len: 4096\nsample_packing: true\neval_sample_packing: false\n\n\nadapter: lora\nlora_model_dir:\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_linear: true\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 2\nmicro_batch_size: 2\nnum_epochs: 4\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: true\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 1\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n"
  },
  {
    "path": "examples/archived/falcon/config-7b-lora.yml",
    "content": "base_model: tiiuae/falcon-7b\n# optionally might have model_type or tokenizer_type\nmodel_type: AutoModelForCausalLM\ntokenizer_type: AutoTokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\n# required by falcon custom model code: https://huggingface.co/tiiuae/falcon-7b/tree/main\ntrust_remote_code: true\n\nload_in_8bit: true\nload_in_4bit: false\ngptq: false\npush_dataset_to_hub:\ndatasets:\n  - path: teknium/GPT4-LLM-Cleaned\n    type: alpaca:chat\ndataset_prepared_path:\nval_set_size: 0.05\nadapter: lora\nlora_model_dir:\nsequence_len: 2048\nmax_packed_sequence_len:\nlora_r: 16\nlora_alpha: 32\nlora_dropout: 0.0\nlora_target_linear: true\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\noutput_dir: ./outputs/falcon-7b\nbatch_size: 2\nmicro_batch_size: 1\nnum_epochs: 4\noptimizer: adamw_bnb_8bit\ntorchdistx_path:\nlr_scheduler: cosine\nlearning_rate: 0.00003\nbf16: auto\ntf32: true\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nxformers_attention: true\nflash_attention:\ngptq_groupsize:\ngptq_model_v1:\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n  pad_token: \"<|endoftext|>\"\n  bos_token: \"<|endoftext|>\"\n  eos_token: \"<|endoftext|>\"\n"
  },
  {
    "path": "examples/archived/falcon/config-7b-qlora.yml",
    "content": "# 1b: tiiuae/falcon-rw-1b\n# 40b: tiiuae/falcon-40b\nbase_model: tiiuae/falcon-7b\n# optionally might have model_type or tokenizer_type\nmodel_type: AutoModelForCausalLM\ntokenizer_type: AutoTokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\n# required by falcon custom model code: https://huggingface.co/tiiuae/falcon-7b/tree/main\ntrust_remote_code: true\n\n\nload_in_8bit: false\n# enable 4bit for QLoRA\nload_in_4bit: true\ngptq: false\npush_dataset_to_hub:\ndatasets:\n  - path: QingyiSi/Alpaca-CoT\n    data_files:\n      - Chain-of-Thought/formatted_cot_data/gsm8k_train.json\n    type: \"alpaca:chat\"\ndataset_prepared_path:\nval_set_size: 0.05\n# enable QLoRA\nadapter: qlora\nlora_model_dir:\nsequence_len: 2048\nmax_packed_sequence_len:\n\n# hyperparameters from QLoRA paper Appendix B.2\n# \"We find hyperparameters to be largely robust across datasets\"\nlora_r: 64\nlora_alpha: 16\n# 0.1 for models up to 13B\n# 0.05 for 33B and 65B models\nlora_dropout: 0.05\n# add LoRA modules on all linear layers of the base model\nlora_target_linear: true\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\noutput_dir: ./outputs/qlora-out\n\n# QLoRA paper Table 9\n# - 16 for 7b & 13b\n# - 32 for 33b, 64 for 64b\n# Max size tested on A6000\n# - 7b: 40\n# - 40b: 4\n# decrease if OOM, increase for max VRAM utilization\nmicro_batch_size: 1\ngradient_accumulation_steps: 2\nnum_epochs: 4\n# Optimizer for QLoRA\noptimizer: paged_adamw_32bit\ntorchdistx_path:\nlr_scheduler: cosine\n# QLoRA paper Table 9\n# - 2e-4 for 7b & 13b\n# - 1e-4 for 33b & 64b\nlearning_rate: 0.0002\nbf16: auto\ntf32: true\ngradient_checkpointing: true\n# stop training after this many evaluation losses have increased in a row\n# https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback\nearly_stopping_patience: 3\nresume_from_checkpoint:\nauto_resume_from_checkpoints: true\nlogging_steps: 1\nxformers_attention: true\nflash_attention:\ngptq_groupsize:\ngptq_model_v1:\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\nweight_decay: 0.000001\nspecial_tokens:\n  pad_token: \"<|endoftext|>\"\n  bos_token: \"<|endoftext|>\"\n  eos_token: \"<|endoftext|>\"\n"
  },
  {
    "path": "examples/archived/falcon/config-7b.yml",
    "content": "base_model: tiiuae/falcon-7b\n# optionally might have model_type or tokenizer_type\nmodel_type: AutoModelForCausalLM\ntokenizer_type: AutoTokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\n# required by falcon custom model code: https://huggingface.co/tiiuae/falcon-7b/tree/main\ntrust_remote_code: true\ngptq: false\npush_dataset_to_hub:\ndatasets:\n  - path: teknium/GPT4-LLM-Cleaned\n    type: alpaca:chat\ndataset_prepared_path:\nval_set_size: 0.05\nadapter:\nlora_model_dir:\nsequence_len: 2048\nmax_packed_sequence_len:\nlora_r: 64\nlora_alpha: 32\nlora_dropout: 0.0\nlora_target_linear: true\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\noutput_dir: ./outputs/falcon-7b\nbatch_size: 2\nmicro_batch_size: 1\nnum_epochs: 4\noptimizer: adamw_bnb_8bit\ntorchdistx_path:\nlr_scheduler: cosine\nlearning_rate: 0.00003\nbf16: auto\ntf32: true\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nxformers_attention: true\nflash_attention:\ngptq_groupsize:\ngptq_model_v1:\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n  pad_token: \"<|endoftext|>\"\n  bos_token: \"<|endoftext|>\"\n  eos_token: \"<|endoftext|>\"\n"
  },
  {
    "path": "examples/archived/gemma/qlora.yml",
    "content": "# use google/gemma-7b if you have access\nbase_model: mhenrichsen/gemma-7b\n# optionally might have model_type or tokenizer_type\nmodel_type: AutoModelForCausalLM\ntokenizer_type: AutoTokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nload_in_8bit: false\nload_in_4bit: true\n\n# huggingface repo\ndatasets:\n  - path: mhenrichsen/alpaca_2k_test\n    type: alpaca\nval_set_size: 0.1\noutput_dir: ./outputs/out\n\nadapter: qlora\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_linear: true\n\nsequence_len: 4096\nsample_packing: true\neval_sample_packing: false\n\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\n\ngradient_accumulation_steps: 3\nmicro_batch_size: 2\nnum_epochs: 4\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n"
  },
  {
    "path": "examples/archived/gptj/qlora.yml",
    "content": "base_model: EleutherAI/gpt-j-6b\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nload_in_8bit: false\nload_in_4bit: true\npush_dataset_to_hub:\ndatasets:\n  - path: teknium/GPT4-LLM-Cleaned\n    type: alpaca\ndataset_prepared_path:\nval_set_size: 0.05\nadapter: qlora\nlora_model_dir:\nsequence_len: 2048\nmax_packed_sequence_len:\nlora_r: 8\nlora_alpha: 32\nlora_dropout: 0.05\nlora_target_linear: true\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\noutput_dir: ./outputs/qlora-out\ngradient_accumulation_steps: 2\nmicro_batch_size: 2\nnum_epochs: 2\noptimizer: paged_adamw_8bit\ntorchdistx_path:\nlr_scheduler: cosine\nlearning_rate: 0.0001\nbf16: auto\ntf32: true\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nxformers_attention: true\nflash_attention:\ngptq_groupsize:\ngptq_model_v1:\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\nweight_decay: 0.1\nspecial_tokens:\n  pad_token: \"<|endoftext|>\"\n"
  },
  {
    "path": "examples/archived/jeopardy-bot/config.yml",
    "content": "base_model: huggyllama/llama-7b\n# optionally might have model_type or tokenizer_type\nmodel_type: LlamaForCausalLM\ntokenizer_type: LlamaTokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nload_in_8bit: false\ndatasets:\n  - path: openaccess-ai-collective/jeopardy\n    type: jeopardy\ndataset_prepared_path:\nval_set_size: 0.02\nadapter:\nlora_model_dir:\nsequence_len: 512\nmax_packed_sequence_len:\nlora_r:\nlora_alpha:\nlora_dropout:\nlora_target_modules:\nlora_fan_in_fan_out: false\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\noutput_dir: ./outputs/jeopardy-bot-7b\ngradient_accumulation_steps: 1\nmicro_batch_size: 1\nnum_epochs: 4\noptimizer: adamw_bnb_8bit\ntorchdistx_path:\nlr_scheduler: cosine\nlearning_rate: 0.00003\nbf16: auto\ntf32: true\nresume_from_checkpoint:\nlogging_steps: 5\nxformers_attention: true\nflash_attention:\ngptq_groupsize:\ngptq_model_v1:\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\nweight_decay: 0.1\ntokens:\n  bos_token: \"<s>\"\n  eos_token: \"</s>\"\n  unk_token: \"<unk>\"\n"
  },
  {
    "path": "examples/archived/mpt-7b/README.md",
    "content": "# MPT-7B\n\n```shell\naccelerate launch scripts/finetune.py examples/mpt-7b/config.yml\n\n```\n"
  },
  {
    "path": "examples/archived/mpt-7b/config.yml",
    "content": "base_model: mosaicml/mpt-7b\n# optionally might have model_type or tokenizer_type\ntokenizer_type: AutoTokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\ntrust_remote_code: true  # required for mpt as their model class is not merged into transformers yet\nload_in_8bit: false\ndatasets:\n  - path: vicgalle/alpaca-gpt4\n    type: alpaca\ndataset_prepared_path:\nval_set_size: 0.02\nadapter:\nlora_model_dir:\nsequence_len: 2048\nmax_packed_sequence_len:\nlora_r: 8\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_modules:\n  - q_proj\n  - v_proj\nlora_fan_in_fan_out: false\nwandb_project: mpt-alpaca-7b\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\noutput_dir: ./outputs/mpt-alpaca-7b\ngradient_accumulation_steps: 1\nmicro_batch_size: 1\nnum_epochs: 4\noptimizer: adamw_bnb_8bit\ntorchdistx_path:\nlr_scheduler: cosine\nlearning_rate: 0.0000002\nbf16: auto\ntf32: true\nresume_from_checkpoint:\nlogging_steps: 5\nflash_attention:\ngptq_groupsize:\ngptq_model_v1:\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\nweight_decay: 0.0001\ntokens:\n  pad_token: \"<|padding|>\"\n  bos_token: \"<|endoftext|>\"\n  eos_token: \"<|endoftext|>\"\n  unk_token: \"<|endoftext|>\"\n"
  },
  {
    "path": "examples/archived/openllama-3b/README.md",
    "content": "# openllama-3b\n\nBasic full tune\n```shell\naccelerate launch scripts/finetune.py examples/openllama-3b/config.yml\n```\n\nLoRA\n```shell\naccelerate launch scripts/finetune.py examples/openllama-3b/lora.yml\n```\n\nQLoRA\n```shell\naccelerate launch scripts/finetune.py examples/openllama-3b/qlora.yml\n```\n"
  },
  {
    "path": "examples/archived/openllama-3b/config.yml",
    "content": "base_model: openlm-research/open_llama_3b_v2\n# optionally might have model_type or tokenizer_type\nmodel_type: LlamaForCausalLM\ntokenizer_type: LlamaTokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\npush_dataset_to_hub:\ndatasets:\n  - path: teknium/GPT4-LLM-Cleaned\n    type: alpaca\ndataset_prepared_path:\nval_set_size: 0.02\nadapter:\nlora_model_dir:\nsequence_len: 1024\nsample_packing: true\nlora_r:\nlora_alpha:\nlora_dropout:\nlora_target_modules:\nlora_target_linear:\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\noutput_dir: ./outputs/openllama-out\ngradient_accumulation_steps: 1\nmicro_batch_size: 1\nnum_epochs: 4\noptimizer: adamw_bnb_8bit\ntorchdistx_path:\nlr_scheduler: cosine\nlearning_rate: 0.000003\nfloat16: true\nbf16: false\nfp16: false\ntf32: false\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\ngptq_groupsize:\ngptq_model_v1:\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\nweight_decay: 0.1\nspecial_tokens:\n  bos_token: \"<s>\"\n  eos_token: \"</s>\"\n  unk_token: \"<unk>\"\n"
  },
  {
    "path": "examples/archived/openllama-3b/lora.yml",
    "content": "base_model: openlm-research/open_llama_3b_v2\n# optionally might have model_type or tokenizer_type\nmodel_type: LlamaForCausalLM\ntokenizer_type: LlamaTokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nload_in_8bit: true\nload_in_4bit: false\npush_dataset_to_hub:\ndatasets:\n  - path: teknium/GPT4-LLM-Cleaned\n    type: alpaca\ndataset_prepared_path:\nval_set_size: 0.02\nadapter: lora\nlora_model_dir:\nsequence_len: 1024\nsample_packing: true\nlora_r: 8\nlora_alpha: 16\nlora_dropout: 0.0\nlora_target_modules:\n  - gate_proj\n  - down_proj\n  - up_proj\n  - q_proj\n  - v_proj\n  - k_proj\n  - o_proj\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\noutput_dir: ./outputs/lora-out\ngradient_accumulation_steps: 1\nmicro_batch_size: 2\nnum_epochs: 4\noptimizer: adamw_bnb_8bit\ntorchdistx_path:\nlr_scheduler: cosine\nlearning_rate: 0.0002\nbf16: false\nfp16: true\ntf32: false\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\ngptq_groupsize:\ngptq_model_v1:\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\nweight_decay: 0.1\nspecial_tokens:\n  bos_token: \"<s>\"\n  eos_token: \"</s>\"\n  unk_token: \"<unk>\"\n"
  },
  {
    "path": "examples/archived/openllama-3b/qlora.yml",
    "content": "base_model: openlm-research/open_llama_3b_v2\n# optionally might have model_type or tokenizer_type\nmodel_type: LlamaForCausalLM\ntokenizer_type: LlamaTokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nload_in_8bit: false\nload_in_4bit: true\npush_dataset_to_hub:\ndatasets:\n  - path: teknium/GPT4-LLM-Cleaned\n    type: alpaca\ndataset_prepared_path:\nval_set_size: 0.05\nadapter: qlora\nlora_model_dir:\nsequence_len: 1024\nsample_packing: true\nlora_r: 8\nlora_alpha: 32\nlora_dropout: 0.05\nlora_target_linear: true\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\noutput_dir: ./outputs/qlora-out\ngradient_accumulation_steps: 1\nmicro_batch_size: 2\nnum_epochs: 4\noptimizer: paged_adamw_32bit\ntorchdistx_path:\nlr_scheduler: cosine\nlearning_rate: 0.0002\nbf16: false\nfp16: true\ntf32: false\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\ngptq_groupsize:\ngptq_model_v1:\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\nweight_decay: 0.1\nspecial_tokens:\n  bos_token: \"<s>\"\n  eos_token: \"</s>\"\n  unk_token: \"<unk>\"\n"
  },
  {
    "path": "examples/archived/pythia/lora.yml",
    "content": "base_model: EleutherAI/pythia-1.4b-deduped\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nload_in_8bit: true\ndatasets:\n  - path: teknium/GPT4-LLM-Cleaned\n    type: alpaca\ndataset_prepared_path:\nval_set_size: 0.05\nadapter: lora\nlora_model_dir:\nsequence_len: 512\nlora_r: 16\nlora_alpha: 32\nlora_dropout: 0.05\nlora_target_modules:\n  - query_key_value\nlora_target_linear:\nlora_fan_in_fan_out: true  # pythia/GPTNeoX lora specific\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\noutput_dir: ./outputs/lora-alpaca-pythia\ngradient_accumulation_steps: 1\nmicro_batch_size: 4\nnum_epochs: 4\nlearning_rate: 0.00001\nbf16: auto\ntf32: true\nresume_from_checkpoint:\nweight_decay: 0.1\nevals_per_epoch: 4\nlogging_steps: 1\n"
  },
  {
    "path": "examples/archived/pythia-12b/README.md",
    "content": "# Pythia 12B\n\n- Single-GPU A100 only (?)\n\n```shell\npython scripts/finetune.py examples/pythia-12b/config.yml\n```\n\n⚠️ Multiple-GPU A100 - Doesn't seem to work with multi-gpu without causing OOM! ⚠️\n"
  },
  {
    "path": "examples/archived/pythia-12b/config.yml",
    "content": "base_model: EleutherAI/pythia-12b-deduped\nbase_model_ignore_patterns: pytorch*  # prefer safetensors\n# optionally might have model_type or tokenizer_type\nmodel_type: GPTNeoXForCausalLM\ntokenizer_type: AutoTokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\ngptq: false\ndevice_map: auto\ndatasets:\n  - path: vicgalle/alpaca-gpt4\n    type: alpaca\ndataset_prepared_path:\nval_set_size: 0.05\nadapter:\nlora_model_dir:\nsequence_len: 2048\nmax_packed_sequence_len: 2048\nlora_r: 64\nlora_alpha: 32\nlora_dropout: 0.0\nlora_target_linear: true\nlora_fan_in_fan_out: true  # pythia/GPTNeoX lora specific\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\noutput_dir: ./outputs/pythia-12b\ngradient_accumulation_steps: 1\nmicro_batch_size: 1\nnum_epochs: 5\nlearning_rate: 0.00003\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nbf16: false\nfp16: false\nfloat16: true\ntf32: true\nflash_optimum: true\nresume_from_checkpoint:\ngradient_checkpointing: true\n"
  },
  {
    "path": "examples/archived/qwen/README.md",
    "content": "# Qwen\n\nTODO\n\n# Qwen2 MoE\n\n✅ multipack\n✅ qwen2_moe 4-bit QLoRA\n✅ qwen2_moe 16-bit LoRA\n❓ qwen2_moe 8-bit LoRA\n"
  },
  {
    "path": "examples/archived/qwen/lora.yml",
    "content": "base_model: Qwen/Qwen-7B\n# optionally might have model_type or tokenizer_type\nmodel_type: AutoModelForCausalLM\ntokenizer_type: AutoTokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\ntrust_remote_code: true\n\nload_in_8bit: true\nload_in_4bit: false\n\ndatasets:\n  - path: mhenrichsen/alpaca_2k_test\n    type: alpaca\ndataset_prepared_path:\nval_set_size: 0.05\noutput_dir: ./outputs/lora-out\n\nsequence_len: 2048  # supports up to 8192\nsample_packing: false\npad_to_sequence_len:\n\nadapter: lora\nlora_model_dir:\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_linear: true\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 2\nnum_epochs: 4\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: false\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention:\n\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n"
  },
  {
    "path": "examples/archived/qwen/qlora.yml",
    "content": "base_model: Qwen/Qwen-7B\n# optionally might have model_type or tokenizer_type\nmodel_type: AutoModelForCausalLM\ntokenizer_type: AutoTokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\ntrust_remote_code: true\n\nload_in_8bit: false\nload_in_4bit: true\n\ndatasets:\n  - path: mhenrichsen/alpaca_2k_test\n    type: alpaca\ndataset_prepared_path:\nval_set_size: 0.05\noutput_dir: ./outputs/lora-out\n\nsequence_len: 2048  # supports up to 8192\nsample_packing: false\npad_to_sequence_len:\n\nadapter: qlora\nlora_model_dir:\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_linear: true\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 2\nnum_epochs: 4\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: false\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention:\n\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n"
  },
  {
    "path": "examples/archived/qwen/qwen2-moe-lora.yaml",
    "content": "base_model: Qwen/Qwen1.5-MoE-A2.7B\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\ntrust_remote_code: true\n\ndatasets:\n  - path: mhenrichsen/alpaca_2k_test\n    type: alpaca\ndataset_prepared_path:\nval_set_size: 0.05\noutput_dir: ./outputs/out\n\nsequence_len: 1024  # supports up to 32k\nsample_packing: false\npad_to_sequence_len: false\n\nadapter: lora\nlora_model_dir:\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_linear: true\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 1\nnum_epochs: 4\noptimizer: paged_adamw_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: true\n\ngradient_checkpointing: true\ngradient_checkpointing_kwargs:\n  use_reentrant: false\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n"
  },
  {
    "path": "examples/archived/qwen/qwen2-moe-qlora.yaml",
    "content": "base_model: Qwen/Qwen1.5-MoE-A2.7B\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\ntrust_remote_code: true\n\nload_in_8bit: false\nload_in_4bit: true\n\ndatasets:\n  - path: mhenrichsen/alpaca_2k_test\n    type: alpaca\ndataset_prepared_path:\nval_set_size: 0.05\noutput_dir: ./outputs/out\n\nsequence_len: 1024  # supports up to 32k\nsample_packing: false\npad_to_sequence_len: false\n\nadapter: qlora\nlora_model_dir:\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_linear: true\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 1\nnum_epochs: 4\noptimizer: paged_adamw_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: true\n\ngradient_checkpointing: true\ngradient_checkpointing_kwargs:\n  use_reentrant: false\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n"
  },
  {
    "path": "examples/archived/redpajama/README.md",
    "content": "# RedPajama 3B preview release\n\n```shell\naccelerate launch scripts/finetune.py examples/redpajama/config-3b.yml\n\n```\n"
  },
  {
    "path": "examples/archived/redpajama/config-3b.yml",
    "content": "base_model: togethercomputer/RedPajama-INCITE-Chat-3B-v1\n# optionally might have model_type or tokenizer_type\nmodel_type: GPTNeoXForCausalLM\ntokenizer_type: AutoTokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\ntrust_remote_code:\nload_in_8bit: false\ndatasets:\n  - path: vicgalle/alpaca-gpt4\n    type: alpaca\ndataset_prepared_path:\nval_set_size: 0.02\nadapter:\nlora_model_dir:\nsequence_len: 2048\nmax_packed_sequence_len:\nlora_r: 8\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_modules:\n  - q_proj\n  - v_proj\nlora_fan_in_fan_out: false\nwandb_project: redpajama-alpaca-3b\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\noutput_dir: ./outputs/redpajama-alpaca-3b\nbatch_size: 4\nmicro_batch_size: 1\nnum_epochs: 4\noptimizer: adamw_bnb_8bit\ntorchdistx_path:\nlr_scheduler: cosine\nlearning_rate: 0.0000002\nbf16: auto\ntf32: true\nresume_from_checkpoint:\nlogging_steps: 5\nflash_attention:\ngptq_groupsize:\ngptq_model_v1:\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\nweight_decay: 0.0001\ntokens:\n  pad_token: \"<|padding|>\"\n  bos_token: \"<|endoftext|>\"\n  eos_token: \"<|endoftext|>\"\n  unk_token: \"<|endoftext|>\"\n"
  },
  {
    "path": "examples/archived/replit-3b/config-lora.yml",
    "content": "base_model: replit/replit-code-v1-3b\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\ntrust_remote_code: true\nload_in_8bit: false\ndatasets:\n  - path: vicgalle/alpaca-gpt4\n    type: alpaca\ndataset_prepared_path:\nval_set_size: 0.05\nadapter: lora\nlora_model_dir:\nsequence_len: 2048\nmax_packed_sequence_len:\nlora_r: 8\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_modules:\n  - Wqkv\n  - mlp_up\n  - mlp_down\nwandb_project: lora-replit\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\noutput_dir: ./outputs/lora-replit\nbatch_size: 8\nmicro_batch_size: 1\nnum_epochs: 4\noptimizer:\ntorchdistx_path:\nlr_scheduler:\nlearning_rate: 0.00001\nbf16: auto\ntf32: true\ngradient_checkpointing:\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention:\ngptq_groupsize:\ngptq_model_v1:\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\nweight_decay: 0\n#special_tokens:\n"
  },
  {
    "path": "examples/archived/stablelm-2/1.6b/fft.yml",
    "content": "base_model: stabilityai/stablelm-2-1_6b\n# optionally might have model_type or tokenizer_type\nmodel_type: AutoModelForCausalLM\ntokenizer_type: AutoTokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\ntrust_remote_code: true\n\ndatasets:\n  - path: mhenrichsen/alpaca_2k_test\n    type: alpaca\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.05\noutput_dir: ./outputs/out\n\nsequence_len: 4096\nsample_packing: true\n\n\nadapter:\nlora_model_dir:\nlora_r:\nlora_alpha:\nlora_dropout:\nlora_target_linear:\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 1\nmicro_batch_size: 1\nnum_epochs: 1\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\nflash_attn_cross_entropy: false\nflash_attn_rms_norm: true\nflash_attn_fuse_mlp: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\n\ndeepspeed: #deepspeed_configs/zero2.json # multi-gpu only\nweight_decay: 0.1\nspecial_tokens:\n"
  },
  {
    "path": "examples/archived/stablelm-2/1.6b/lora.yml",
    "content": "base_model: stabilityai/stablelm-2-1_6b\n# optionally might have model_type or tokenizer_type\nmodel_type: AutoModelForCausalLM\ntokenizer_type: AutoTokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\ntrust_remote_code: true\n\nload_in_8bit: true\nload_in_4bit: false\n\ndatasets:\n  - path: mhenrichsen/alpaca_2k_test\n    type: alpaca\ndataset_prepared_path:\nval_set_size: 0.05\noutput_dir: ./outputs/lora-out\n\nsequence_len: 4096\nsample_packing: true\n\n\nadapter: lora\nlora_model_dir:\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_linear: true\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 1\nmicro_batch_size: 1\nnum_epochs: 1\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\nflash_attn_cross_entropy: false\nflash_attn_rms_norm: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n"
  },
  {
    "path": "examples/archived/stablelm-2/README.md",
    "content": "# StableLM 2\n\nThis repository contains examples for training and processing using StableLM-2. It also includes a section to help you estimate the GPU requirements for your specific use case.\n\n## Estimating GPU Requirements\n\n| type          | deepspeed | batch size | context length | vRAM GPU (GBs) |\n|---------------|-----------|------------|----------------|----------------|\n| full finetune | N/A       | 1          | 4096           | ~21.5GBs       |\n| full finetune | zero2     | 1          | 4096           | ~20GBs         |\n| lora          | N/A       | 1          | 4096           | ~16.6GBs       |\n\nThe above are estimates and might differ slight depending on the setup for example whether you pack your sequence lengths or not (the above assumes you do to length 4096).\n\nThis blog post from Hamel Husain was a great resource for estimating these numbers: https://hamel.dev/notes/llm/03_estimating_vram.html\n\n## Training\nWe have example scripts here for both full finetuning and lora using the popular alpaca dataset:\n\n```shell\n# preprocess the dataset\nCUDA_VISIBLE_DEVICES=\"\" python -m axolotl.cli.preprocess examples/stablelm-2/1.6b/lora.yml\n```\n\nSingle GPU Training:\n```shell\npython -m axolotl.cli.train examples/stablelm-2/fft.yml --deepspeed deepspeed_configs/zero2.json\n# OR\npython -m axolotl.cli.train examples/stablelm-2/1.6b/lora.yml\n```\n\nMultinode GPU Training with `accelerate`:\n```shell\n# make sure you've configured accelerate properly\naccelerate launch -m axolotl.cli.train examples/stablelm-2/1.6b/fft.yml --deepspeed deepspeed_configs/zero2.json\n```\n"
  },
  {
    "path": "examples/archived/starcoder2/qlora.yml",
    "content": "base_model: bigcode/starcoder2-3b\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nload_in_8bit: false\nload_in_4bit: true\n\ndatasets:\n  - path: mhenrichsen/alpaca_2k_test\n    type: alpaca\n\n\ndataset_prepared_path:\nval_set_size: 0.2\noutput_dir: ./outputs/qlora\n\nadapter: qlora\nlora_model_dir:\n\nsequence_len: 8192\nsample_packing: true\n\n\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_linear: true\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_run_id:\nwandb_log_model:\n\ngradient_accumulation_steps: 8\nmicro_batch_size: 2\nnum_epochs: 3\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 2e-5\n\nbf16: auto\nfp16: false\ntf32: false\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 4\neval_steps:\nsaves_per_epoch: 4\nsave_steps:\nsave_total_limit: 2\nweight_decay:\nspecial_tokens:\n"
  },
  {
    "path": "examples/archived/tiny-llama/README.md",
    "content": "# Overview\n\nThis is a simple example of how to finetune TinyLlama1.1B using either lora or qlora:\n\nLoRa:\n\n```\naccelerate launch -m axolotl.cli.train examples/tiny-llama/lora.yml\n```\n\nqLoRa:\n\n```\naccelerate launch -m axolotl.cli.train examples/tiny-llama/qlora.yml\n```\n\nBoth take about 10 minutes to complete on a 4090.\n"
  },
  {
    "path": "examples/archived/tiny-llama/lora-mps.yml",
    "content": "base_model: TinyLlama/TinyLlama_v1.1\n# optionally might have model_type or tokenizer_type\nmodel_type: LlamaForCausalLM\ntokenizer_type: LlamaTokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nload_in_8bit: true\nload_in_4bit: false\n\ndatasets:\n  - path: mhenrichsen/alpaca_2k_test\n    type: alpaca\ndataset_prepared_path:\nval_set_size: 0\noutput_dir: ./outputs/lora-out\n\nsequence_len: 4096\nsample_packing: true\n\neval_sample_packing: false\n\nadapter: lora\nlora_model_dir:\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_linear: true\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 2\nnum_epochs: 4\noptimizer: adamw_torch_fused\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\nfp16: false\ntf32: true\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: false\n\nwarmup_ratio: 0.1\nevals_per_epoch: 0\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n"
  },
  {
    "path": "examples/archived/tiny-llama/lora.yml",
    "content": "base_model: TinyLlama/TinyLlama_v1.1\n# optionally might have model_type or tokenizer_type\ntokenizer_type: AutoTokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nload_in_8bit: true\nload_in_4bit: false\n\ndatasets:\n  - path: mhenrichsen/alpaca_2k_test\n    type: alpaca\ndataset_prepared_path:\nval_set_size: 0.05\noutput_dir: ./outputs/lora-out\n\nsequence_len: 4096\nsample_packing: true\neval_sample_packing: false\n\n\nadapter: lora\nlora_model_dir:\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_linear: true\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 2\nnum_epochs: 4\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n"
  },
  {
    "path": "examples/archived/tiny-llama/pretrain.yml",
    "content": "base_model: TinyLlama/TinyLlama-1.1B-Chat-v1.0\n# optionally might have model_type or tokenizer_type\nmodel_type: LlamaForCausalLM\ntokenizer_type: LlamaTokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nmax_steps: 200\npretraining_dataset:\n  - path: allenai/c4\n    name: en\n    type: pretrain\ndataset_prepared_path:\nval_set_size: 0.0\noutput_dir: ./outputs/model-out\n\nsequence_len: 2048\nsample_packing: true\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 2\nnum_epochs: 4\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch:\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n"
  },
  {
    "path": "examples/archived/tiny-llama/qlora.yml",
    "content": "base_model: TinyLlama/TinyLlama_v1.1\n# optionally might have model_type or tokenizer_type\nmodel_type: LlamaForCausalLM\ntokenizer_type: LlamaTokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nload_in_8bit: false\nload_in_4bit: true\n\ndatasets:\n  - path: mhenrichsen/alpaca_2k_test\n    type: alpaca\ndataset_prepared_path:\nval_set_size: 0.05\noutput_dir: ./outputs/qlora-out\n\nadapter: qlora\nlora_model_dir:\n\nsequence_len: 4096\nsample_packing: true\neval_sample_packing: false\n\n\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_linear: true\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 2\nnum_epochs: 4\noptimizer: paged_adamw_32bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n"
  },
  {
    "path": "examples/archived/xgen-7b/xgen-7b-8k-qlora.yml",
    "content": "# An example finetuning Saleforce's XGen-7b model with 8k context using qlora\n# on Tim Dettmer's Guanaco dataset.\nbase_model: Salesforce/xgen-7b-8k-base\n# optionally might have model_type or tokenizer_type\nmodel_type: AutoModelForCausalLM\ntokenizer_type: AutoTokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\ntrust_remote_code: true\n\nload_in_8bit: false\n# enable 4bit for QLoRA\nload_in_4bit: true\ngptq: false\npush_dataset_to_hub:\ndatasets:\n  - path: timdettmers/openassistant-guanaco\n    data_files:\n      - openassistant_best_replies_train.jsonl\n    type: \"completion\"\ndataset_prepared_path:\nval_set_size: 0.05\n# enable QLoRA\nadapter: qlora\nlora_model_dir:\nsequence_len: 8192\nmax_packed_sequence_len:\n\n# hyperparameters from QLoRA paper Appendix B.2\n# \"We find hyperparameters to be largely robust across datasets\"\nlora_r: 64\nlora_alpha: 16\n# 0.1 for models up to 13B\n# 0.05 for 33B and 65B models\nlora_dropout: 0.05\n# add LoRA modules on all linear layers of the base model\nlora_target_linear: true\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\noutput_dir: ./outputs/qlora-out\n\n# QLoRA paper Table 9\n# - 16 for 7b & 13b\n# - 32 for 33b, 64 for 64b\n# Max size tested on A6000\n# - 7b: 40\n# - 40b: 4\n# decrease if OOM, increase for max VRAM utilization\nmicro_batch_size: 1\ngradient_accumulation_steps: 1\nnum_epochs: 4\n# Optimizer for QLoRA\noptimizer: paged_adamw_32bit\ntorchdistx_path:\nlr_scheduler: cosine\n# QLoRA paper Table 9\n# - 2e-4 for 7b & 13b\n# - 1e-4 for 33b & 64b\nlearning_rate: 0.00002\nbf16: auto\ntf32: false\ngradient_checkpointing: true\n# stop training after this many evaluation losses have increased in a row\n# https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback\nearly_stopping_patience: 3\nresume_from_checkpoint:\nauto_resume_from_checkpoints: true\nlogging_steps: 1\nxformers_attention: true\nflash_attention:\ngptq_groupsize:\ngptq_model_v1:\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n  eos_token: \"<|endoftext|>\"\n  bos_token: \"<|endoftext|>\"\n  unk_token: \"<|endoftext|>\"\n  pad_token: \"<|endoftext|>\"\n"
  },
  {
    "path": "examples/archived/yi-34B-chat/README.md",
    "content": "# Overview\n\nThis is an example of a Yi-34B-Chat configuration. It demonstrates that it is possible to finetune a 34B model on a GPU with 24GB of VRAM.\n\nTested on an RTX 4090 with `python -m axolotl.cli.train examples/mistral/qlora.yml`, a single epoch of finetuning on the alpaca dataset using qlora runs in 47 mins, using 97% of available memory.\n"
  },
  {
    "path": "examples/archived/yi-34B-chat/qlora.yml",
    "content": "base_model: 01-ai/Yi-34B-Chat\n# optionally might have model_type or tokenizer_type\nmodel_type: LlamaForCausalLM\ntokenizer_type: LlamaTokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nload_in_8bit: false\nload_in_4bit: true\nsequence_len: 1024\nbf16: auto\ntf32: false\nflash_attention: true\nspecial_tokens:\n  bos_token: \"<|startoftext|>\"\n  eos_token: \"<|endoftext|>\"\n  unk_token: \"<unk>\"\n\n# Data\ndatasets:\n  - path: mhenrichsen/alpaca_2k_test\n    type: alpaca\nwarmup_ratio: 0.1\n\n# Iterations\nnum_epochs: 1\n\n# Evaluation\nval_set_size: 0.1\nevals_per_epoch: 5\neval_sample_packing: false\neval_batch_size: 1\n\n# LoRA\noutput_dir: ./outputs/qlora-out\nadapter: qlora\nlora_model_dir:\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_linear: true\nlora_target_modules:\n\n# Sampling\nsample_packing: false\npad_to_sequence_len: false\n\n# Batching\ngradient_accumulation_steps: 4\nmicro_batch_size: 1\ngradient_checkpointing: true\n\n# wandb\nwandb_project:\n\n# Optimizer\noptimizer: paged_adamw_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\n# Misc\nresume_from_checkpoint:\nlogging_steps: 1\nweight_decay: 0\n"
  },
  {
    "path": "examples/cloud/baseten.yaml",
    "content": "provider: baseten\nproject_name:\n\nsecrets:\n  - HF_TOKEN\n  - WANDB_API_KEY\n\ngpu: h100\ngpu_count: 8\nnode_count: 1\n"
  },
  {
    "path": "examples/cloud/modal.yaml",
    "content": "project_name:\nvolumes:\n  - name: axolotl-data\n    mount: /workspace/data\n  - name: axolotl-artifacts\n    mount: /workspace/artifacts\n\n# environment variables from local to set as secrets\nsecrets:\n  - HF_TOKEN\n  - WANDB_API_KEY\n\n# Which branch of axolotl to use remotely\nbranch:\n\n# additional custom commands when building the image\ndockerfile_commands:\n\ngpu: h100\ngpu_count: 1\n\n# Train specific configurations\nmemory: 128\ntimeout: 86400\n\n# Preprocess specific configurations\nmemory_preprocess: 32\ntimeout_preprocess: 14400\n"
  },
  {
    "path": "examples/cohere/command-r-7b-qlora.yml",
    "content": "base_model: CohereForAI/c4ai-command-r7b-12-2024\nmodel_type: AutoModelForCausalLM\ntokenizer_type: AutoTokenizer\n\nload_in_8bit: false\nload_in_4bit: true\n\n# huggingface repo\nchat_template: cohere\ndatasets:\n  - path: cgato/SlimOrcaDedupCleaned\n    type: chat_template\n    field_messages: conversations\n    message_property_mappings:\n      role: from\n      content: value\n\nval_set_size: 0.0\noutput_dir: ./outputs/out\n\nadapter: qlora\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_linear: true\n\nsequence_len: 2048\nsample_packing: true\neval_sample_packing: false\n\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 1\nnum_epochs: 4\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: true\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch:\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/colab-notebooks/colab-axolotl-example.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"OPLSwmgdrB7g\"\n   },\n   \"source\": [\n    \"# Fine-Tune Qwen3 14B with Axolotl\\n\",\n    \"\\n\",\n    \"[<img src=\\\"https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/image/axolotl-badge-web.png\\\" alt=\\\"Built with Axolotl\\\" width=\\\"200\\\" height=\\\"32\\\"/>](https://github.com/axolotl-ai-cloud/axolotl)\\n\",\n    \"\\n\",\n    \"Axolotl is the most performant LLM post-training framework available, delivering faster training with efficient, consistent and stable performance. Train your workload and ship your product 30% faster; saving you both time and money.\\n\",\n    \"\\n\",\n    \"- ⭐ us on [GitHub](https://github.com/axolotl-ai-cloud/axolotl)\\n\",\n    \"- 📜 Read the [Docs](http://docs.axolotl.ai/)\\n\",\n    \"- 💬 Chat with us on [Discord](https://discord.gg/mnpEYgRUmD)\\n\",\n    \"- 📰 Get updates on [X/Twitter](https://x.com/axolotl_ai)\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"rVjKD7CbxIP3\"\n   },\n   \"source\": [\n    \"# Installation\\n\",\n    \"\\n\",\n    \"Axolotl is easy to install from [pip](https://pypi.org/project/axolotl/), or use our [pre-built Docker images](http://docs.axolotl.ai/docs/docker.html) for a hassle free dependency experience. See our [docs](http://docs.axolotl.ai/docs/installation.html) for more information.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"id\": \"msOCO4NRmRLa\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"%%capture\\n\",\n    \"# This step can take ~5-10 minutes to install dependencies\\n\",\n    \"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\\n\",\n    \"!pip install \\\"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6\\\"\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"N0OW0YeksDLr\"\n   },\n   \"source\": [\n    \"## Demo: Talk Like a Pirate\\n\",\n    \"\\n\",\n    \"In this demo, we are training the model ***to respond like a pirate***. This was chosen as a way to easily show how to train a model to respond in a certain style of your choosing (without being prompted) and is quite easy to validate within the scope of a Colab.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"8Du2fANTsNCK\"\n   },\n   \"source\": [\n    \"### Upload your own dataset or use a Huggingface dataset\\n\",\n    \"\\n\",\n    \"You can choose to use your own JSONL file from your own [Google Drive](https://drive.google.com/drive/home); for example downloading the [Pirate-Ultrachat JSONL](https://huggingface.co/datasets/winglian/pirate-ultrachat-10k/blob/main/train.jsonl) to your Google Drive. JSONL datasets should be formatted similar to the [OpenAI dataset format](https://cookbook.openai.com/examples/chat_finetuning_data_prep).\\n\",\n    \"\\n\",\n    \"You can also simply use the [`winglian/pirate-ultrachat-10k`](https://huggingface.co/datasets/winglian/pirate-ultrachat-10k) dataset directly.\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"id\": \"fGEEjyQ-r_IV\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"# Default to HF dataset location\\n\",\n    \"dataset_id = \\\"winglian/pirate-ultrachat-10k\\\"\\n\",\n    \"uploaded = {}\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"id\": \"c5MyYqk7vIsG\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"import os\\n\",\n    \"\\n\",\n    \"# Optionally, upload your own JSONL to your Google Drive\\n\",\n    \"GOOGLE_DRIVE_PATH = \\\"\\\"  # ex: \\\"MyDrive/Colab\\\\ Notebooks/train.jsonl\\\"\\n\",\n    \"\\n\",\n    \"# \\\"Select All\\\" permissions, or you may get the error:\\n\",\n    \"# \\\"MessageError: Error: credential propagation was unsuccessful\\\"\\n\",\n    \"if GOOGLE_DRIVE_PATH:\\n\",\n    \"    from google.colab import drive\\n\",\n    \"\\n\",\n    \"    # Mount your Google Drive\\n\",\n    \"    GOOGLE_DRIVE_MNT = \\\"/content/drive/\\\"\\n\",\n    \"    drive.mount(GOOGLE_DRIVE_MNT, force_remount=True)\\n\",\n    \"    tmp_path = os.path.join(GOOGLE_DRIVE_MNT, GOOGLE_DRIVE_PATH.lstrip(\\\"/\\\"))\\n\",\n    \"    # make sure file exists\\n\",\n    \"    if not os.path.isfile(tmp_path):\\n\",\n    \"        raise ValueError(f\\\"File {tmp_path} does not exist\\\")\\n\",\n    \"    dataset_id = tmp_path\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"U6pTk3A9xj1W\"\n   },\n   \"source\": [\n    \"# Configure for Supervised Fine-Tuning (SFT)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"colab\": {\n     \"base_uri\": \"https://localhost:8080/\",\n     \"height\": 151,\n     \"referenced_widgets\": [\n      \"388f618924274d21a066f098f4f1e744\",\n      \"7c95f85a2b1f47a1bd846d110c47bb3c\",\n      \"083f9cda8d754c168beee10d2f8955a2\",\n      \"62e1a65582f446a78612eaa804e08a7d\",\n      \"487a177d020f4605834878b2fdc7afa3\",\n      \"7fd44cf9ca6e4726bfd7ac21846d6a14\",\n      \"366a343b62fa47d8985a3bd464d99f9e\",\n      \"a0a11e929edd4189b79723d618522c33\",\n      \"e87ea87fcff247b5bbcc331ba79a8dc2\",\n      \"5e18768f7ad6434ba8b8b8a2e853e204\",\n      \"bb33aec33a6447078c31bfd728942994\"\n     ]\n    },\n    \"id\": \"fdRioqytmTtX\",\n    \"outputId\": \"f0acdcec-4b41-4a3f-ffed-c2d2d929158e\"\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[2025-05-08 13:40:27,488] [INFO] [root.register:348] [PID:174] Attempting to load plugin: axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin\\n\",\n      \"[2025-05-08 13:40:27,493] [INFO] [root.register:351] [PID:174] Plugin loaded successfully: axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin\\n\",\n      \"[2025-05-08 13:40:27,959] [INFO] [axolotl.utils.schemas.config.check_eval_packing:721] [PID:174] [RANK:0] explicitly setting `eval_sample_packing` to match `sample_packing`\\u001b[39m\\n\",\n      \"[2025-05-08 13:40:27,960] [INFO] [axolotl.utils.schemas.config.hint_sample_packing_padding:514] [PID:174] [RANK:0] Setting `pad_to_sequence_len: true` to prevent memory leaks when sample_packing\\u001b[39m\\n\",\n      \"[2025-05-08 13:40:27,961] [INFO] [axolotl.utils.schemas.config.check_bf16:1251] [PID:174] [RANK:0] bf16 support detected, but not enabled for this configuration.\\u001b[39m\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"388f618924274d21a066f098f4f1e744\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/plain\": [\n       \"config.json:   0%|          | 0.00/728 [00:00<?, ?B/s]\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[2025-05-08 13:40:28,590] [INFO] [axolotl.normalize_config:237] [PID:174] [RANK:0] cuda memory usage baseline: 0.000GB (+0.002GB cache, +0.359GB misc)\\u001b[39m\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"from axolotl.cli.config import load_cfg\\n\",\n    \"from axolotl.utils.dict import DictDefault\\n\",\n    \"\\n\",\n    \"# Axolotl provides full control and transparency over model and training configuration\\n\",\n    \"config = DictDefault(\\n\",\n    \"    base_model=\\\"Qwen/Qwen3-14B\\\",  # Use the instruct tuned model, but we're aligning it to be a pirate\\n\",\n    \"    load_in_4bit=True,  # set to True for qLoRA\\n\",\n    \"    adapter=\\\"qlora\\\",\\n\",\n    \"    lora_r=32,\\n\",\n    \"    lora_alpha=64,\\n\",\n    \"    lora_target_modules=[\\n\",\n    \"        \\\"q_proj\\\",\\n\",\n    \"        \\\"k_proj\\\",\\n\",\n    \"        \\\"v_proj\\\",\\n\",\n    \"        \\\"o_proj\\\",  # train self_attn linear modules\\n\",\n    \"        \\\"gate_proj\\\",\\n\",\n    \"        \\\"down_proj\\\",\\n\",\n    \"        \\\"up_proj\\\",  # train MLP linear modules\\n\",\n    \"    ],\\n\",\n    \"    lora_qkv_kernel=True,  # optimized triton kernels for LoRA\\n\",\n    \"    lora_o_kernel=True,\\n\",\n    \"    lora_mlp_kernel=True,\\n\",\n    \"    embeddings_skip_upcast=True,  # keep embeddings in fp16 so the model fits in 15GB VRAM\\n\",\n    \"    xformers_attention=True,  # use xformers on Colab w/ T4 for memory efficient attention, flash_attention only on Ampere or above\\n\",\n    \"    plugins=[\\n\",\n    \"        # more efficient training using Apple's Cut Cross Entropy; https://github.com/apple/ml-cross-entropy\\n\",\n    \"        \\\"axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin\\\",\\n\",\n    \"    ],\\n\",\n    \"    sample_packing=True,  # 2-6x increase in tokens per micro-batch\\n\",\n    \"    # when using packing, use a slightly higher learning rate to account for fewer steps\\n\",\n    \"    # alternatively, reduce the micro_batch_size + gradient_accumulation_steps to achieve closer to the same number of steps/epoch\\n\",\n    \"    learning_rate=0.00019,\\n\",\n    \"    sequence_len=4096,  # larger sequence length improves packing efficiency for more tokens/sec\\n\",\n    \"    micro_batch_size=1,\\n\",\n    \"    gradient_accumulation_steps=1,\\n\",\n    \"    gradient_checkpointing=True,  # tradeoff reduced VRAM for increased time\\n\",\n    \"    gradient_checkpointing_kwargs={\\n\",\n    \"        \\\"use_reentrant\\\": False,\\n\",\n    \"    },\\n\",\n    \"    optimizer=\\\"paged_adamw_8bit\\\",\\n\",\n    \"    lr_scheduler=\\\"cosine\\\",\\n\",\n    \"    warmup_steps=5,\\n\",\n    \"    fp16=True,  # use float16 + automatic mixed precision, bfloat16 not supported on Colab w/ T4\\n\",\n    \"    bf16=False,\\n\",\n    \"    max_grad_norm=0.1,  # gradient clipping\\n\",\n    \"    num_epochs=1,\\n\",\n    \"    saves_per_epoch=2,  # how many checkpoints to save over one epoch\\n\",\n    \"    logging_steps=1,\\n\",\n    \"    output_dir=\\\"./outputs/qwen-sft-pirate-rrr\\\",\\n\",\n    \"    chat_template=\\\"qwen3\\\",\\n\",\n    \"    datasets=[\\n\",\n    \"        {\\n\",\n    \"            \\\"path\\\": dataset_id,  # Huggingface Dataset id or path to train.jsonl\\n\",\n    \"            \\\"type\\\": \\\"chat_template\\\",\\n\",\n    \"            \\\"split\\\": \\\"train\\\",\\n\",\n    \"            \\\"eot_tokens\\\": [\\\"<|im_end|>\\\"],\\n\",\n    \"        }\\n\",\n    \"    ],\\n\",\n    \"    dataloader_prefetch_factor=8,  # dataloader optimizations\\n\",\n    \"    dataloader_num_workers=2,\\n\",\n    \"    dataloader_pin_memory=True,\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"# validates the configuration\\n\",\n    \"cfg = load_cfg(config)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"id\": \"715UpvnSoBIS\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"from axolotl.utils import set_pytorch_cuda_alloc_conf\\n\",\n    \"\\n\",\n    \"set_pytorch_cuda_alloc_conf()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"Vc6MC-hwyH-n\"\n   },\n   \"source\": [\n    \"# Datasets\\n\",\n    \"\\n\",\n    \"Axolotl has a robust suite of loaders and transforms to parse most open datasets of any format into the appropriate chat template for your model. Axolotl will mask input tokens from the user's prompt so that the train loss is only calculated against the model's response. For more information, [see our documentation](http://docs.axolotl.ai/docs/dataset-formats/conversation.html) on dataset preparation.\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"colab\": {\n     \"base_uri\": \"https://localhost:8080/\",\n     \"height\": 1000,\n     \"referenced_widgets\": [\n      \"b82aa8c57f7c422a9a9c90f333ed2a99\",\n      \"c0991cf63ee6458b96e9a75e7a88b61a\",\n      \"71c8af139cd248b1b51101fd46a93f35\",\n      \"1d5117195d4b49eb8f1a73b18419f7ce\",\n      \"3c21e4a511b4441192c03b7f1d0976e9\",\n      \"ed28e2e0410d4e0b855467e798e53d66\",\n      \"d93f134f802b4b69b575bdaf07dbd27c\",\n      \"d0e9dce55cec4c1ca619a0ccf209d924\",\n      \"4c727d40ef0443449afc31724ee79f0c\",\n      \"0dea5caa27384f5689e3cab51f558727\",\n      \"a6f48410b9964fefba0c3009a77dc838\",\n      \"95caff42f08a4c2aa14c867b8f37f231\",\n      \"de7c37ee83e24f0c889e84d07279c2ec\",\n      \"9d4897eefb5f48259ffb2d23e332f752\",\n      \"253017b0d0534e54ab44e181f6d7c82d\",\n      \"27beaf06e41b472abdb544a43c720c5a\",\n      \"34cf3df51fbc41cabfdbba153c007f0e\",\n      \"ac764024cf1c4e08ba7749afd2cd20ac\",\n      \"30a81da86f8043eca301e86a8651201a\",\n      \"e8b7a81040904c1e89e58978223b1737\",\n      \"1c6f1f10667545aaab958016ba7e2c94\",\n      \"e6e969610738449887259063967f82b0\",\n      \"a138859f19b74fc0928dc236ab5359db\",\n      \"9b42e08b3c9548818488268768a118b1\",\n      \"12b56912736849fea2ad8124456fdc5c\",\n      \"879c8ab5873847a8833bd74123be90a4\",\n      \"20352e5f58d24bb8b1f3940efd14fe4a\",\n      \"d955dcaa0e944e719f3a06139dd54a03\",\n      \"d3de2662c7964f1ba96e58da382af720\",\n      \"97e36007e1304e1583fd81bfb13f0edd\",\n      \"c65dc74c7d6f4bab8f7dd28455161dd8\",\n      \"ef223e8504b64e3592589880326aaf41\",\n      \"598da69727bd4fb8b1caf465ac736d7a\",\n      \"5f86cd894de94c3280fadc1e2fd0ee13\",\n      \"a20927bf5f2c41f58c1e31ac858ab36c\",\n      \"0a46ad75c198463d843fb35e813642cb\",\n      \"09007681cf8d42aeb8c1d2f6a74e470a\",\n      \"ebc80d1a55fa47f4a5ea2756588569ec\",\n      \"1811cda0644e4190a9469d1774435d82\",\n      \"35c811d2ae8e43f3b5cecbdd3cfa857f\",\n      \"b8e39e4dddc3497fbc29ae45c66da759\",\n      \"63b4e563e85c4f03b1b72beda9577bcc\",\n      \"b195f160ca20442fadd8b5aed0ee41af\",\n      \"ca65e32eb52f48c09a84b33cb18f22cd\",\n      \"7cd0b85ebd204b7aba908417811ce4e0\",\n      \"7baeab52d6694c32b1efd1ea1a0a7782\",\n      \"519a7b154022443db6703f04a9142bae\",\n      \"d4183e9715f34d249942b8271cca3bdf\",\n      \"da2347ac94764a3fa2743343cf0d3cd2\",\n      \"93a44a11aa4846fa8efc6c1413ef1627\",\n      \"a55060adc3564407ac81ad7297d34aaa\",\n      \"d02274afd47b462291c745f261209d42\",\n      \"0f417447a7bd4a33acca96fa37aec877\",\n      \"63580b6fb30642479fe3000915bf551a\",\n      \"8f726dbfb45d4528afa33e36a6313267\",\n      \"03b093d592ba4386aa61f7b8483da660\",\n      \"b8766a88716948cf968f4563531a76d9\",\n      \"6f3a28b912714c6e931003549664bfa3\",\n      \"16d1283741404b7bb319094c992fce01\",\n      \"2a5bb0e818ab47be8cf6465988328503\",\n      \"2b3a2659b12244bd8548320320016dbf\",\n      \"0cd7efffbb3c4c4b972e63749f61ab97\",\n      \"5ca240f31e6b44e3882c5eb37cd5a309\",\n      \"5eb06edeb58e4930b1affef2a59eae81\",\n      \"a4e5789584564049b83df7c6c54a3e08\",\n      \"ff3a94b146a948b6907f5d80c7157f99\",\n      \"258b7c635c1045329d4669e48c46ccd5\",\n      \"6f68ed9889f54ad2ae8a3b95ac263a83\",\n      \"80366349d81e4dcc892db6cd56e384f3\",\n      \"c73055099c084dca996159e23e162d0b\",\n      \"977f799afaac4a55b2dc1cffa7d5b63b\",\n      \"41f3b32c2f6b4034ae7a3b9124e28bc7\",\n      \"a10d0a76010f4e508c65a9b69ebc5156\",\n      \"f8ef805b776145c3bfa9ba8d90972058\",\n      \"cc587493c33c4f118d1b1170f85be24c\",\n      \"e40d1c1ac9494b3bade9858324e7ffdf\",\n      \"d65b6b060d9845779299491ac5599c31\",\n      \"0f6907ebbc6242c8bde059cef1e1bd29\",\n      \"5bdfd87fc6cd4f9dabef7cfee29c8060\",\n      \"64f54d4a744a4627a07c3c0120276f3b\",\n      \"65b75b9b8bc143cf997796af68ff6668\",\n      \"d6fe74e4255444368f8f90a62157d869\",\n      \"4d468f96ec924681ad65eb671674b93e\",\n      \"ad7599de524549c48bf2d3124ad4b299\",\n      \"0546d04aae644dde846c58a4afb598a6\",\n      \"897b77a56c09479bb11d7f2a30997e55\",\n      \"81c3db71ac704280ad030072655f1537\",\n      \"042e091f75694c47aee761e760e76773\",\n      \"ef0a3c7a6f14460fb4da096928ae249e\",\n      \"07fb3a2c8315494e97b447e672dfae06\",\n      \"ec030fc3c346426f9abc3a89892258d3\",\n      \"e3fb3fc6afe04b3c9b7ac61809ce78fa\",\n      \"c3be9109d63c485d9c0ef4f9bc0f9218\",\n      \"12815f401eba44658caa7b2e490137a8\",\n      \"30e02aa2d0d241979369e598287f2639\",\n      \"dfd2a2649b8341ef913207526708aff1\",\n      \"4f1977d7e4824ef1a14b65f0f42bba10\",\n      \"c6164e05a1914ae48083db9ad7f4ef7c\",\n      \"813621384dc748b0ad06775e22761c0b\",\n      \"dc892a596f6942d7973c616c38f0eebb\",\n      \"c84cc07789be48aebb322c23d355289e\",\n      \"bed8726b8069434687c75452e21f19e5\",\n      \"16a188a0b06d45f980dcf3933509fe0a\",\n      \"60c1a0d765c14a1d888317e6a507e4ea\",\n      \"0077aedc3d174560bce924ee89e9c006\",\n      \"00321cce58884f6f9b3855a21fcd9187\",\n      \"fa864b41586f4a7aa56aeafd1d84eb75\",\n      \"3225603166b54e7aab766b9964a2f660\",\n      \"349eee9f56d64f0cba6fc24ff2c50c9b\",\n      \"7e5d3774060e4589aa65982da5ea4ef4\",\n      \"7c2485c6cdfe463da6fdb35982a1070d\",\n      \"ad1236893754446881e153adc9d5c962\",\n      \"daee63fd167e4441a32324b51b00ad2b\",\n      \"fe41858c6bd04c58840112b67c19a336\",\n      \"d262c82138024169b9f3aa034ca756fa\",\n      \"62e302ebdad64aada0ffe64ae1c873f3\",\n      \"bd1b0dfed6d34d16af33a4a58330f5ec\",\n      \"d07c8b97d3314f1c852e44bdd40f61ed\",\n      \"ebb69a2c3d0a4299a484698287b3087c\",\n      \"e5a82df528bb4e408797a3b6c2758f4a\",\n      \"f113ebd8c1c34806bea4dd7ed3035173\"\n     ]\n    },\n    \"id\": \"KQQhgK8FoDfF\",\n    \"outputId\": \"f69441d8-95f9-4885-c306-6c8709090ff6\"\n   },\n   \"outputs\": [\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"b82aa8c57f7c422a9a9c90f333ed2a99\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/plain\": [\n       \"tokenizer_config.json:   0%|          | 0.00/9.68k [00:00<?, ?B/s]\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"95caff42f08a4c2aa14c867b8f37f231\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/plain\": [\n       \"vocab.json:   0%|          | 0.00/2.78M [00:00<?, ?B/s]\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"a138859f19b74fc0928dc236ab5359db\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/plain\": [\n       \"merges.txt:   0%|          | 0.00/1.67M [00:00<?, ?B/s]\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"5f86cd894de94c3280fadc1e2fd0ee13\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/plain\": [\n       \"tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[2025-05-08 13:41:00,844] [DEBUG] [axolotl.utils.models.load_tokenizer:441] [PID:174] [RANK:0] EOS: 151645 / <|im_end|>\\u001b[39m\\n\",\n      \"[2025-05-08 13:41:00,845] [DEBUG] [axolotl.utils.models.load_tokenizer:442] [PID:174] [RANK:0] BOS: None / None\\u001b[39m\\n\",\n      \"[2025-05-08 13:41:00,846] [DEBUG] [axolotl.utils.models.load_tokenizer:443] [PID:174] [RANK:0] PAD: 151643 / <|endoftext|>\\u001b[39m\\n\",\n      \"[2025-05-08 13:41:00,847] [DEBUG] [axolotl.utils.models.load_tokenizer:444] [PID:174] [RANK:0] UNK: None / None\\u001b[39m\\n\",\n      \"[2025-05-08 13:41:00,869] [INFO] [axolotl.utils.data.sft.load_tokenized_prepared_datasets:271] [PID:174] [RANK:0] Unable to find prepared dataset in last_run_prepared/97037817611d38b3a9c681753c3c4c95\\u001b[39m\\n\",\n      \"[2025-05-08 13:41:00,870] [INFO] [axolotl.utils.data.sft.load_tokenized_prepared_datasets:272] [PID:174] [RANK:0] Loading raw datasets...\\u001b[39m\\n\",\n      \"\\u001b[33m[2025-05-08 13:41:00,870] [WARNING] [axolotl.utils.data.sft.load_tokenized_prepared_datasets:274] [PID:174] [RANK:0] Processing datasets during training can lead to VRAM instability. Please pre-process your dataset.\\u001b[39m\\n\",\n      \"[2025-05-08 13:41:00,871] [INFO] [axolotl.utils.data.sft.load_tokenized_prepared_datasets:281] [PID:174] [RANK:0] No seed provided, using default seed of 42\\u001b[39m\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"7cd0b85ebd204b7aba908417811ce4e0\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/plain\": [\n       \"train.jsonl:   0%|          | 0.00/27.3M [00:00<?, ?B/s]\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"03b093d592ba4386aa61f7b8483da660\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/plain\": [\n       \"Generating train split: 0 examples [00:00, ? examples/s]\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[2025-05-08 13:41:04,196] [INFO] [axolotl.utils.data.sft.get_dataset_wrapper:484] [PID:174] [RANK:0] Loading dataset with base_type: chat_template and prompt_style: None\\u001b[39m\\n\",\n      \"[2025-05-08 13:41:04,233] [INFO] [axolotl.__call__:761] [PID:174] [RANK:0] Using chat template:\\n\",\n      \"---\\n\",\n      \"{%- if tools %}\\n\",\n      \"    {{- '<|im_start|>system\\\\n' }}\\n\",\n      \"    {%- if messages[0].role == 'system' %}\\n\",\n      \"        {{- messages[0].content + '\\\\n\\\\n' }}\\n\",\n      \"    {%- endif %}\\n\",\n      \"    {{- \\\"# Tools\\\\n\\\\nYou may call one or more functions to assist with the user query.\\\\n\\\\nYou are provided with function signatures within <tools></tools> XML tags:\\\\n<tools>\\\" }}\\n\",\n      \"    {%- for tool in tools %}\\n\",\n      \"        {{- \\\"\\\\n\\\" }}\\n\",\n      \"        {{- tool | tojson }}\\n\",\n      \"    {%- endfor %}\\n\",\n      \"    {{- \\\"\\\\n</tools>\\\\n\\\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\\\n<tool_call>\\\\n{\\\\\\\"name\\\\\\\": <function-name>, \\\\\\\"arguments\\\\\\\": <args-json-object>}\\\\n</tool_call><|im_end|>\\\\n\\\" }}\\n\",\n      \"{%- else %}\\n\",\n      \"    {%- if messages[0].role == 'system' %}\\n\",\n      \"        {{- '<|im_start|>system\\\\n' + messages[0].content + '<|im_end|>\\\\n' }}\\n\",\n      \"    {%- endif %}\\n\",\n      \"{%- endif %}\\n\",\n      \"{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\\n\",\n      \"{%- for message in messages[::-1] %}\\n\",\n      \"    {%- set index = (messages|length - 1) - loop.index0 %}\\n\",\n      \"    {%- if ns.multi_step_tool and message.role == \\\"user\\\" and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}\\n\",\n      \"        {%- set ns.multi_step_tool = false %}\\n\",\n      \"        {%- set ns.last_query_index = index %}\\n\",\n      \"    {%- endif %}\\n\",\n      \"{%- endfor %}\\n\",\n      \"{%- for message in messages %}\\n\",\n      \"    {%- if (message.role == \\\"user\\\") or (message.role == \\\"system\\\" and not loop.first) %}\\n\",\n      \"        {{- '<|im_start|>' + message.role + '\\\\n' + message.content + '<|im_end|>' + '\\\\n' }}\\n\",\n      \"    {%- elif message.role == \\\"assistant\\\" %}\\n\",\n      \"        {%- set content = message.content %}\\n\",\n      \"        {%- set reasoning_content = '' %}\\n\",\n      \"        {%- if message.reasoning_content is defined and message.reasoning_content is not none %}\\n\",\n      \"            {%- set reasoning_content = message.reasoning_content %}\\n\",\n      \"        {%- else %}\\n\",\n      \"            {%- if '</think>' in message.content %}\\n\",\n      \"                {%- set content = message.content.split('</think>')[-1].lstrip('\\\\n') %}\\n\",\n      \"                {%- set reasoning_content = message.content.split('</think>')[0].rstrip('\\\\n').split('<think>')[-1].lstrip('\\\\n') %}\\n\",\n      \"            {%- endif %}\\n\",\n      \"        {%- endif %}\\n\",\n      \"        {%- if loop.index0 > ns.last_query_index %}\\n\",\n      \"            {%- if loop.last or (not loop.last and reasoning_content) %}\\n\",\n      \"                {{- '<|im_start|>' + message.role + '\\\\n<think>\\\\n' + reasoning_content.strip('\\\\n') + '\\\\n</think>\\\\n\\\\n' + content.lstrip('\\\\n') }}\\n\",\n      \"            {%- else %}\\n\",\n      \"                {{- '<|im_start|>' + message.role + '\\\\n' + content }}\\n\",\n      \"            {%- endif %}\\n\",\n      \"        {%- else %}\\n\",\n      \"            {{- '<|im_start|>' + message.role + '\\\\n' + content }}\\n\",\n      \"        {%- endif %}\\n\",\n      \"        {%- if message.tool_calls %}\\n\",\n      \"            {%- for tool_call in message.tool_calls %}\\n\",\n      \"                {%- if (loop.first and content) or (not loop.first) %}\\n\",\n      \"                    {{- '\\\\n' }}\\n\",\n      \"                {%- endif %}\\n\",\n      \"                {%- if tool_call.function %}\\n\",\n      \"                    {%- set tool_call = tool_call.function %}\\n\",\n      \"                {%- endif %}\\n\",\n      \"                {{- '<tool_call>\\\\n{\\\"name\\\": \\\"' }}\\n\",\n      \"                {{- tool_call.name }}\\n\",\n      \"                {{- '\\\", \\\"arguments\\\": ' }}\\n\",\n      \"                {%- if tool_call.arguments is string %}\\n\",\n      \"                    {{- tool_call.arguments }}\\n\",\n      \"                {%- else %}\\n\",\n      \"                    {{- tool_call.arguments | tojson }}\\n\",\n      \"                {%- endif %}\\n\",\n      \"                {{- '}\\\\n</tool_call>' }}\\n\",\n      \"            {%- endfor %}\\n\",\n      \"        {%- endif %}\\n\",\n      \"        {{- '<|im_end|>\\\\n' }}\\n\",\n      \"    {%- elif message.role == \\\"tool\\\" %}\\n\",\n      \"        {%- if loop.first or (messages[loop.index0 - 1].role != \\\"tool\\\") %}\\n\",\n      \"            {{- '<|im_start|>user' }}\\n\",\n      \"        {%- endif %}\\n\",\n      \"        {{- '\\\\n<tool_response>\\\\n' }}\\n\",\n      \"        {{- message.content }}\\n\",\n      \"        {{- '\\\\n</tool_response>' }}\\n\",\n      \"        {%- if loop.last or (messages[loop.index0 + 1].role != \\\"tool\\\") %}\\n\",\n      \"            {{- '<|im_end|>\\\\n' }}\\n\",\n      \"        {%- endif %}\\n\",\n      \"    {%- endif %}\\n\",\n      \"{%- endfor %}\\n\",\n      \"{%- if add_generation_prompt %}\\n\",\n      \"    {{- '<|im_start|>assistant\\\\n' }}\\n\",\n      \"    {%- if enable_thinking is defined and enable_thinking is false %}\\n\",\n      \"        {{- '<think>\\\\n\\\\n</think>\\\\n\\\\n' }}\\n\",\n      \"    {%- endif %}\\n\",\n      \"{%- endif %}\\n\",\n      \"---\\u001b[39m\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"258b7c635c1045329d4669e48c46ccd5\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/plain\": [\n       \"Tokenizing Prompts (num_proc=2):   0%|          | 0/9985 [00:00<?, ? examples/s]\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[2025-05-08 13:42:09,195] [INFO] [axolotl.utils.data.utils.drop_long_seq_in_dataset:177] [PID:174] [RANK:0] min_input_len: 23\\u001b[39m\\n\",\n      \"[2025-05-08 13:42:09,196] [INFO] [axolotl.utils.data.utils.drop_long_seq_in_dataset:179] [PID:174] [RANK:0] max_input_len: 3380\\u001b[39m\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"0f6907ebbc6242c8bde059cef1e1bd29\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/plain\": [\n       \"Dropping Long Sequences (num_proc=2):   0%|          | 0/9985 [00:00<?, ? examples/s]\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"ef0a3c7a6f14460fb4da096928ae249e\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/plain\": [\n       \"Drop Samples with Zero Trainable Tokens (num_proc=2):   0%|          | 0/9985 [00:00<?, ? examples/s]\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"dc892a596f6942d7973c616c38f0eebb\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/plain\": [\n       \"Add position_id column (Sample Packing) (num_proc=2):   0%|          | 0/9985 [00:00<?, ? examples/s]\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[2025-05-08 13:42:21,651] [INFO] [axolotl.utils.data.sft.load_tokenized_prepared_datasets:351] [PID:174] [RANK:0] Saving merged prepared dataset to disk... last_run_prepared/97037817611d38b3a9c681753c3c4c95\\u001b[39m\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"7c2485c6cdfe463da6fdb35982a1070d\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/plain\": [\n       \"Saving the dataset (0/1 shards):   0%|          | 0/9985 [00:00<?, ? examples/s]\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[2025-05-08 13:42:25,711] [INFO] [axolotl.utils.samplers.multipack.calc_min_len:411] [PID:174] [RANK:0] gather_len_batches: [1540]\\u001b[39m\\n\",\n      \"[2025-05-08 13:42:25,714] [INFO] [axolotl.calc_sample_packing_eff_est:491] [PID:174] [RANK:0] sample_packing_eff_est across ranks: [0.9987832601968344]\\u001b[39m\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"from axolotl.common.datasets import load_datasets\\n\",\n    \"\\n\",\n    \"# Load, parse and tokenize the datasets to be formatted with qwen3 chat template\\n\",\n    \"# Drop long samples from the dataset that overflow the max sequence length\\n\",\n    \"dataset_meta = load_datasets(cfg=cfg)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"mrSNfHpk0EAe\"\n   },\n   \"source\": [\n    \"# Training\\n\",\n    \"\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"colab\": {\n     \"base_uri\": \"https://localhost:8080/\",\n     \"height\": 1000,\n     \"referenced_widgets\": [\n      \"004d9177a6a14118a5930dc3cc13147b\",\n      \"a80410b919e442c49aea15acc1ce1a72\",\n      \"c6e00f5224364822bc4239b176686919\",\n      \"ec11d1e5ae7b42c883d9b1f38a65356e\",\n      \"734185351eb543fa9a00a881dcbb9fe7\",\n      \"fa1282ccc7544e4f818e2f03ccffe4a5\",\n      \"bbbf575d2a4b4c6ea8389be79b2a6039\",\n      \"2a51b36be41745468e4c2d7a21b1c0d2\",\n      \"4fd114abe9f5494ab59858949f5055f1\",\n      \"936d04b5fe1b4c63bf0b080e423d051b\",\n      \"f1cef8e8dc2646fb9fd09f3b09081074\",\n      \"cdebbc55a1164c018546c2ac6f8c620c\",\n      \"a44f630e099e43899f20a77084ae60cd\",\n      \"c3725c7f79fe415fbd1ea336f0cc9cf1\",\n      \"0e50870ed0c643e0b6c18cc5d7ddae7f\",\n      \"c33ced495f70464aa4a3a91922090853\",\n      \"ed5ca967ad5342929e578ac6aa4dc4c0\",\n      \"af401d117d5047629d3a6e2361757b62\",\n      \"b191ac001a2e4962bc9a245fcdf26e6b\",\n      \"054c8dffadba48c6b895a6cc62448ecc\",\n      \"bfcdbba993b74972a9e3e575f86908ff\",\n      \"6ebb2ec171414e47a14765505f64bb3c\",\n      \"500e272208a246089613bf788a165271\",\n      \"200df5e79b9244849e589ecb0250a520\",\n      \"cc94432d08464affa3e58b560bdad194\",\n      \"3036608c71904ce9ae4bb2a9fa8802d9\",\n      \"adacfdcc1b0140efac56918e9ccf064e\",\n      \"f4a1795dc7514a718f478245f521f0ba\",\n      \"5e746eb25bbe416fb585fa24e79f5177\",\n      \"b5b65414154544aa8a71b1a39164aad7\",\n      \"f0a58fbd0fca4340890041f99fa2f8c8\",\n      \"5ca6be24acb548cea130bd58e9954c7c\",\n      \"5cfb02ee044b4011a378efa8b54a370f\",\n      \"4d05314858354e729d76094b3b0ce761\",\n      \"c42acf646f344a88b8c11f81e67f7206\",\n      \"7be6f04c284e4326bb4ff3d301e7b3c6\",\n      \"ffdbb12a2f2c4d14911685e7683e0ef0\",\n      \"bee3501b2a17427784a717e50a85e7fa\",\n      \"8bc9d8ba866c442b9118d9630009939c\",\n      \"9f56a2d9979c4bd8928c644c22c3ecdf\",\n      \"9503a45960984adc97b58e16c50662e0\",\n      \"da6e93f3e4984780b930fe7a706983ea\",\n      \"ab93eabd7cea4b94b4b7a387f101e8a1\",\n      \"704f2f5a9b1c49d5a75a0025a5dda11b\",\n      \"dd0e646fad3f4a89ba23b39d162bd8d9\",\n      \"d43c6df07ddb466587807d6dbe1ff614\",\n      \"e0e8b840b8ea4d0d9db09afe99fa287d\",\n      \"9327977822be4b1294f80e876552e305\",\n      \"77304d1a46b3468a98483e02ec0ac4a4\",\n      \"8c4d4fc5a30f4e7cb3be53fe2adda33d\",\n      \"e90658f4bcb642baa78426012f863152\",\n      \"f7434f3e03124a1c938a39af79d7fa59\",\n      \"c1314f241a434c41b45d84dc4d3b30f8\",\n      \"37de928300e34184881039378bd75e7f\",\n      \"0e936d9dbf9c4fdd86bbfe9730dedc47\",\n      \"e21e180307e5485cbbe908672fd6639a\",\n      \"2e2b0c1599c341a198f632f46a40c90e\",\n      \"bff139df987d4a62abec6456cb27f3d4\",\n      \"ebe1cc366d324ad59b264c8b3c431441\",\n      \"114dece49dba437c8572ef94b23c3b1e\",\n      \"be724f04b03942b2a033a7e8898bb4fd\",\n      \"fcbab4d8dced41a18dfccce81e3a45a0\",\n      \"c1f9c267ba3f40039cdb5eb3267e8043\",\n      \"33b3b1d0295646edaac7b4822761aeb0\",\n      \"fba7aa824b38467ab3061b226114cdec\",\n      \"f3075dccbd2747b4a7913b66f44f2596\",\n      \"fe18bba7f3fb4c31bf840541f36b3425\",\n      \"fd4f333f7ece4450b04e1a9af1f9d2f6\",\n      \"f60a2bdb6b6b4e0e8c3508580e247132\",\n      \"c0892a1881de4eb4bfabc6a68f87ae99\",\n      \"1bec6297c90242a88672d195bc09d429\",\n      \"d1f9b10c130542f094c8fd3d1e23b5e9\",\n      \"e575d87a7efe4ec7b1efde489839d4a6\",\n      \"edc99591b9c747b689b94d0052fec14c\",\n      \"35cc989ca3374e7dba0cb166febc4bde\",\n      \"158c8b85dbf34de6a94b4e35e2fc7d5a\",\n      \"0b4c9753a7cb4354b8e5f187e6e1ad7c\",\n      \"4471ff62258549fba9514bb67050f965\",\n      \"9cd5211b5d8b457aa0002f1d17b80028\",\n      \"19127c7bb1554ccbac877059f9a82db0\",\n      \"f4667818b9d34a09891cd727a429a610\",\n      \"9ed02dc43412471a9ab47f3620ccf3a5\",\n      \"6932489232ec4ab18a160b1e7fbcdfe1\",\n      \"4540927d98f54466b434ba4c0edf045d\",\n      \"e400cbf14bcc446a9d33b210cd93550b\",\n      \"71002199df6b40c9a1ac40df5fb27a1b\",\n      \"4b27c267393640f28f6eae0875bd2ed9\",\n      \"9858cb74a09748a39e8149baac96702c\",\n      \"eb1c9535e6a546098b760528b2ea387c\",\n      \"18357b321ce44d7b8bd9d1c886f69275\",\n      \"279937fe03bc4e4eb25b472d7e9df163\",\n      \"bca2c7185b6749fd899c06a2ba4c5e46\",\n      \"1f7d30f71bbd4547a9150d21da071055\",\n      \"e366ae3fceec4566b9ed303d6c5f90af\",\n      \"5dd7d150dbe04f08b165ce7f2c27cd11\",\n      \"b634bb73cfa743d09a5999101b840976\",\n      \"742b1030acfd414bbd9d5327b7e3826d\",\n      \"0f480e3a0b0a45d2a2d2dec3cad923f3\",\n      \"fcb30372e7404c5d8a1ad4df91e6c7b2\",\n      \"2860e3bb3baf4f7da058465850e800c5\",\n      \"3efd18ea8eaa41918894883da9541bfa\",\n      \"e09f1bcbb9d94c09be53e5e1303642c2\",\n      \"82177df57a494de8900c14c2f5185175\",\n      \"ccfcdc95baf646f8aeb3d516742383f2\",\n      \"8f5bd719974e41c3a8dd9a5b0d3d71e6\",\n      \"b87c84de30e84b3abf4871461fb9cbd3\",\n      \"e7d8e4fe58384e93a106de546068c65e\",\n      \"0aa8ab56b85f4171a79c3bc210594025\",\n      \"67da6c4260574869aa24c3cbc1bc1654\",\n      \"94b9088614464f60a203de39dbcae853\",\n      \"fea1b70fb46745feb5111b3929175b5d\",\n      \"f365820a3d3c42b2948abfe32065de14\",\n      \"823f1c78f15043e38bbd4dca3932a86a\",\n      \"a1959759c5424da9961fb2a308d4dee4\",\n      \"34c9c0137b504cd799c6bd6de69507c2\",\n      \"735d4f225b24414294fc1b213c61223c\",\n      \"5e5e15b0569b474c9620083b3ec6af55\",\n      \"03a3c744d716431488163b4358b80f92\",\n      \"a5434ee714f9498d83870544b67c0cb7\",\n      \"3aaecbf540f54a2db9ab0931e3b1fe57\",\n      \"9e333ed3b5014069ac1dd969255dd591\"\n     ]\n    },\n    \"id\": \"IwrpurmloGOy\",\n    \"outputId\": \"84fa167f-ba27-4255-d508-dc9df56ad39b\"\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\n\",\n      \"     #@@ #@@      @@# @@#\\n\",\n      \"    @@  @@          @@  @@           =@@#                               @@                 #@    =@@#.\\n\",\n      \"    @@    #@@@@@@@@@    @@           #@#@=                              @@                 #@     .=@@\\n\",\n      \"      #@@@@@@@@@@@@@@@@@            =@# @#     ##=     ##    =####=+    @@      =#####+  =#@@###.   @@\\n\",\n      \"    @@@@@@@@@@/  +@@/  +@@          #@  =@=     #@=   @@   =@#+  +#@#   @@    =@#+  +#@#   #@.      @@\\n\",\n      \"    @@@@@@@@@@  ##@@  ##@@         =@#   @#      =@# @#    @@      @@   @@    @@      #@   #@       @@\\n\",\n      \"     @@@@@@@@@@@@@@@@@@@@          #@=+++#@=      =@@#     @@      @@   @@    @@      #@   #@       @@\\n\",\n      \"                                  =@#=====@@     =@# @#    @@      @@   @@    @@      #@   #@       @@\\n\",\n      \"    @@@@@@@@@@@@@@@@  @@@@        #@      #@=   #@=  +@@   #@#    =@#   @@.   =@#    =@#   #@.      @@\\n\",\n      \"                                 =@#       @#  #@=     #@   =#@@@@#=    +#@@=  +#@@@@#=    .##@@+   @@\\n\",\n      \"    @@@@  @@@@@@@@@@@@@@@@\\n\",\n      \"\\n\",\n      \"[2025-05-07 22:08:14,344] [INFO] [axolotl.monkeypatch.peft.utils.patch_peft_prep_code:76] [PID:1336] [RANK:0] patching prepare_model_for_kbit_training to allow for overrides\\u001b[39m\\n\",\n      \"[2025-05-07 22:08:14,549] [INFO] [axolotl.integrations.cut_cross_entropy.pre_model_load:80] [PID:1336] [RANK:0] Applying Cut Cross Entropy to model type: qwen3\\u001b[39m\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"004d9177a6a14118a5930dc3cc13147b\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/plain\": [\n       \"model.safetensors.index.json:   0%|          | 0.00/36.5k [00:00<?, ?B/s]\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"cdebbc55a1164c018546c2ac6f8c620c\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/plain\": [\n       \"model-00001-of-00008.safetensors:   0%|          | 0.00/3.84G [00:00<?, ?B/s]\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"500e272208a246089613bf788a165271\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/plain\": [\n       \"model-00002-of-00008.safetensors:   0%|          | 0.00/3.96G [00:00<?, ?B/s]\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"4d05314858354e729d76094b3b0ce761\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/plain\": [\n       \"model-00003-of-00008.safetensors:   0%|          | 0.00/3.96G [00:00<?, ?B/s]\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"dd0e646fad3f4a89ba23b39d162bd8d9\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/plain\": [\n       \"model-00004-of-00008.safetensors:   0%|          | 0.00/3.96G [00:00<?, ?B/s]\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"e21e180307e5485cbbe908672fd6639a\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/plain\": [\n       \"model-00005-of-00008.safetensors:   0%|          | 0.00/3.96G [00:00<?, ?B/s]\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"fe18bba7f3fb4c31bf840541f36b3425\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/plain\": [\n       \"model-00006-of-00008.safetensors:   0%|          | 0.00/3.96G [00:00<?, ?B/s]\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"4471ff62258549fba9514bb67050f965\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/plain\": [\n       \"model-00007-of-00008.safetensors:   0%|          | 0.00/3.96G [00:00<?, ?B/s]\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"eb1c9535e6a546098b760528b2ea387c\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/plain\": [\n       \"model-00008-of-00008.safetensors:   0%|          | 0.00/1.91G [00:00<?, ?B/s]\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[2025-05-07 22:09:49,798] [INFO] [accelerate.utils.modeling.get_balanced_memory:990] [PID:1336] We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"2860e3bb3baf4f7da058465850e800c5\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/plain\": [\n       \"Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"fea1b70fb46745feb5111b3929175b5d\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/plain\": [\n       \"generation_config.json:   0%|          | 0.00/239 [00:00<?, ?B/s]\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[2025-05-07 22:11:37,521] [INFO] [axolotl.utils.models.load_model:1302] [PID:1336] [RANK:0] cuda memory usage after model load: 9.264GB (+1.721GB cache, +0.375GB misc)\\u001b[39m\\n\",\n      \"[2025-05-07 22:11:37,532] [INFO] [axolotl.utils.models.prepare_model:1205] [PID:1336] [RANK:0] converting PEFT model w/ prepare_model_for_kbit_training\\u001b[39m\\n\",\n      \"[2025-05-07 22:11:37,537] [INFO] [axolotl.utils.models.load_model:1341] [PID:1336] [RANK:0] Converting modules to torch.float16\\u001b[39m\\n\",\n      \"trainable params: 128,450,560 || all params: 14,896,757,760 || trainable%: 0.8623\\n\",\n      \"[2025-05-07 22:11:40,170] [INFO] [axolotl.utils.models.load_model:1402] [PID:1336] [RANK:0] cuda memory usage after adapters: 9.743GB (+1.476GB cache, +0.375GB misc)\\u001b[39m\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"/usr/local/lib/python3.11/dist-packages/axolotl/core/trainers/base.py:64: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `AxolotlTrainer.__init__`. Use `processing_class` instead.\\n\",\n      \"  super().__init__(*_args, **kwargs)\\n\",\n      \"No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.\\n\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[2025-05-07 22:11:41,755] [INFO] [axolotl.train.save_initial_configs:359] [PID:1336] [RANK:0] Pre-saving adapter config to ./outputs/qwen-sft-pirate-rrr...\\u001b[39m\\n\",\n      \"[2025-05-07 22:11:41,756] [INFO] [axolotl.train.save_initial_configs:363] [PID:1336] [RANK:0] Pre-saving tokenizer to ./outputs/qwen-sft-pirate-rrr...\\u001b[39m\\n\",\n      \"[2025-05-07 22:11:41,974] [INFO] [axolotl.train.save_initial_configs:366] [PID:1336] [RANK:0] Pre-saving model config to ./outputs/qwen-sft-pirate-rrr...\\u001b[39m\\n\",\n      \"[2025-05-07 22:11:41,982] [INFO] [axolotl.train.execute_training:211] [PID:1336] [RANK:0] Starting trainer...\\u001b[39m\\n\",\n      \"[2025-05-07 22:11:45,047] [INFO] [axolotl.utils.samplers.multipack.calc_min_len:411] [PID:1336] [RANK:0] gather_len_batches: [1540]\\u001b[39m\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"You're using a Qwen2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\\n\",\n      \"You're using a Qwen2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/html\": [\n       \"\\n\",\n       \"    <div>\\n\",\n       \"      \\n\",\n       \"      <progress value='25' max='25' style='width:300px; height:20px; vertical-align: middle;'></progress>\\n\",\n       \"      [25/25 09:25, Epoch 0/1]\\n\",\n       \"    </div>\\n\",\n       \"    <table border=\\\"1\\\" class=\\\"dataframe\\\">\\n\",\n       \"  <thead>\\n\",\n       \" <tr style=\\\"text-align: left;\\\">\\n\",\n       \"      <th>Step</th>\\n\",\n       \"      <th>Training Loss</th>\\n\",\n       \"    </tr>\\n\",\n       \"  </thead>\\n\",\n       \"  <tbody>\\n\",\n       \"    <tr>\\n\",\n       \"      <td>1</td>\\n\",\n       \"      <td>1.092300</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <td>2</td>\\n\",\n       \"      <td>1.554200</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <td>3</td>\\n\",\n       \"      <td>1.041400</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <td>4</td>\\n\",\n       \"      <td>1.733800</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <td>5</td>\\n\",\n       \"      <td>1.430000</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <td>6</td>\\n\",\n       \"      <td>1.258500</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <td>7</td>\\n\",\n       \"      <td>1.343600</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <td>8</td>\\n\",\n       \"      <td>1.101700</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <td>9</td>\\n\",\n       \"      <td>1.086500</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <td>10</td>\\n\",\n       \"      <td>0.813200</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <td>11</td>\\n\",\n       \"      <td>0.689600</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <td>12</td>\\n\",\n       \"      <td>0.826700</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <td>13</td>\\n\",\n       \"      <td>1.541800</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <td>14</td>\\n\",\n       \"      <td>0.948000</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <td>15</td>\\n\",\n       \"      <td>1.357000</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <td>16</td>\\n\",\n       \"      <td>1.085800</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <td>17</td>\\n\",\n       \"      <td>1.516800</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <td>18</td>\\n\",\n       \"      <td>1.146800</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <td>19</td>\\n\",\n       \"      <td>0.834800</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <td>20</td>\\n\",\n       \"      <td>0.968000</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <td>21</td>\\n\",\n       \"      <td>1.388800</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <td>22</td>\\n\",\n       \"      <td>1.511500</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <td>23</td>\\n\",\n       \"      <td>1.338500</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <td>24</td>\\n\",\n       \"      <td>1.206600</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <td>25</td>\\n\",\n       \"      <td>1.504600</td>\\n\",\n       \"    </tr>\\n\",\n       \"  </tbody>\\n\",\n       \"</table><p>\"\n      ],\n      \"text/plain\": [\n       \"<IPython.core.display.HTML object>\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[2025-05-07 22:12:42,746] [INFO] [axolotl.callbacks.on_step_end:128] [PID:1336] [RANK:0] cuda memory usage while training: 9.768GB (+3.287GB cache, +0.646GB misc)\\u001b[39m\\n\",\n      \"[2025-05-07 22:21:46,859] [INFO] [axolotl.train.save_trained_model:231] [PID:1336] [RANK:0] Training completed! Saving pre-trained model to ./outputs/qwen-sft-pirate-rrr.\\u001b[39m\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"from axolotl.train import train\\n\",\n    \"\\n\",\n    \"# just train the first 25 steps for demo.\\n\",\n    \"# This is sufficient to align the model as we've used packing to maximize the trainable samples per step.\\n\",\n    \"cfg.max_steps = 25\\n\",\n    \"model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"j1b9ypF78eCb\"\n   },\n   \"source\": [\n    \"# Inferencing the trained model\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"colab\": {\n     \"base_uri\": \"https://localhost:8080/\"\n    },\n    \"id\": \"r3_vHhif8YEs\",\n    \"outputId\": \"e5050605-f6c9-421c-98f9-bde56a281eae\"\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Ahoy there, matey! Shiver me timbers, ye be lookin' for the Pythagorean theorem, eh? Well, hold yer horses and listen up, for I'll be tellin' ye all about it in me own special way.\\n\",\n      \"\\n\",\n      \"The Pythagorean theorem be a real gem of a mathematical trick that helps ye find the length of a side of a right triangle. Now, a right triangle be a triangle with a right angle, which be that little corner that looks like a square. \\n\",\n      \"\\n\",\n      \"The theorem be named after a clever fellow named Pythagoras, who be a mathematician from ancient Greece. He discovered that if ye have a right triangle, the square of the length of the hypotenuse (that be the side opposite the right angle) be equal to the sum of the squares of the other two sides. \\n\",\n      \"\\n\",\n      \"In other words, if ye have a triangle with sides of length a, b, and c (\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"from transformers import TextStreamer\\n\",\n    \"\\n\",\n    \"messages = [\\n\",\n    \"    {\\n\",\n    \"        \\\"role\\\": \\\"user\\\",\\n\",\n    \"        \\\"content\\\": \\\"Explain the Pythagorean theorem to me.\\\",\\n\",\n    \"    },\\n\",\n    \"]\\n\",\n    \"\\n\",\n    \"prompt = tokenizer.apply_chat_template(\\n\",\n    \"    messages,\\n\",\n    \"    add_generation_prompt=True,\\n\",\n    \"    tokenize=False,\\n\",\n    \"    enable_thinking=False,\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"outputs = model.generate(\\n\",\n    \"    **tokenizer(prompt, return_tensors=\\\"pt\\\").to(\\\"cuda\\\"),\\n\",\n    \"    max_new_tokens=192,\\n\",\n    \"    temperature=1.0,\\n\",\n    \"    top_p=0.8,\\n\",\n    \"    top_k=32,\\n\",\n    \"    streamer=TextStreamer(tokenizer, skip_prompt=True),\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"HoGwT2JRSIjA\"\n   },\n   \"source\": [\n    \"# Saving your trained model\\n\",\n    \"\\n\",\n    \"Axolotl automatically saves checkpoints to the `output_dir` path.\\n\",\n    \"\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"colab\": {\n     \"base_uri\": \"https://localhost:8080/\"\n    },\n    \"id\": \"5BmSbiy6NaaS\",\n    \"outputId\": \"f5e1d913-7d55-42d2-8340-f9f1b0bc2b38\"\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"total 506M\\n\",\n      \"-rw-r--r-- 1 root root  845 May  7 22:21 adapter_config.json\\n\",\n      \"-rw-r--r-- 1 root root 491M May  7 22:21 adapter_model.safetensors\\n\",\n      \"-rw-r--r-- 1 root root  707 May  7 22:11 added_tokens.json\\n\",\n      \"drwxr-xr-x 2 root root 4.0K May  7 22:17 checkpoint-13\\n\",\n      \"drwxr-xr-x 2 root root 4.0K May  7 22:21 checkpoint-25\\n\",\n      \"-rw-r--r-- 1 root root 1.2K May  7 22:11 config.json\\n\",\n      \"-rw-r--r-- 1 root root 1.6M May  7 22:11 merges.txt\\n\",\n      \"-rw-r--r-- 1 root root 2.6K May  7 22:21 README.md\\n\",\n      \"-rw-r--r-- 1 root root  613 May  7 22:11 special_tokens_map.json\\n\",\n      \"-rw-r--r-- 1 root root 9.5K May  7 22:11 tokenizer_config.json\\n\",\n      \"-rw-r--r-- 1 root root  11M May  7 22:11 tokenizer.json\\n\",\n      \"-rw-r--r-- 1 root root 2.7M May  7 22:11 vocab.json\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"# Show the saved checkpoints in the output_dir\\n\",\n    \"!ls -lh \\\"./outputs/qwen-sft-pirate-rrr\\\"\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"_PCIFWxuOZd6\"\n   },\n   \"source\": [\n    \"Setting `hub_model_id: ` in the original config would have automatically uploaded the model to HuggingFace Hub (e.g. `hub_model_id: username/model_id`)\\n\",\n    \"\\n\",\n    \"If you prefer to manually upload the training artifacts, we can still upload the entire final checkpoint to HuggingFace from the CLI.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"colab\": {\n     \"base_uri\": \"https://localhost:8080/\",\n     \"height\": 955,\n     \"referenced_widgets\": [\n      \"c12ea43372ac4d57bb9605f1a429b397\",\n      \"86816687746246b4a6105e8010384e25\",\n      \"6f05e9bebf7b40c9835808e77de6c236\",\n      \"c7433acd3c4841e6958ae8f7e87b1808\",\n      \"19c1e38389fa46c7b7e2152a56e1df34\",\n      \"0e067d8db8ed48308a718d5f57683fd1\",\n      \"131065f118274a1586ac38e39ed84ef0\",\n      \"8640ac440fbc4644b9a3af7ba3ae7183\",\n      \"5cea7996f02040b187ece0bb2d6a8d1f\",\n      \"2e257c8be2da40b4bb67a9e4ab6811f3\",\n      \"56e3768bef5a4b9db4168c5c17f509c2\",\n      \"62c028fdef904dedb9cdeca2b3bda725\",\n      \"a7cf477e80fc43e0ad82c7997b076dce\",\n      \"835bcc28a5564fb9b3d651bc8e32dc46\",\n      \"9f1c9a0695384bdaa6f8b847ef89bee8\",\n      \"b1bea589efa14258a9982071b87938bf\",\n      \"590eef89881545aa8bbef9a8bbe7fb00\",\n      \"4b1f04ff63d14a118fdd15814dff50e4\",\n      \"39789237703c4a418134243055c9cbf5\",\n      \"a3a945817f684328b34651fe052393ec\"\n     ]\n    },\n    \"id\": \"2yw8pLvlSMl8\",\n    \"outputId\": \"6e489ab2-4abe-4e28-84ca-959f912433a4\"\n   },\n   \"outputs\": [\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"c12ea43372ac4d57bb9605f1a429b397\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/plain\": [\n       \"VBox(children=(HTML(value='<center> <img\\\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"It seems you are trying to upload a large folder at once. This might take some time and then fail if the folder is too large. For such cases, it is recommended to upload in smaller batches or to use `HfApi().upload_large_folder(...)`/`huggingface-cli upload-large-folder` instead. For more details, check out https://huggingface.co/docs/huggingface_hub/main/en/guides/upload#upload-a-large-folder.\\n\",\n      \"Start hashing 40 files.\\n\",\n      \"Finished hashing 40 files.\\n\",\n      \"Uploading files using Xet Storage..\\n\",\n      \"Uploading...:  87% 1.82G/2.10G [00:23<00:04, 67.3MB/s]Cancellation requested; stopping current tasks.\\n\",\n      \"Traceback (most recent call last):\\n\",\n      \"  File \\\"/usr/local/lib/python3.11/dist-packages/huggingface_hub/_commit_api.py\\\", line 598, in _upload_xet_files\\n\",\n      \"    upload_files(\\n\",\n      \"RuntimeError: Xet Runtime Error: Task cancelled; possible runtime shutdown in progress (task 9 was cancelled).\\n\",\n      \"\\n\",\n      \"During handling of the above exception, another exception occurred:\\n\",\n      \"\\n\",\n      \"Traceback (most recent call last):\\n\",\n      \"  File \\\"/usr/local/bin/huggingface-cli\\\", line 8, in <module>\\n\",\n      \"    sys.exit(main())\\n\",\n      \"             ^^^^^^\\n\",\n      \"  File \\\"/usr/local/lib/python3.11/dist-packages/huggingface_hub/commands/huggingface_cli.py\\\", line 57, in main\\n\",\n      \"    service.run()\\n\",\n      \"  File \\\"/usr/local/lib/python3.11/dist-packages/huggingface_hub/commands/upload.py\\\", line 207, in run\\n\",\n      \"    print(self._upload())\\n\",\n      \"          ^^^^^^^^^^^^^^\\n\",\n      \"  File \\\"/usr/local/lib/python3.11/dist-packages/huggingface_hub/commands/upload.py\\\", line 302, in _upload\\n\",\n      \"    return self.api.upload_folder(\\n\",\n      \"           ^^^^^^^^^^^^^^^^^^^^^^^\\n\",\n      \"  File \\\"/usr/local/lib/python3.11/dist-packages/huggingface_hub/utils/_validators.py\\\", line 114, in _inner_fn\\n\",\n      \"    return fn(*args, **kwargs)\\n\",\n      \"           ^^^^^^^^^^^^^^^^^^^\\n\",\n      \"  File \\\"/usr/local/lib/python3.11/dist-packages/huggingface_hub/hf_api.py\\\", line 1633, in _inner\\n\",\n      \"    return fn(self, *args, **kwargs)\\n\",\n      \"           ^^^^^^^^^^^^^^^^^^^^^^^^^\\n\",\n      \"  File \\\"/usr/local/lib/python3.11/dist-packages/huggingface_hub/hf_api.py\\\", line 4942, in upload_folder\\n\",\n      \"    commit_info = self.create_commit(\\n\",\n      \"                  ^^^^^^^^^^^^^^^^^^^\\n\",\n      \"  File \\\"/usr/local/lib/python3.11/dist-packages/huggingface_hub/utils/_validators.py\\\", line 114, in _inner_fn\\n\",\n      \"    return fn(*args, **kwargs)\\n\",\n      \"           ^^^^^^^^^^^^^^^^^^^\\n\",\n      \"  File \\\"/usr/local/lib/python3.11/dist-packages/huggingface_hub/hf_api.py\\\", line 1633, in _inner\\n\",\n      \"    return fn(self, *args, **kwargs)\\n\",\n      \"           ^^^^^^^^^^^^^^^^^^^^^^^^^\\n\",\n      \"  File \\\"/usr/local/lib/python3.11/dist-packages/huggingface_hub/hf_api.py\\\", line 4202, in create_commit\\n\",\n      \"    self.preupload_lfs_files(\\n\",\n      \"  File \\\"/usr/local/lib/python3.11/dist-packages/huggingface_hub/hf_api.py\\\", line 4483, in preupload_lfs_files\\n\",\n      \"    _upload_xet_files(**upload_kwargs, create_pr=create_pr)  # type: ignore [arg-type]\\n\",\n      \"    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\\n\",\n      \"  File \\\"/usr/local/lib/python3.11/dist-packages/huggingface_hub/utils/_validators.py\\\", line 114, in _inner_fn\\n\",\n      \"    return fn(*args, **kwargs)\\n\",\n      \"           ^^^^^^^^^^^^^^^^^^^\\n\",\n      \"  File \\\"/usr/local/lib/python3.11/dist-packages/huggingface_hub/_commit_api.py\\\", line 592, in _upload_xet_files\\n\",\n      \"    with progress_cm as progress:\\n\",\n      \"  File \\\"/usr/local/lib/python3.11/dist-packages/tqdm/std.py\\\", line 1138, in __exit__\\n\",\n      \"    def __exit__(self, exc_type, exc_value, traceback):\\n\",\n      \"\\n\",\n      \"KeyboardInterrupt\\n\",\n      \"^C\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"from huggingface_hub import notebook_login\\n\",\n    \"\\n\",\n    \"# remove the partial epoch checkpoints\\n\",\n    \"!rm -rf \\\"./outputs/qwen-sft-pirate-rrr/checkpoint-*\\\"\\n\",\n    \"\\n\",\n    \"# HF Notebook login widget\\n\",\n    \"notebook_login()\\n\",\n    \"\\n\",\n    \"# upload the LoRA adapter for your model to HF, remember to update the username/model-name below\\n\",\n    \"!huggingface-cli upload --repo-type=model winglian/pirate-qwen-14B \\\"./outputs/qwen-sft-pirate-rrr\\\"\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"accelerator\": \"GPU\",\n  \"colab\": {\n   \"gpuType\": \"T4\",\n   \"provenance\": []\n  },\n  \"kernelspec\": {\n   \"display_name\": \"Python 3\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"name\": \"python\"\n  },\n  \"widgets\": {\n   \"application/vnd.jupyter.widget-state+json\": {\n    \"00321cce58884f6f9b3855a21fcd9187\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"DescriptionStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"DescriptionStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"description_width\": \"\"\n     }\n    },\n    \"004d9177a6a14118a5930dc3cc13147b\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HBoxModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HBoxModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HBoxView\",\n      \"box_style\": \"\",\n      \"children\": [\n       \"IPY_MODEL_a80410b919e442c49aea15acc1ce1a72\",\n       \"IPY_MODEL_c6e00f5224364822bc4239b176686919\",\n       \"IPY_MODEL_ec11d1e5ae7b42c883d9b1f38a65356e\"\n      ],\n      \"layout\": \"IPY_MODEL_734185351eb543fa9a00a881dcbb9fe7\"\n     }\n    },\n    \"0077aedc3d174560bce924ee89e9c006\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"03a3c744d716431488163b4358b80f92\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"03b093d592ba4386aa61f7b8483da660\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HBoxModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HBoxModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HBoxView\",\n      \"box_style\": \"\",\n      \"children\": [\n       \"IPY_MODEL_b8766a88716948cf968f4563531a76d9\",\n       \"IPY_MODEL_6f3a28b912714c6e931003549664bfa3\",\n       \"IPY_MODEL_16d1283741404b7bb319094c992fce01\"\n      ],\n      \"layout\": \"IPY_MODEL_2a5bb0e818ab47be8cf6465988328503\"\n     }\n    },\n    \"042e091f75694c47aee761e760e76773\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"DescriptionStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"DescriptionStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"description_width\": \"\"\n     }\n    },\n    \"0546d04aae644dde846c58a4afb598a6\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"054c8dffadba48c6b895a6cc62448ecc\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"ProgressStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"ProgressStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"bar_color\": null,\n      \"description_width\": \"\"\n     }\n    },\n    \"07fb3a2c8315494e97b447e672dfae06\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HTMLModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HTMLModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HTMLView\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_12815f401eba44658caa7b2e490137a8\",\n      \"placeholder\": \"​\",\n      \"style\": \"IPY_MODEL_30e02aa2d0d241979369e598287f2639\",\n      \"value\": \"Drop Samples with Zero Trainable Tokens (num_proc=2): 100%\"\n     }\n    },\n    \"083f9cda8d754c168beee10d2f8955a2\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"FloatProgressModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"FloatProgressModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"ProgressView\",\n      \"bar_style\": \"success\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_a0a11e929edd4189b79723d618522c33\",\n      \"max\": 728,\n      \"min\": 0,\n      \"orientation\": \"horizontal\",\n      \"style\": \"IPY_MODEL_e87ea87fcff247b5bbcc331ba79a8dc2\",\n      \"value\": 728\n     }\n    },\n    \"09007681cf8d42aeb8c1d2f6a74e470a\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HTMLModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HTMLModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HTMLView\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_b195f160ca20442fadd8b5aed0ee41af\",\n      \"placeholder\": \"​\",\n      \"style\": \"IPY_MODEL_ca65e32eb52f48c09a84b33cb18f22cd\",\n      \"value\": \" 11.4M/11.4M [00:00&lt;00:00, 21.8MB/s]\"\n     }\n    },\n    \"0a46ad75c198463d843fb35e813642cb\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"FloatProgressModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"FloatProgressModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"ProgressView\",\n      \"bar_style\": \"success\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_b8e39e4dddc3497fbc29ae45c66da759\",\n      \"max\": 11422654,\n      \"min\": 0,\n      \"orientation\": \"horizontal\",\n      \"style\": \"IPY_MODEL_63b4e563e85c4f03b1b72beda9577bcc\",\n      \"value\": 11422654\n     }\n    },\n    \"0aa8ab56b85f4171a79c3bc210594025\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"ProgressStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"ProgressStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"bar_color\": null,\n      \"description_width\": \"\"\n     }\n    },\n    \"0b4c9753a7cb4354b8e5f187e6e1ad7c\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"DescriptionStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"DescriptionStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"description_width\": \"\"\n     }\n    },\n    \"0cd7efffbb3c4c4b972e63749f61ab97\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"DescriptionStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"DescriptionStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"description_width\": \"\"\n     }\n    },\n    \"0dea5caa27384f5689e3cab51f558727\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"0e067d8db8ed48308a718d5f57683fd1\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HTMLModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HTMLModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HTMLView\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_b1bea589efa14258a9982071b87938bf\",\n      \"placeholder\": \"​\",\n      \"style\": \"IPY_MODEL_590eef89881545aa8bbef9a8bbe7fb00\",\n      \"value\": \"\\n<b>Pro Tip:</b> If you don't already have one, you can create a dedicated\\n'notebooks' token with 'write' access, that you can then easily reuse for all\\nnotebooks. </center>\"\n     }\n    },\n    \"0e50870ed0c643e0b6c18cc5d7ddae7f\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HTMLModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HTMLModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HTMLView\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_bfcdbba993b74972a9e3e575f86908ff\",\n      \"placeholder\": \"​\",\n      \"style\": \"IPY_MODEL_6ebb2ec171414e47a14765505f64bb3c\",\n      \"value\": \" 3.84G/3.84G [00:09&lt;00:00, 664MB/s]\"\n     }\n    },\n    \"0e936d9dbf9c4fdd86bbfe9730dedc47\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"DescriptionStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"DescriptionStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"description_width\": \"\"\n     }\n    },\n    \"0f417447a7bd4a33acca96fa37aec877\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"ProgressStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"ProgressStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"bar_color\": null,\n      \"description_width\": \"\"\n     }\n    },\n    \"0f480e3a0b0a45d2a2d2dec3cad923f3\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"0f6907ebbc6242c8bde059cef1e1bd29\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HBoxModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HBoxModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HBoxView\",\n      \"box_style\": \"\",\n      \"children\": [\n       \"IPY_MODEL_5bdfd87fc6cd4f9dabef7cfee29c8060\",\n       \"IPY_MODEL_64f54d4a744a4627a07c3c0120276f3b\",\n       \"IPY_MODEL_65b75b9b8bc143cf997796af68ff6668\"\n      ],\n      \"layout\": \"IPY_MODEL_d6fe74e4255444368f8f90a62157d869\"\n     }\n    },\n    \"114dece49dba437c8572ef94b23c3b1e\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"12815f401eba44658caa7b2e490137a8\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"12b56912736849fea2ad8124456fdc5c\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"FloatProgressModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"FloatProgressModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"ProgressView\",\n      \"bar_style\": \"success\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_97e36007e1304e1583fd81bfb13f0edd\",\n      \"max\": 1671853,\n      \"min\": 0,\n      \"orientation\": \"horizontal\",\n      \"style\": \"IPY_MODEL_c65dc74c7d6f4bab8f7dd28455161dd8\",\n      \"value\": 1671853\n     }\n    },\n    \"131065f118274a1586ac38e39ed84ef0\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": \"center\",\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": \"flex\",\n      \"flex\": null,\n      \"flex_flow\": \"column\",\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": \"50%\"\n     }\n    },\n    \"158c8b85dbf34de6a94b4e35e2fc7d5a\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"16a188a0b06d45f980dcf3933509fe0a\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HTMLModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HTMLModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HTMLView\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_349eee9f56d64f0cba6fc24ff2c50c9b\",\n      \"placeholder\": \"​\",\n      \"style\": \"IPY_MODEL_7e5d3774060e4589aa65982da5ea4ef4\",\n      \"value\": \" 9985/9985 [00:04&lt;00:00, 2604.11 examples/s]\"\n     }\n    },\n    \"16d1283741404b7bb319094c992fce01\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HTMLModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HTMLModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HTMLView\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_a4e5789584564049b83df7c6c54a3e08\",\n      \"placeholder\": \"​\",\n      \"style\": \"IPY_MODEL_ff3a94b146a948b6907f5d80c7157f99\",\n      \"value\": \" 9985/0 [00:00&lt;00:00, 50763.46 examples/s]\"\n     }\n    },\n    \"1811cda0644e4190a9469d1774435d82\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"18357b321ce44d7b8bd9d1c886f69275\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HTMLModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HTMLModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HTMLView\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_e366ae3fceec4566b9ed303d6c5f90af\",\n      \"placeholder\": \"​\",\n      \"style\": \"IPY_MODEL_5dd7d150dbe04f08b165ce7f2c27cd11\",\n      \"value\": \"model-00008-of-00008.safetensors: 100%\"\n     }\n    },\n    \"19127c7bb1554ccbac877059f9a82db0\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"FloatProgressModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"FloatProgressModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"ProgressView\",\n      \"bar_style\": \"danger\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_e400cbf14bcc446a9d33b210cd93550b\",\n      \"max\": 3963750880,\n      \"min\": 0,\n      \"orientation\": \"horizontal\",\n      \"style\": \"IPY_MODEL_71002199df6b40c9a1ac40df5fb27a1b\",\n      \"value\": 3963750502\n     }\n    },\n    \"19c1e38389fa46c7b7e2152a56e1df34\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"ButtonModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"ButtonModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"ButtonView\",\n      \"button_style\": \"\",\n      \"description\": \"Login\",\n      \"disabled\": false,\n      \"icon\": \"\",\n      \"layout\": \"IPY_MODEL_835bcc28a5564fb9b3d651bc8e32dc46\",\n      \"style\": \"IPY_MODEL_9f1c9a0695384bdaa6f8b847ef89bee8\",\n      \"tooltip\": \"\"\n     }\n    },\n    \"1bec6297c90242a88672d195bc09d429\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"1c6f1f10667545aaab958016ba7e2c94\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"1d5117195d4b49eb8f1a73b18419f7ce\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HTMLModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HTMLModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HTMLView\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_0dea5caa27384f5689e3cab51f558727\",\n      \"placeholder\": \"​\",\n      \"style\": \"IPY_MODEL_a6f48410b9964fefba0c3009a77dc838\",\n      \"value\": \" 9.68k/9.68k [00:00&lt;00:00, 812kB/s]\"\n     }\n    },\n    \"1f7d30f71bbd4547a9150d21da071055\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"200df5e79b9244849e589ecb0250a520\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HTMLModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HTMLModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HTMLView\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_f4a1795dc7514a718f478245f521f0ba\",\n      \"placeholder\": \"​\",\n      \"style\": \"IPY_MODEL_5e746eb25bbe416fb585fa24e79f5177\",\n      \"value\": \"model-00002-of-00008.safetensors: 100%\"\n     }\n    },\n    \"20352e5f58d24bb8b1f3940efd14fe4a\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"253017b0d0534e54ab44e181f6d7c82d\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HTMLModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HTMLModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HTMLView\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_1c6f1f10667545aaab958016ba7e2c94\",\n      \"placeholder\": \"​\",\n      \"style\": \"IPY_MODEL_e6e969610738449887259063967f82b0\",\n      \"value\": \" 2.78M/2.78M [00:00&lt;00:00, 17.8MB/s]\"\n     }\n    },\n    \"258b7c635c1045329d4669e48c46ccd5\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HBoxModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HBoxModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HBoxView\",\n      \"box_style\": \"\",\n      \"children\": [\n       \"IPY_MODEL_6f68ed9889f54ad2ae8a3b95ac263a83\",\n       \"IPY_MODEL_80366349d81e4dcc892db6cd56e384f3\",\n       \"IPY_MODEL_c73055099c084dca996159e23e162d0b\"\n      ],\n      \"layout\": \"IPY_MODEL_977f799afaac4a55b2dc1cffa7d5b63b\"\n     }\n    },\n    \"279937fe03bc4e4eb25b472d7e9df163\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"FloatProgressModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"FloatProgressModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"ProgressView\",\n      \"bar_style\": \"danger\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_b634bb73cfa743d09a5999101b840976\",\n      \"max\": 1912371880,\n      \"min\": 0,\n      \"orientation\": \"horizontal\",\n      \"style\": \"IPY_MODEL_742b1030acfd414bbd9d5327b7e3826d\",\n      \"value\": 1912371698\n     }\n    },\n    \"27beaf06e41b472abdb544a43c720c5a\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"2860e3bb3baf4f7da058465850e800c5\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HBoxModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HBoxModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HBoxView\",\n      \"box_style\": \"\",\n      \"children\": [\n       \"IPY_MODEL_3efd18ea8eaa41918894883da9541bfa\",\n       \"IPY_MODEL_e09f1bcbb9d94c09be53e5e1303642c2\",\n       \"IPY_MODEL_82177df57a494de8900c14c2f5185175\"\n      ],\n      \"layout\": \"IPY_MODEL_ccfcdc95baf646f8aeb3d516742383f2\"\n     }\n    },\n    \"2a51b36be41745468e4c2d7a21b1c0d2\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"2a5bb0e818ab47be8cf6465988328503\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"2b3a2659b12244bd8548320320016dbf\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"2e257c8be2da40b4bb67a9e4ab6811f3\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"2e2b0c1599c341a198f632f46a40c90e\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HTMLModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HTMLModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HTMLView\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_be724f04b03942b2a033a7e8898bb4fd\",\n      \"placeholder\": \"​\",\n      \"style\": \"IPY_MODEL_fcbab4d8dced41a18dfccce81e3a45a0\",\n      \"value\": \"model-00005-of-00008.safetensors: 100%\"\n     }\n    },\n    \"3036608c71904ce9ae4bb2a9fa8802d9\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HTMLModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HTMLModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HTMLView\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_5ca6be24acb548cea130bd58e9954c7c\",\n      \"placeholder\": \"​\",\n      \"style\": \"IPY_MODEL_5cfb02ee044b4011a378efa8b54a370f\",\n      \"value\": \" 3.96G/3.96G [00:10&lt;00:00, 531MB/s]\"\n     }\n    },\n    \"30a81da86f8043eca301e86a8651201a\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"30e02aa2d0d241979369e598287f2639\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"DescriptionStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"DescriptionStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"description_width\": \"\"\n     }\n    },\n    \"3225603166b54e7aab766b9964a2f660\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"ProgressStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"ProgressStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"bar_color\": null,\n      \"description_width\": \"\"\n     }\n    },\n    \"33b3b1d0295646edaac7b4822761aeb0\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"ProgressStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"ProgressStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"bar_color\": null,\n      \"description_width\": \"\"\n     }\n    },\n    \"349eee9f56d64f0cba6fc24ff2c50c9b\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"34c9c0137b504cd799c6bd6de69507c2\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"34cf3df51fbc41cabfdbba153c007f0e\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"35c811d2ae8e43f3b5cecbdd3cfa857f\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"DescriptionStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"DescriptionStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"description_width\": \"\"\n     }\n    },\n    \"35cc989ca3374e7dba0cb166febc4bde\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"ProgressStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"ProgressStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"bar_color\": null,\n      \"description_width\": \"\"\n     }\n    },\n    \"366a343b62fa47d8985a3bd464d99f9e\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"DescriptionStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"DescriptionStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"description_width\": \"\"\n     }\n    },\n    \"37de928300e34184881039378bd75e7f\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"388f618924274d21a066f098f4f1e744\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HBoxModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HBoxModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HBoxView\",\n      \"box_style\": \"\",\n      \"children\": [\n       \"IPY_MODEL_7c95f85a2b1f47a1bd846d110c47bb3c\",\n       \"IPY_MODEL_083f9cda8d754c168beee10d2f8955a2\",\n       \"IPY_MODEL_62e1a65582f446a78612eaa804e08a7d\"\n      ],\n      \"layout\": \"IPY_MODEL_487a177d020f4605834878b2fdc7afa3\"\n     }\n    },\n    \"39789237703c4a418134243055c9cbf5\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"3aaecbf540f54a2db9ab0931e3b1fe57\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"3c21e4a511b4441192c03b7f1d0976e9\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"3efd18ea8eaa41918894883da9541bfa\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HTMLModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HTMLModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HTMLView\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_8f5bd719974e41c3a8dd9a5b0d3d71e6\",\n      \"placeholder\": \"​\",\n      \"style\": \"IPY_MODEL_b87c84de30e84b3abf4871461fb9cbd3\",\n      \"value\": \"Loading checkpoint shards: 100%\"\n     }\n    },\n    \"41f3b32c2f6b4034ae7a3b9124e28bc7\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"4471ff62258549fba9514bb67050f965\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HBoxModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HBoxModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HBoxView\",\n      \"box_style\": \"\",\n      \"children\": [\n       \"IPY_MODEL_9cd5211b5d8b457aa0002f1d17b80028\",\n       \"IPY_MODEL_19127c7bb1554ccbac877059f9a82db0\",\n       \"IPY_MODEL_f4667818b9d34a09891cd727a429a610\"\n      ],\n      \"layout\": \"IPY_MODEL_9ed02dc43412471a9ab47f3620ccf3a5\"\n     }\n    },\n    \"4540927d98f54466b434ba4c0edf045d\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"DescriptionStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"DescriptionStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"description_width\": \"\"\n     }\n    },\n    \"487a177d020f4605834878b2fdc7afa3\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"4b1f04ff63d14a118fdd15814dff50e4\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"LabelModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"LabelModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"LabelView\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_39789237703c4a418134243055c9cbf5\",\n      \"placeholder\": \"​\",\n      \"style\": \"IPY_MODEL_a3a945817f684328b34651fe052393ec\",\n      \"value\": \"Connecting...\"\n     }\n    },\n    \"4b27c267393640f28f6eae0875bd2ed9\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"4c727d40ef0443449afc31724ee79f0c\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"ProgressStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"ProgressStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"bar_color\": null,\n      \"description_width\": \"\"\n     }\n    },\n    \"4d05314858354e729d76094b3b0ce761\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HBoxModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HBoxModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HBoxView\",\n      \"box_style\": \"\",\n      \"children\": [\n       \"IPY_MODEL_c42acf646f344a88b8c11f81e67f7206\",\n       \"IPY_MODEL_7be6f04c284e4326bb4ff3d301e7b3c6\",\n       \"IPY_MODEL_ffdbb12a2f2c4d14911685e7683e0ef0\"\n      ],\n      \"layout\": \"IPY_MODEL_bee3501b2a17427784a717e50a85e7fa\"\n     }\n    },\n    \"4d468f96ec924681ad65eb671674b93e\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"4f1977d7e4824ef1a14b65f0f42bba10\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"ProgressStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"ProgressStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"bar_color\": null,\n      \"description_width\": \"\"\n     }\n    },\n    \"4fd114abe9f5494ab59858949f5055f1\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"ProgressStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"ProgressStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"bar_color\": null,\n      \"description_width\": \"\"\n     }\n    },\n    \"500e272208a246089613bf788a165271\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HBoxModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HBoxModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HBoxView\",\n      \"box_style\": \"\",\n      \"children\": [\n       \"IPY_MODEL_200df5e79b9244849e589ecb0250a520\",\n       \"IPY_MODEL_cc94432d08464affa3e58b560bdad194\",\n       \"IPY_MODEL_3036608c71904ce9ae4bb2a9fa8802d9\"\n      ],\n      \"layout\": \"IPY_MODEL_adacfdcc1b0140efac56918e9ccf064e\"\n     }\n    },\n    \"519a7b154022443db6703f04a9142bae\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"FloatProgressModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"FloatProgressModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"ProgressView\",\n      \"bar_style\": \"success\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_d02274afd47b462291c745f261209d42\",\n      \"max\": 27341251,\n      \"min\": 0,\n      \"orientation\": \"horizontal\",\n      \"style\": \"IPY_MODEL_0f417447a7bd4a33acca96fa37aec877\",\n      \"value\": 27341251\n     }\n    },\n    \"56e3768bef5a4b9db4168c5c17f509c2\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"DescriptionStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"DescriptionStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"description_width\": \"\"\n     }\n    },\n    \"590eef89881545aa8bbef9a8bbe7fb00\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"DescriptionStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"DescriptionStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"description_width\": \"\"\n     }\n    },\n    \"598da69727bd4fb8b1caf465ac736d7a\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"DescriptionStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"DescriptionStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"description_width\": \"\"\n     }\n    },\n    \"5bdfd87fc6cd4f9dabef7cfee29c8060\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HTMLModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HTMLModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HTMLView\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_4d468f96ec924681ad65eb671674b93e\",\n      \"placeholder\": \"​\",\n      \"style\": \"IPY_MODEL_ad7599de524549c48bf2d3124ad4b299\",\n      \"value\": \"Dropping Long Sequences (num_proc=2): 100%\"\n     }\n    },\n    \"5ca240f31e6b44e3882c5eb37cd5a309\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": \"20px\"\n     }\n    },\n    \"5ca6be24acb548cea130bd58e9954c7c\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"5cea7996f02040b187ece0bb2d6a8d1f\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"DescriptionStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"DescriptionStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"description_width\": \"\"\n     }\n    },\n    \"5cfb02ee044b4011a378efa8b54a370f\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"DescriptionStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"DescriptionStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"description_width\": \"\"\n     }\n    },\n    \"5dd7d150dbe04f08b165ce7f2c27cd11\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"DescriptionStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"DescriptionStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"description_width\": \"\"\n     }\n    },\n    \"5e18768f7ad6434ba8b8b8a2e853e204\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"5e5e15b0569b474c9620083b3ec6af55\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"DescriptionStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"DescriptionStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"description_width\": \"\"\n     }\n    },\n    \"5e746eb25bbe416fb585fa24e79f5177\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"DescriptionStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"DescriptionStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"description_width\": \"\"\n     }\n    },\n    \"5eb06edeb58e4930b1affef2a59eae81\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"ProgressStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"ProgressStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"bar_color\": null,\n      \"description_width\": \"\"\n     }\n    },\n    \"5f86cd894de94c3280fadc1e2fd0ee13\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HBoxModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HBoxModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HBoxView\",\n      \"box_style\": \"\",\n      \"children\": [\n       \"IPY_MODEL_a20927bf5f2c41f58c1e31ac858ab36c\",\n       \"IPY_MODEL_0a46ad75c198463d843fb35e813642cb\",\n       \"IPY_MODEL_09007681cf8d42aeb8c1d2f6a74e470a\"\n      ],\n      \"layout\": \"IPY_MODEL_ebc80d1a55fa47f4a5ea2756588569ec\"\n     }\n    },\n    \"60c1a0d765c14a1d888317e6a507e4ea\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"62c028fdef904dedb9cdeca2b3bda725\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"62e1a65582f446a78612eaa804e08a7d\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HTMLModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HTMLModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HTMLView\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_5e18768f7ad6434ba8b8b8a2e853e204\",\n      \"placeholder\": \"​\",\n      \"style\": \"IPY_MODEL_bb33aec33a6447078c31bfd728942994\",\n      \"value\": \" 728/728 [00:00&lt;00:00, 20.3kB/s]\"\n     }\n    },\n    \"62e302ebdad64aada0ffe64ae1c873f3\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"63580b6fb30642479fe3000915bf551a\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"63b4e563e85c4f03b1b72beda9577bcc\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"ProgressStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"ProgressStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"bar_color\": null,\n      \"description_width\": \"\"\n     }\n    },\n    \"64f54d4a744a4627a07c3c0120276f3b\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"FloatProgressModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"FloatProgressModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"ProgressView\",\n      \"bar_style\": \"success\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_0546d04aae644dde846c58a4afb598a6\",\n      \"max\": 9985,\n      \"min\": 0,\n      \"orientation\": \"horizontal\",\n      \"style\": \"IPY_MODEL_897b77a56c09479bb11d7f2a30997e55\",\n      \"value\": 9985\n     }\n    },\n    \"65b75b9b8bc143cf997796af68ff6668\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HTMLModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HTMLModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HTMLView\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_81c3db71ac704280ad030072655f1537\",\n      \"placeholder\": \"​\",\n      \"style\": \"IPY_MODEL_042e091f75694c47aee761e760e76773\",\n      \"value\": \" 9985/9985 [00:02&lt;00:00, 3977.47 examples/s]\"\n     }\n    },\n    \"67da6c4260574869aa24c3cbc1bc1654\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"6932489232ec4ab18a160b1e7fbcdfe1\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"6ebb2ec171414e47a14765505f64bb3c\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"DescriptionStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"DescriptionStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"description_width\": \"\"\n     }\n    },\n    \"6f05e9bebf7b40c9835808e77de6c236\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"PasswordModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"PasswordModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"PasswordView\",\n      \"continuous_update\": true,\n      \"description\": \"Token:\",\n      \"description_tooltip\": null,\n      \"disabled\": false,\n      \"layout\": \"IPY_MODEL_2e257c8be2da40b4bb67a9e4ab6811f3\",\n      \"placeholder\": \"​\",\n      \"style\": \"IPY_MODEL_56e3768bef5a4b9db4168c5c17f509c2\",\n      \"value\": \"\"\n     }\n    },\n    \"6f3a28b912714c6e931003549664bfa3\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"FloatProgressModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"FloatProgressModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"ProgressView\",\n      \"bar_style\": \"success\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_5ca240f31e6b44e3882c5eb37cd5a309\",\n      \"max\": 1,\n      \"min\": 0,\n      \"orientation\": \"horizontal\",\n      \"style\": \"IPY_MODEL_5eb06edeb58e4930b1affef2a59eae81\",\n      \"value\": 1\n     }\n    },\n    \"6f68ed9889f54ad2ae8a3b95ac263a83\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HTMLModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HTMLModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HTMLView\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_41f3b32c2f6b4034ae7a3b9124e28bc7\",\n      \"placeholder\": \"​\",\n      \"style\": \"IPY_MODEL_a10d0a76010f4e508c65a9b69ebc5156\",\n      \"value\": \"Tokenizing Prompts (num_proc=2): 100%\"\n     }\n    },\n    \"704f2f5a9b1c49d5a75a0025a5dda11b\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"DescriptionStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"DescriptionStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"description_width\": \"\"\n     }\n    },\n    \"71002199df6b40c9a1ac40df5fb27a1b\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"ProgressStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"ProgressStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"bar_color\": null,\n      \"description_width\": \"\"\n     }\n    },\n    \"71c8af139cd248b1b51101fd46a93f35\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"FloatProgressModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"FloatProgressModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"ProgressView\",\n      \"bar_style\": \"success\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_d0e9dce55cec4c1ca619a0ccf209d924\",\n      \"max\": 9675,\n      \"min\": 0,\n      \"orientation\": \"horizontal\",\n      \"style\": \"IPY_MODEL_4c727d40ef0443449afc31724ee79f0c\",\n      \"value\": 9675\n     }\n    },\n    \"734185351eb543fa9a00a881dcbb9fe7\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"735d4f225b24414294fc1b213c61223c\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"742b1030acfd414bbd9d5327b7e3826d\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"ProgressStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"ProgressStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"bar_color\": null,\n      \"description_width\": \"\"\n     }\n    },\n    \"77304d1a46b3468a98483e02ec0ac4a4\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"7baeab52d6694c32b1efd1ea1a0a7782\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HTMLModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HTMLModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HTMLView\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_93a44a11aa4846fa8efc6c1413ef1627\",\n      \"placeholder\": \"​\",\n      \"style\": \"IPY_MODEL_a55060adc3564407ac81ad7297d34aaa\",\n      \"value\": \"train.jsonl: 100%\"\n     }\n    },\n    \"7be6f04c284e4326bb4ff3d301e7b3c6\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"FloatProgressModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"FloatProgressModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"ProgressView\",\n      \"bar_style\": \"danger\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_9503a45960984adc97b58e16c50662e0\",\n      \"max\": 3963750880,\n      \"min\": 0,\n      \"orientation\": \"horizontal\",\n      \"style\": \"IPY_MODEL_da6e93f3e4984780b930fe7a706983ea\",\n      \"value\": 3963750502\n     }\n    },\n    \"7c2485c6cdfe463da6fdb35982a1070d\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HBoxModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HBoxModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HBoxView\",\n      \"box_style\": \"\",\n      \"children\": [\n       \"IPY_MODEL_ad1236893754446881e153adc9d5c962\",\n       \"IPY_MODEL_daee63fd167e4441a32324b51b00ad2b\",\n       \"IPY_MODEL_fe41858c6bd04c58840112b67c19a336\"\n      ],\n      \"layout\": \"IPY_MODEL_d262c82138024169b9f3aa034ca756fa\"\n     }\n    },\n    \"7c95f85a2b1f47a1bd846d110c47bb3c\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HTMLModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HTMLModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HTMLView\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_7fd44cf9ca6e4726bfd7ac21846d6a14\",\n      \"placeholder\": \"​\",\n      \"style\": \"IPY_MODEL_366a343b62fa47d8985a3bd464d99f9e\",\n      \"value\": \"config.json: 100%\"\n     }\n    },\n    \"7cd0b85ebd204b7aba908417811ce4e0\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HBoxModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HBoxModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HBoxView\",\n      \"box_style\": \"\",\n      \"children\": [\n       \"IPY_MODEL_7baeab52d6694c32b1efd1ea1a0a7782\",\n       \"IPY_MODEL_519a7b154022443db6703f04a9142bae\",\n       \"IPY_MODEL_d4183e9715f34d249942b8271cca3bdf\"\n      ],\n      \"layout\": \"IPY_MODEL_da2347ac94764a3fa2743343cf0d3cd2\"\n     }\n    },\n    \"7e5d3774060e4589aa65982da5ea4ef4\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"DescriptionStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"DescriptionStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"description_width\": \"\"\n     }\n    },\n    \"7fd44cf9ca6e4726bfd7ac21846d6a14\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"80366349d81e4dcc892db6cd56e384f3\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"FloatProgressModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"FloatProgressModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"ProgressView\",\n      \"bar_style\": \"success\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_f8ef805b776145c3bfa9ba8d90972058\",\n      \"max\": 9985,\n      \"min\": 0,\n      \"orientation\": \"horizontal\",\n      \"style\": \"IPY_MODEL_cc587493c33c4f118d1b1170f85be24c\",\n      \"value\": 9985\n     }\n    },\n    \"813621384dc748b0ad06775e22761c0b\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"DescriptionStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"DescriptionStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"description_width\": \"\"\n     }\n    },\n    \"81c3db71ac704280ad030072655f1537\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"82177df57a494de8900c14c2f5185175\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HTMLModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HTMLModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HTMLView\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_67da6c4260574869aa24c3cbc1bc1654\",\n      \"placeholder\": \"​\",\n      \"style\": \"IPY_MODEL_94b9088614464f60a203de39dbcae853\",\n      \"value\": \" 8/8 [01:47&lt;00:00, 11.64s/it]\"\n     }\n    },\n    \"823f1c78f15043e38bbd4dca3932a86a\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"FloatProgressModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"FloatProgressModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"ProgressView\",\n      \"bar_style\": \"success\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_03a3c744d716431488163b4358b80f92\",\n      \"max\": 239,\n      \"min\": 0,\n      \"orientation\": \"horizontal\",\n      \"style\": \"IPY_MODEL_a5434ee714f9498d83870544b67c0cb7\",\n      \"value\": 239\n     }\n    },\n    \"835bcc28a5564fb9b3d651bc8e32dc46\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"8640ac440fbc4644b9a3af7ba3ae7183\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"86816687746246b4a6105e8010384e25\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HTMLModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HTMLModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HTMLView\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_8640ac440fbc4644b9a3af7ba3ae7183\",\n      \"placeholder\": \"​\",\n      \"style\": \"IPY_MODEL_5cea7996f02040b187ece0bb2d6a8d1f\",\n      \"value\": \"<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.svg\\nalt='Hugging Face'> <br> Copy a token from <a\\nhref=\\\"https://huggingface.co/settings/tokens\\\" target=\\\"_blank\\\">your Hugging Face\\ntokens page</a> and paste it below. <br> Immediately click login after copying\\nyour token or it might be stored in plain text in this notebook file. </center>\"\n     }\n    },\n    \"879c8ab5873847a8833bd74123be90a4\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HTMLModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HTMLModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HTMLView\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_ef223e8504b64e3592589880326aaf41\",\n      \"placeholder\": \"​\",\n      \"style\": \"IPY_MODEL_598da69727bd4fb8b1caf465ac736d7a\",\n      \"value\": \" 1.67M/1.67M [00:00&lt;00:00, 19.0MB/s]\"\n     }\n    },\n    \"897b77a56c09479bb11d7f2a30997e55\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"ProgressStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"ProgressStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"bar_color\": null,\n      \"description_width\": \"\"\n     }\n    },\n    \"8bc9d8ba866c442b9118d9630009939c\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"8c4d4fc5a30f4e7cb3be53fe2adda33d\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"8f5bd719974e41c3a8dd9a5b0d3d71e6\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"8f726dbfb45d4528afa33e36a6313267\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"DescriptionStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"DescriptionStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"description_width\": \"\"\n     }\n    },\n    \"9327977822be4b1294f80e876552e305\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HTMLModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HTMLModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HTMLView\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_37de928300e34184881039378bd75e7f\",\n      \"placeholder\": \"​\",\n      \"style\": \"IPY_MODEL_0e936d9dbf9c4fdd86bbfe9730dedc47\",\n      \"value\": \" 3.96G/3.96G [00:13&lt;00:00, 273MB/s]\"\n     }\n    },\n    \"936d04b5fe1b4c63bf0b080e423d051b\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"93a44a11aa4846fa8efc6c1413ef1627\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"94b9088614464f60a203de39dbcae853\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"DescriptionStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"DescriptionStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"description_width\": \"\"\n     }\n    },\n    \"9503a45960984adc97b58e16c50662e0\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"95caff42f08a4c2aa14c867b8f37f231\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HBoxModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HBoxModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HBoxView\",\n      \"box_style\": \"\",\n      \"children\": [\n       \"IPY_MODEL_de7c37ee83e24f0c889e84d07279c2ec\",\n       \"IPY_MODEL_9d4897eefb5f48259ffb2d23e332f752\",\n       \"IPY_MODEL_253017b0d0534e54ab44e181f6d7c82d\"\n      ],\n      \"layout\": \"IPY_MODEL_27beaf06e41b472abdb544a43c720c5a\"\n     }\n    },\n    \"977f799afaac4a55b2dc1cffa7d5b63b\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"97e36007e1304e1583fd81bfb13f0edd\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"9858cb74a09748a39e8149baac96702c\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"DescriptionStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"DescriptionStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"description_width\": \"\"\n     }\n    },\n    \"9b42e08b3c9548818488268768a118b1\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HTMLModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HTMLModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HTMLView\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_d955dcaa0e944e719f3a06139dd54a03\",\n      \"placeholder\": \"​\",\n      \"style\": \"IPY_MODEL_d3de2662c7964f1ba96e58da382af720\",\n      \"value\": \"merges.txt: 100%\"\n     }\n    },\n    \"9cd5211b5d8b457aa0002f1d17b80028\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HTMLModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HTMLModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HTMLView\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_6932489232ec4ab18a160b1e7fbcdfe1\",\n      \"placeholder\": \"​\",\n      \"style\": \"IPY_MODEL_4540927d98f54466b434ba4c0edf045d\",\n      \"value\": \"model-00007-of-00008.safetensors: 100%\"\n     }\n    },\n    \"9d4897eefb5f48259ffb2d23e332f752\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"FloatProgressModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"FloatProgressModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"ProgressView\",\n      \"bar_style\": \"success\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_30a81da86f8043eca301e86a8651201a\",\n      \"max\": 2776833,\n      \"min\": 0,\n      \"orientation\": \"horizontal\",\n      \"style\": \"IPY_MODEL_e8b7a81040904c1e89e58978223b1737\",\n      \"value\": 2776833\n     }\n    },\n    \"9e333ed3b5014069ac1dd969255dd591\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"DescriptionStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"DescriptionStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"description_width\": \"\"\n     }\n    },\n    \"9ed02dc43412471a9ab47f3620ccf3a5\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"9f1c9a0695384bdaa6f8b847ef89bee8\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"ButtonStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"ButtonStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"button_color\": null,\n      \"font_weight\": \"\"\n     }\n    },\n    \"9f56a2d9979c4bd8928c644c22c3ecdf\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"DescriptionStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"DescriptionStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"description_width\": \"\"\n     }\n    },\n    \"a0a11e929edd4189b79723d618522c33\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"a10d0a76010f4e508c65a9b69ebc5156\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"DescriptionStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"DescriptionStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"description_width\": \"\"\n     }\n    },\n    \"a138859f19b74fc0928dc236ab5359db\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HBoxModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HBoxModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HBoxView\",\n      \"box_style\": \"\",\n      \"children\": [\n       \"IPY_MODEL_9b42e08b3c9548818488268768a118b1\",\n       \"IPY_MODEL_12b56912736849fea2ad8124456fdc5c\",\n       \"IPY_MODEL_879c8ab5873847a8833bd74123be90a4\"\n      ],\n      \"layout\": \"IPY_MODEL_20352e5f58d24bb8b1f3940efd14fe4a\"\n     }\n    },\n    \"a1959759c5424da9961fb2a308d4dee4\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HTMLModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HTMLModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HTMLView\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_3aaecbf540f54a2db9ab0931e3b1fe57\",\n      \"placeholder\": \"​\",\n      \"style\": \"IPY_MODEL_9e333ed3b5014069ac1dd969255dd591\",\n      \"value\": \" 239/239 [00:00&lt;00:00, 30.9kB/s]\"\n     }\n    },\n    \"a20927bf5f2c41f58c1e31ac858ab36c\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HTMLModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HTMLModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HTMLView\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_1811cda0644e4190a9469d1774435d82\",\n      \"placeholder\": \"​\",\n      \"style\": \"IPY_MODEL_35c811d2ae8e43f3b5cecbdd3cfa857f\",\n      \"value\": \"tokenizer.json: 100%\"\n     }\n    },\n    \"a3a945817f684328b34651fe052393ec\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"DescriptionStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"DescriptionStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"description_width\": \"\"\n     }\n    },\n    \"a44f630e099e43899f20a77084ae60cd\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HTMLModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HTMLModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HTMLView\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_ed5ca967ad5342929e578ac6aa4dc4c0\",\n      \"placeholder\": \"​\",\n      \"style\": \"IPY_MODEL_af401d117d5047629d3a6e2361757b62\",\n      \"value\": \"model-00001-of-00008.safetensors: 100%\"\n     }\n    },\n    \"a4e5789584564049b83df7c6c54a3e08\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"a5434ee714f9498d83870544b67c0cb7\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"ProgressStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"ProgressStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"bar_color\": null,\n      \"description_width\": \"\"\n     }\n    },\n    \"a55060adc3564407ac81ad7297d34aaa\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"DescriptionStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"DescriptionStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"description_width\": \"\"\n     }\n    },\n    \"a6f48410b9964fefba0c3009a77dc838\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"DescriptionStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"DescriptionStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"description_width\": \"\"\n     }\n    },\n    \"a7cf477e80fc43e0ad82c7997b076dce\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"DescriptionStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"DescriptionStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"description_width\": \"\"\n     }\n    },\n    \"a80410b919e442c49aea15acc1ce1a72\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HTMLModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HTMLModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HTMLView\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_fa1282ccc7544e4f818e2f03ccffe4a5\",\n      \"placeholder\": \"​\",\n      \"style\": \"IPY_MODEL_bbbf575d2a4b4c6ea8389be79b2a6039\",\n      \"value\": \"model.safetensors.index.json: 100%\"\n     }\n    },\n    \"ab93eabd7cea4b94b4b7a387f101e8a1\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"ac764024cf1c4e08ba7749afd2cd20ac\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"DescriptionStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"DescriptionStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"description_width\": \"\"\n     }\n    },\n    \"ad1236893754446881e153adc9d5c962\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HTMLModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HTMLModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HTMLView\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_62e302ebdad64aada0ffe64ae1c873f3\",\n      \"placeholder\": \"​\",\n      \"style\": \"IPY_MODEL_bd1b0dfed6d34d16af33a4a58330f5ec\",\n      \"value\": \"Saving the dataset (1/1 shards): 100%\"\n     }\n    },\n    \"ad7599de524549c48bf2d3124ad4b299\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"DescriptionStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"DescriptionStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"description_width\": \"\"\n     }\n    },\n    \"adacfdcc1b0140efac56918e9ccf064e\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"af401d117d5047629d3a6e2361757b62\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"DescriptionStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"DescriptionStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"description_width\": \"\"\n     }\n    },\n    \"b191ac001a2e4962bc9a245fcdf26e6b\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"b195f160ca20442fadd8b5aed0ee41af\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"b1bea589efa14258a9982071b87938bf\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"b5b65414154544aa8a71b1a39164aad7\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"b634bb73cfa743d09a5999101b840976\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"b82aa8c57f7c422a9a9c90f333ed2a99\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HBoxModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HBoxModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HBoxView\",\n      \"box_style\": \"\",\n      \"children\": [\n       \"IPY_MODEL_c0991cf63ee6458b96e9a75e7a88b61a\",\n       \"IPY_MODEL_71c8af139cd248b1b51101fd46a93f35\",\n       \"IPY_MODEL_1d5117195d4b49eb8f1a73b18419f7ce\"\n      ],\n      \"layout\": \"IPY_MODEL_3c21e4a511b4441192c03b7f1d0976e9\"\n     }\n    },\n    \"b8766a88716948cf968f4563531a76d9\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HTMLModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HTMLModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HTMLView\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_2b3a2659b12244bd8548320320016dbf\",\n      \"placeholder\": \"​\",\n      \"style\": \"IPY_MODEL_0cd7efffbb3c4c4b972e63749f61ab97\",\n      \"value\": \"Generating train split: \"\n     }\n    },\n    \"b87c84de30e84b3abf4871461fb9cbd3\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"DescriptionStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"DescriptionStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"description_width\": \"\"\n     }\n    },\n    \"b8e39e4dddc3497fbc29ae45c66da759\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"bb33aec33a6447078c31bfd728942994\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"DescriptionStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"DescriptionStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"description_width\": \"\"\n     }\n    },\n    \"bbbf575d2a4b4c6ea8389be79b2a6039\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"DescriptionStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"DescriptionStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"description_width\": \"\"\n     }\n    },\n    \"bca2c7185b6749fd899c06a2ba4c5e46\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HTMLModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HTMLModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HTMLView\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_0f480e3a0b0a45d2a2d2dec3cad923f3\",\n      \"placeholder\": \"​\",\n      \"style\": \"IPY_MODEL_fcb30372e7404c5d8a1ad4df91e6c7b2\",\n      \"value\": \" 1.91G/1.91G [00:05&lt;00:00, 444MB/s]\"\n     }\n    },\n    \"bd1b0dfed6d34d16af33a4a58330f5ec\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"DescriptionStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"DescriptionStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"description_width\": \"\"\n     }\n    },\n    \"be724f04b03942b2a033a7e8898bb4fd\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"bed8726b8069434687c75452e21f19e5\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"FloatProgressModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"FloatProgressModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"ProgressView\",\n      \"bar_style\": \"success\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_fa864b41586f4a7aa56aeafd1d84eb75\",\n      \"max\": 9985,\n      \"min\": 0,\n      \"orientation\": \"horizontal\",\n      \"style\": \"IPY_MODEL_3225603166b54e7aab766b9964a2f660\",\n      \"value\": 9985\n     }\n    },\n    \"bee3501b2a17427784a717e50a85e7fa\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"bfcdbba993b74972a9e3e575f86908ff\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"bff139df987d4a62abec6456cb27f3d4\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"FloatProgressModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"FloatProgressModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"ProgressView\",\n      \"bar_style\": \"danger\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_c1f9c267ba3f40039cdb5eb3267e8043\",\n      \"max\": 3963750880,\n      \"min\": 0,\n      \"orientation\": \"horizontal\",\n      \"style\": \"IPY_MODEL_33b3b1d0295646edaac7b4822761aeb0\",\n      \"value\": 3963750502\n     }\n    },\n    \"c0892a1881de4eb4bfabc6a68f87ae99\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HTMLModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HTMLModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HTMLView\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_158c8b85dbf34de6a94b4e35e2fc7d5a\",\n      \"placeholder\": \"​\",\n      \"style\": \"IPY_MODEL_0b4c9753a7cb4354b8e5f187e6e1ad7c\",\n      \"value\": \" 3.96G/3.96G [00:15&lt;00:00, 564MB/s]\"\n     }\n    },\n    \"c0991cf63ee6458b96e9a75e7a88b61a\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HTMLModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HTMLModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HTMLView\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_ed28e2e0410d4e0b855467e798e53d66\",\n      \"placeholder\": \"​\",\n      \"style\": \"IPY_MODEL_d93f134f802b4b69b575bdaf07dbd27c\",\n      \"value\": \"tokenizer_config.json: 100%\"\n     }\n    },\n    \"c12ea43372ac4d57bb9605f1a429b397\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"VBoxModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"VBoxModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"VBoxView\",\n      \"box_style\": \"\",\n      \"children\": [],\n      \"layout\": \"IPY_MODEL_131065f118274a1586ac38e39ed84ef0\"\n     }\n    },\n    \"c1314f241a434c41b45d84dc4d3b30f8\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"ProgressStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"ProgressStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"bar_color\": null,\n      \"description_width\": \"\"\n     }\n    },\n    \"c1f9c267ba3f40039cdb5eb3267e8043\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"c33ced495f70464aa4a3a91922090853\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"c3725c7f79fe415fbd1ea336f0cc9cf1\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"FloatProgressModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"FloatProgressModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"ProgressView\",\n      \"bar_style\": \"danger\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_b191ac001a2e4962bc9a245fcdf26e6b\",\n      \"max\": 3841788544,\n      \"min\": 0,\n      \"orientation\": \"horizontal\",\n      \"style\": \"IPY_MODEL_054c8dffadba48c6b895a6cc62448ecc\",\n      \"value\": 3841788178\n     }\n    },\n    \"c3be9109d63c485d9c0ef4f9bc0f9218\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"c42acf646f344a88b8c11f81e67f7206\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HTMLModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HTMLModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HTMLView\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_8bc9d8ba866c442b9118d9630009939c\",\n      \"placeholder\": \"​\",\n      \"style\": \"IPY_MODEL_9f56a2d9979c4bd8928c644c22c3ecdf\",\n      \"value\": \"model-00003-of-00008.safetensors: 100%\"\n     }\n    },\n    \"c6164e05a1914ae48083db9ad7f4ef7c\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"c65dc74c7d6f4bab8f7dd28455161dd8\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"ProgressStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"ProgressStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"bar_color\": null,\n      \"description_width\": \"\"\n     }\n    },\n    \"c6e00f5224364822bc4239b176686919\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"FloatProgressModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"FloatProgressModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"ProgressView\",\n      \"bar_style\": \"success\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_2a51b36be41745468e4c2d7a21b1c0d2\",\n      \"max\": 36514,\n      \"min\": 0,\n      \"orientation\": \"horizontal\",\n      \"style\": \"IPY_MODEL_4fd114abe9f5494ab59858949f5055f1\",\n      \"value\": 36514\n     }\n    },\n    \"c73055099c084dca996159e23e162d0b\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HTMLModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HTMLModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HTMLView\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_e40d1c1ac9494b3bade9858324e7ffdf\",\n      \"placeholder\": \"​\",\n      \"style\": \"IPY_MODEL_d65b6b060d9845779299491ac5599c31\",\n      \"value\": \" 9985/9985 [01:04&lt;00:00, 189.08 examples/s]\"\n     }\n    },\n    \"c7433acd3c4841e6958ae8f7e87b1808\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"CheckboxModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"CheckboxModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"CheckboxView\",\n      \"description\": \"Add token as git credential?\",\n      \"description_tooltip\": null,\n      \"disabled\": false,\n      \"indent\": true,\n      \"layout\": \"IPY_MODEL_62c028fdef904dedb9cdeca2b3bda725\",\n      \"style\": \"IPY_MODEL_a7cf477e80fc43e0ad82c7997b076dce\",\n      \"value\": false\n     }\n    },\n    \"c84cc07789be48aebb322c23d355289e\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HTMLModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HTMLModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HTMLView\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_0077aedc3d174560bce924ee89e9c006\",\n      \"placeholder\": \"​\",\n      \"style\": \"IPY_MODEL_00321cce58884f6f9b3855a21fcd9187\",\n      \"value\": \"Add position_id column (Sample Packing) (num_proc=2): 100%\"\n     }\n    },\n    \"ca65e32eb52f48c09a84b33cb18f22cd\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"DescriptionStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"DescriptionStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"description_width\": \"\"\n     }\n    },\n    \"cc587493c33c4f118d1b1170f85be24c\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"ProgressStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"ProgressStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"bar_color\": null,\n      \"description_width\": \"\"\n     }\n    },\n    \"cc94432d08464affa3e58b560bdad194\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"FloatProgressModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"FloatProgressModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"ProgressView\",\n      \"bar_style\": \"danger\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_b5b65414154544aa8a71b1a39164aad7\",\n      \"max\": 3963750816,\n      \"min\": 0,\n      \"orientation\": \"horizontal\",\n      \"style\": \"IPY_MODEL_f0a58fbd0fca4340890041f99fa2f8c8\",\n      \"value\": 3963750438\n     }\n    },\n    \"ccfcdc95baf646f8aeb3d516742383f2\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"cdebbc55a1164c018546c2ac6f8c620c\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HBoxModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HBoxModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HBoxView\",\n      \"box_style\": \"\",\n      \"children\": [\n       \"IPY_MODEL_a44f630e099e43899f20a77084ae60cd\",\n       \"IPY_MODEL_c3725c7f79fe415fbd1ea336f0cc9cf1\",\n       \"IPY_MODEL_0e50870ed0c643e0b6c18cc5d7ddae7f\"\n      ],\n      \"layout\": \"IPY_MODEL_c33ced495f70464aa4a3a91922090853\"\n     }\n    },\n    \"d02274afd47b462291c745f261209d42\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"d07c8b97d3314f1c852e44bdd40f61ed\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"d0e9dce55cec4c1ca619a0ccf209d924\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"d1f9b10c130542f094c8fd3d1e23b5e9\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"d262c82138024169b9f3aa034ca756fa\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"d3de2662c7964f1ba96e58da382af720\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"DescriptionStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"DescriptionStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"description_width\": \"\"\n     }\n    },\n    \"d4183e9715f34d249942b8271cca3bdf\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HTMLModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HTMLModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HTMLView\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_63580b6fb30642479fe3000915bf551a\",\n      \"placeholder\": \"​\",\n      \"style\": \"IPY_MODEL_8f726dbfb45d4528afa33e36a6313267\",\n      \"value\": \" 27.3M/27.3M [00:00&lt;00:00, 31.0MB/s]\"\n     }\n    },\n    \"d43c6df07ddb466587807d6dbe1ff614\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HTMLModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HTMLModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HTMLView\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_8c4d4fc5a30f4e7cb3be53fe2adda33d\",\n      \"placeholder\": \"​\",\n      \"style\": \"IPY_MODEL_e90658f4bcb642baa78426012f863152\",\n      \"value\": \"model-00004-of-00008.safetensors: 100%\"\n     }\n    },\n    \"d65b6b060d9845779299491ac5599c31\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"DescriptionStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"DescriptionStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"description_width\": \"\"\n     }\n    },\n    \"d6fe74e4255444368f8f90a62157d869\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"d93f134f802b4b69b575bdaf07dbd27c\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"DescriptionStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"DescriptionStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"description_width\": \"\"\n     }\n    },\n    \"d955dcaa0e944e719f3a06139dd54a03\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"da2347ac94764a3fa2743343cf0d3cd2\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"da6e93f3e4984780b930fe7a706983ea\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"ProgressStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"ProgressStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"bar_color\": null,\n      \"description_width\": \"\"\n     }\n    },\n    \"daee63fd167e4441a32324b51b00ad2b\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"FloatProgressModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"FloatProgressModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"ProgressView\",\n      \"bar_style\": \"success\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_d07c8b97d3314f1c852e44bdd40f61ed\",\n      \"max\": 9985,\n      \"min\": 0,\n      \"orientation\": \"horizontal\",\n      \"style\": \"IPY_MODEL_ebb69a2c3d0a4299a484698287b3087c\",\n      \"value\": 9985\n     }\n    },\n    \"dc892a596f6942d7973c616c38f0eebb\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HBoxModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HBoxModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HBoxView\",\n      \"box_style\": \"\",\n      \"children\": [\n       \"IPY_MODEL_c84cc07789be48aebb322c23d355289e\",\n       \"IPY_MODEL_bed8726b8069434687c75452e21f19e5\",\n       \"IPY_MODEL_16a188a0b06d45f980dcf3933509fe0a\"\n      ],\n      \"layout\": \"IPY_MODEL_60c1a0d765c14a1d888317e6a507e4ea\"\n     }\n    },\n    \"dd0e646fad3f4a89ba23b39d162bd8d9\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HBoxModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HBoxModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HBoxView\",\n      \"box_style\": \"\",\n      \"children\": [\n       \"IPY_MODEL_d43c6df07ddb466587807d6dbe1ff614\",\n       \"IPY_MODEL_e0e8b840b8ea4d0d9db09afe99fa287d\",\n       \"IPY_MODEL_9327977822be4b1294f80e876552e305\"\n      ],\n      \"layout\": \"IPY_MODEL_77304d1a46b3468a98483e02ec0ac4a4\"\n     }\n    },\n    \"de7c37ee83e24f0c889e84d07279c2ec\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HTMLModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HTMLModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HTMLView\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_34cf3df51fbc41cabfdbba153c007f0e\",\n      \"placeholder\": \"​\",\n      \"style\": \"IPY_MODEL_ac764024cf1c4e08ba7749afd2cd20ac\",\n      \"value\": \"vocab.json: 100%\"\n     }\n    },\n    \"dfd2a2649b8341ef913207526708aff1\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"e09f1bcbb9d94c09be53e5e1303642c2\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"FloatProgressModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"FloatProgressModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"ProgressView\",\n      \"bar_style\": \"success\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_e7d8e4fe58384e93a106de546068c65e\",\n      \"max\": 8,\n      \"min\": 0,\n      \"orientation\": \"horizontal\",\n      \"style\": \"IPY_MODEL_0aa8ab56b85f4171a79c3bc210594025\",\n      \"value\": 8\n     }\n    },\n    \"e0e8b840b8ea4d0d9db09afe99fa287d\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"FloatProgressModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"FloatProgressModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"ProgressView\",\n      \"bar_style\": \"danger\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_f7434f3e03124a1c938a39af79d7fa59\",\n      \"max\": 3963750880,\n      \"min\": 0,\n      \"orientation\": \"horizontal\",\n      \"style\": \"IPY_MODEL_c1314f241a434c41b45d84dc4d3b30f8\",\n      \"value\": 3963750502\n     }\n    },\n    \"e21e180307e5485cbbe908672fd6639a\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HBoxModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HBoxModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HBoxView\",\n      \"box_style\": \"\",\n      \"children\": [\n       \"IPY_MODEL_2e2b0c1599c341a198f632f46a40c90e\",\n       \"IPY_MODEL_bff139df987d4a62abec6456cb27f3d4\",\n       \"IPY_MODEL_ebe1cc366d324ad59b264c8b3c431441\"\n      ],\n      \"layout\": \"IPY_MODEL_114dece49dba437c8572ef94b23c3b1e\"\n     }\n    },\n    \"e366ae3fceec4566b9ed303d6c5f90af\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"e3fb3fc6afe04b3c9b7ac61809ce78fa\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HTMLModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HTMLModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HTMLView\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_c6164e05a1914ae48083db9ad7f4ef7c\",\n      \"placeholder\": \"​\",\n      \"style\": \"IPY_MODEL_813621384dc748b0ad06775e22761c0b\",\n      \"value\": \" 9985/9985 [00:03&lt;00:00, 3622.89 examples/s]\"\n     }\n    },\n    \"e400cbf14bcc446a9d33b210cd93550b\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"e40d1c1ac9494b3bade9858324e7ffdf\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"e575d87a7efe4ec7b1efde489839d4a6\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"DescriptionStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"DescriptionStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"description_width\": \"\"\n     }\n    },\n    \"e5a82df528bb4e408797a3b6c2758f4a\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"e6e969610738449887259063967f82b0\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"DescriptionStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"DescriptionStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"description_width\": \"\"\n     }\n    },\n    \"e7d8e4fe58384e93a106de546068c65e\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"e87ea87fcff247b5bbcc331ba79a8dc2\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"ProgressStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"ProgressStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"bar_color\": null,\n      \"description_width\": \"\"\n     }\n    },\n    \"e8b7a81040904c1e89e58978223b1737\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"ProgressStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"ProgressStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"bar_color\": null,\n      \"description_width\": \"\"\n     }\n    },\n    \"e90658f4bcb642baa78426012f863152\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"DescriptionStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"DescriptionStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"description_width\": \"\"\n     }\n    },\n    \"eb1c9535e6a546098b760528b2ea387c\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HBoxModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HBoxModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HBoxView\",\n      \"box_style\": \"\",\n      \"children\": [\n       \"IPY_MODEL_18357b321ce44d7b8bd9d1c886f69275\",\n       \"IPY_MODEL_279937fe03bc4e4eb25b472d7e9df163\",\n       \"IPY_MODEL_bca2c7185b6749fd899c06a2ba4c5e46\"\n      ],\n      \"layout\": \"IPY_MODEL_1f7d30f71bbd4547a9150d21da071055\"\n     }\n    },\n    \"ebb69a2c3d0a4299a484698287b3087c\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"ProgressStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"ProgressStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"bar_color\": null,\n      \"description_width\": \"\"\n     }\n    },\n    \"ebc80d1a55fa47f4a5ea2756588569ec\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"ebe1cc366d324ad59b264c8b3c431441\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HTMLModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HTMLModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HTMLView\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_fba7aa824b38467ab3061b226114cdec\",\n      \"placeholder\": \"​\",\n      \"style\": \"IPY_MODEL_f3075dccbd2747b4a7913b66f44f2596\",\n      \"value\": \" 3.96G/3.96G [00:13&lt;00:00, 398MB/s]\"\n     }\n    },\n    \"ec030fc3c346426f9abc3a89892258d3\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"FloatProgressModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"FloatProgressModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"ProgressView\",\n      \"bar_style\": \"success\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_dfd2a2649b8341ef913207526708aff1\",\n      \"max\": 9985,\n      \"min\": 0,\n      \"orientation\": \"horizontal\",\n      \"style\": \"IPY_MODEL_4f1977d7e4824ef1a14b65f0f42bba10\",\n      \"value\": 9985\n     }\n    },\n    \"ec11d1e5ae7b42c883d9b1f38a65356e\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HTMLModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HTMLModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HTMLView\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_936d04b5fe1b4c63bf0b080e423d051b\",\n      \"placeholder\": \"​\",\n      \"style\": \"IPY_MODEL_f1cef8e8dc2646fb9fd09f3b09081074\",\n      \"value\": \" 36.5k/36.5k [00:00&lt;00:00, 4.32MB/s]\"\n     }\n    },\n    \"ed28e2e0410d4e0b855467e798e53d66\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"ed5ca967ad5342929e578ac6aa4dc4c0\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"edc99591b9c747b689b94d0052fec14c\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"ef0a3c7a6f14460fb4da096928ae249e\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HBoxModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HBoxModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HBoxView\",\n      \"box_style\": \"\",\n      \"children\": [\n       \"IPY_MODEL_07fb3a2c8315494e97b447e672dfae06\",\n       \"IPY_MODEL_ec030fc3c346426f9abc3a89892258d3\",\n       \"IPY_MODEL_e3fb3fc6afe04b3c9b7ac61809ce78fa\"\n      ],\n      \"layout\": \"IPY_MODEL_c3be9109d63c485d9c0ef4f9bc0f9218\"\n     }\n    },\n    \"ef223e8504b64e3592589880326aaf41\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"f0a58fbd0fca4340890041f99fa2f8c8\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"ProgressStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"ProgressStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"bar_color\": null,\n      \"description_width\": \"\"\n     }\n    },\n    \"f113ebd8c1c34806bea4dd7ed3035173\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"DescriptionStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"DescriptionStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"description_width\": \"\"\n     }\n    },\n    \"f1cef8e8dc2646fb9fd09f3b09081074\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"DescriptionStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"DescriptionStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"description_width\": \"\"\n     }\n    },\n    \"f3075dccbd2747b4a7913b66f44f2596\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"DescriptionStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"DescriptionStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"description_width\": \"\"\n     }\n    },\n    \"f365820a3d3c42b2948abfe32065de14\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HTMLModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HTMLModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HTMLView\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_735d4f225b24414294fc1b213c61223c\",\n      \"placeholder\": \"​\",\n      \"style\": \"IPY_MODEL_5e5e15b0569b474c9620083b3ec6af55\",\n      \"value\": \"generation_config.json: 100%\"\n     }\n    },\n    \"f4667818b9d34a09891cd727a429a610\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HTMLModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HTMLModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HTMLView\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_4b27c267393640f28f6eae0875bd2ed9\",\n      \"placeholder\": \"​\",\n      \"style\": \"IPY_MODEL_9858cb74a09748a39e8149baac96702c\",\n      \"value\": \" 3.96G/3.96G [00:11&lt;00:00, 457MB/s]\"\n     }\n    },\n    \"f4a1795dc7514a718f478245f521f0ba\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"f60a2bdb6b6b4e0e8c3508580e247132\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"FloatProgressModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"FloatProgressModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"ProgressView\",\n      \"bar_style\": \"danger\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_edc99591b9c747b689b94d0052fec14c\",\n      \"max\": 3963750880,\n      \"min\": 0,\n      \"orientation\": \"horizontal\",\n      \"style\": \"IPY_MODEL_35cc989ca3374e7dba0cb166febc4bde\",\n      \"value\": 3963750502\n     }\n    },\n    \"f7434f3e03124a1c938a39af79d7fa59\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"f8ef805b776145c3bfa9ba8d90972058\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"fa1282ccc7544e4f818e2f03ccffe4a5\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"fa864b41586f4a7aa56aeafd1d84eb75\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"fba7aa824b38467ab3061b226114cdec\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"fcb30372e7404c5d8a1ad4df91e6c7b2\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"DescriptionStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"DescriptionStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"description_width\": \"\"\n     }\n    },\n    \"fcbab4d8dced41a18dfccce81e3a45a0\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"DescriptionStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"DescriptionStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"description_width\": \"\"\n     }\n    },\n    \"fd4f333f7ece4450b04e1a9af1f9d2f6\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HTMLModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HTMLModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HTMLView\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_d1f9b10c130542f094c8fd3d1e23b5e9\",\n      \"placeholder\": \"​\",\n      \"style\": \"IPY_MODEL_e575d87a7efe4ec7b1efde489839d4a6\",\n      \"value\": \"model-00006-of-00008.safetensors: 100%\"\n     }\n    },\n    \"fe18bba7f3fb4c31bf840541f36b3425\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HBoxModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HBoxModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HBoxView\",\n      \"box_style\": \"\",\n      \"children\": [\n       \"IPY_MODEL_fd4f333f7ece4450b04e1a9af1f9d2f6\",\n       \"IPY_MODEL_f60a2bdb6b6b4e0e8c3508580e247132\",\n       \"IPY_MODEL_c0892a1881de4eb4bfabc6a68f87ae99\"\n      ],\n      \"layout\": \"IPY_MODEL_1bec6297c90242a88672d195bc09d429\"\n     }\n    },\n    \"fe41858c6bd04c58840112b67c19a336\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HTMLModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HTMLModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HTMLView\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_e5a82df528bb4e408797a3b6c2758f4a\",\n      \"placeholder\": \"​\",\n      \"style\": \"IPY_MODEL_f113ebd8c1c34806bea4dd7ed3035173\",\n      \"value\": \" 9985/9985 [00:00&lt;00:00, 44264.88 examples/s]\"\n     }\n    },\n    \"fea1b70fb46745feb5111b3929175b5d\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HBoxModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HBoxModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HBoxView\",\n      \"box_style\": \"\",\n      \"children\": [\n       \"IPY_MODEL_f365820a3d3c42b2948abfe32065de14\",\n       \"IPY_MODEL_823f1c78f15043e38bbd4dca3932a86a\",\n       \"IPY_MODEL_a1959759c5424da9961fb2a308d4dee4\"\n      ],\n      \"layout\": \"IPY_MODEL_34c9c0137b504cd799c6bd6de69507c2\"\n     }\n    },\n    \"ff3a94b146a948b6907f5d80c7157f99\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"DescriptionStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"DescriptionStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"description_width\": \"\"\n     }\n    },\n    \"ffdbb12a2f2c4d14911685e7683e0ef0\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HTMLModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HTMLModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HTMLView\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_ab93eabd7cea4b94b4b7a387f101e8a1\",\n      \"placeholder\": \"​\",\n      \"style\": \"IPY_MODEL_704f2f5a9b1c49d5a75a0025a5dda11b\",\n      \"value\": \" 3.96G/3.96G [00:12&lt;00:00, 656MB/s]\"\n     }\n    }\n   }\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 0\n}\n"
  },
  {
    "path": "examples/deepcogito/cogito-v1-preview-llama-3B-lora.yml",
    "content": "base_model: deepcogito/cogito-v1-preview-llama-3B\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nload_in_8bit: true\nload_in_4bit: false\nstrict: false\n\ndatasets:\n  - path: fozziethebeat/alpaca_messages_2k_test\n    type: chat_template\n\ndataset_prepared_path:\nval_set_size: 0.05\noutput_dir: ./outputs/lora-out\n\nsequence_len: 4096\nsample_packing: true\neval_sample_packing: false\n\n\nadapter: lora\nlora_model_dir:\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_linear: true\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 2\nmicro_batch_size: 2\nnum_epochs: 1\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: true\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 1\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/deepcogito/cogito-v1-preview-qwen-14B-lora.yml",
    "content": "base_model: deepcogito/cogito-v1-preview-qwen-14B\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nload_in_8bit: true\nload_in_4bit: false\nstrict: false\n\ndatasets:\n  - path: fozziethebeat/alpaca_messages_2k_test\n    type: chat_template\n\ndataset_prepared_path:\nval_set_size: 0.05\noutput_dir: ./outputs/lora-out\n\nsequence_len: 4096\nsample_packing: true\neval_sample_packing: false\n\n\nadapter: lora\nlora_model_dir:\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_linear: true\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 2\nmicro_batch_size: 2\nnum_epochs: 1\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: true\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 1\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/deepseek-v2/fft-fsdp-16b.yaml",
    "content": "base_model: deepseek-ai/DeepSeek-V2-Lite\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\ntrust_remote_code: true\n\ndatasets:\n  - path: tatsu-lab/alpaca\n    type: alpaca\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.0\noutput_dir: ./outputs/out\n\nsequence_len: 2048\nsample_packing: true\n\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 8\nmicro_batch_size: 1\nnum_epochs: 1\noptimizer: adamw_torch_fused\nlr_scheduler: cosine\nlearning_rate: 2e-5\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\ngradient_checkpointing_kwargs:\n  use_reentrant: false\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 2\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\nfsdp:\n  - full_shard\n  - auto_wrap\nfsdp_config:\n  fsdp_limit_all_gathers: true\n  fsdp_sync_module_states: true\n  fsdp_offload_params: true\n  fsdp_use_orig_params: false\n  fsdp_cpu_ram_efficient_loading: true\n  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP\n  fsdp_transformer_layer_cls_to_wrap: DeepseekV2DecoderLayer\n  fsdp_state_dict_type: FULL_STATE_DICT\n  fsdp_sharding_strategy: FULL_SHARD\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/deepseek-v2/qlora-fsdp-2_5.yaml",
    "content": "base_model: axolotl-quants/DeepSeek-V2.5-bnb-nf4-bf16\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\ntrust_remote_code: true\n\nload_in_8bit: false\nload_in_4bit: true\n\n\nplugins:\n  - axolotl.integrations.liger.LigerPlugin\nliger_rms_norm: true\nliger_glu_activation: true\nliger_fused_linear_cross_entropy: true\n\nchat_template: deepseek_v2\ndatasets:\n  - path: mlabonne/FineTome-100k\n    type: chat_template\n    split: train[:20%]\n    field_messages: conversations\n    message_property_mappings:\n      role: from\n      content: value\n\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.0\noutput_dir: ./outputs/out\n\nsequence_len: 4096\nsample_packing: true\n\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\nadapter: qlora\nlora_r: 256\nlora_alpha: 256\nlora_target_linear: true\npeft_use_rslora: true\n\ngradient_accumulation_steps: 1\nmicro_batch_size: 8\nnum_epochs: 1\noptimizer: adamw_torch_fused\nlr_scheduler: cosine\nlearning_rate: 2e-5\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\ngradient_checkpointing_kwargs:\n  use_reentrant: false\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 2\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\nfsdp:\n  - full_shard\n  - auto_wrap\nfsdp_config:\n  fsdp_limit_all_gathers: true\n  fsdp_sync_module_states: true\n  fsdp_offload_params: true\n  fsdp_use_orig_params: false\n  fsdp_cpu_ram_efficient_loading: true\n  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP\n  fsdp_transformer_layer_cls_to_wrap: DeepseekV2DecoderLayer\n  fsdp_state_dict_type: FULL_STATE_DICT\n  fsdp_sharding_strategy: FULL_SHARD\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/devstral/README.md",
    "content": "# Finetune Devstral with Axolotl\n\nDevstral Small is a 24B parameter opensource model from MistralAI found on HuggingFace [Devstral-Small-2505](https://huggingface.co/mistralai/Devstral-Small-2505) and [Devstral-Small-2507](https://huggingface.co/mistralai/Devstral-Small-2507). `Devstral-Small-2507` is the latest version of the model and has [function calling](https://mistralai.github.io/mistral-common/usage/tools/) support.\n\nThis guide shows how to fine-tune it with Axolotl with multi-turn conversations with proper masking.\n\nThe model was fine-tuned ontop of [Mistral-Small-3.1](https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Base-2503) without the vision layer and has a context of up to 128k tokens.\n\nThanks to the team at MistralAI for giving us early access to prepare for this release.\n\n## Getting started\n\n1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).\n\n    Here is an example of how to install from pip:\n\n```bash\n# Ensure you have Pytorch installed (Pytorch 2.6.0 min)\npip3 install packaging==26.0 setuptools==75.8.0 wheel ninja\npip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'\n```\n\n2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage\n\n```bash\npython scripts/cutcrossentropy_install.py | sh\n```\n\n3. Run the finetuning example:\n\n```bash\naxolotl train examples/devstral/devstral-small-qlora.yml\n```\n\nThis config uses about 21GB VRAM.\n\nLet us know how it goes. Happy finetuning! 🚀\n\n### TIPS\n\n- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.\n- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).\n- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).\n- Learn how to use function calling with Axolotl at [docs](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#using-tool-use).\n\n## Optimization Guides\n\n- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)\n- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)\n- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)\n- [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy)\n- [Liger Kernel](https://docs.axolotl.ai/docs/custom_integrations.html#liger-kernels)\n\n## Limitations\n\nWe only support the `mistral-common` tokenizer for Supervised Fine-tuning at the moment and for `type: chat_template` only.\n\nIn addition, we do not support overriding tokens yet.\n\n## Related Resources\n\n- [MistralAI Devstral Blog](https://mistral.ai/news/devstral)\n- [MistralAI Devstral 1.1 Blog](https://mistral.ai/news/devstral-2507)\n- [Axolotl Docs](https://docs.axolotl.ai)\n- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)\n- [Axolotl Website](https://axolotl.ai)\n- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)\n\n\n## Future Work\n\n- Add parity to Preference Tuning, RL, Multi-modal, etc.\n- Add parity to other tokenizer configs like overriding tokens.\n"
  },
  {
    "path": "examples/devstral/devstral-small-qlora.yml",
    "content": "base_model: mistralai/Devstral-Small-2507\n\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\n# Enable to use mistral-common tokenizer\ntokenizer_use_mistral_common: true\n\nload_in_8bit: false\nload_in_4bit: true\n\nplugins:\n  - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin\n\ndatasets:\n  - path: fozziethebeat/alpaca_messages_2k_test\n    type: chat_template\n\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.1\noutput_dir: ./outputs/qlora-out\n\nadapter: qlora\nlora_model_dir:\n\nsequence_len: 2048\nsample_packing: true\n\n\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0\nlora_target_linear: true\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 2\nnum_epochs: 1\noptimizer: adamw_torch\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\nscaling_softmax: true\n\nloss_watchdog_threshold: 5.0\nloss_watchdog_patience: 3\n\nwarmup_ratio: 0.05\nevals_per_epoch: 4\nsaves_per_epoch: 1\n\nweight_decay: 0.0\nspecial_tokens:\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/distributed-parallel/README.md",
    "content": "# ND Parallelism Examples\n\nThis directory contains example configurations for training models using ND Parallelism in Axolotl. These examples demonstrate how to compose different parallelism strategies (FSDP, TP, CP, HSDP) for efficient multi-GPU training.\n\n## Quick Start\n\n1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).\n\n2. Run the command below:\n\n```bash\n# Train Qwen3 8B with FSDP + TP + CP on a single 8-GPU node\naxolotl train examples/distributed-parallel/qwen3-8b-fsdp-tp-cp.yaml\n\n# Train Llama 3.1 8B with HSDP + TP on 2 nodes (16 GPUs total)\naxolotl train examples/distributed-parallel/llama-3_1-8b-hsdp-tp.yaml\n```\n\n## Example Configurations\n\n### Single Node (8 GPUs)\n\n**Qwen3 8B with FSDP + TP + CP** ([qwen3-8b-fsdp-tp-cp.yaml](./qwen3-8b-fsdp-tp-cp.yaml))\n- Uses all 3 parallelism dimensions on a single node\n- Ideal for: when model weights, activations, and/or context are too large to fit on single GPU\n\n```yaml\ndp_shard_size: 2         # FSDP across 2 GPUs\ntensor_parallel_size: 2  # TP across 2 GPUs\ncontext_parallel_size: 2 # CP across 2 GPUs\n# Total: 2 × 2 × 2 = 8 GPUs\n```\n\n### Multi-Node\n\n**Llama 3.1 8B with HSDP + TP** ([llama-3_1-8b-hsdp-tp.yaml](./llama-3_1-8b-hsdp-tp.yaml))\n- FSDP & TP within nodes, DDP across nodes to minimize inter-node communication\n- Ideal for: Scaling to multiple nodes while maintaining training efficiency\n\n```yaml\ndp_shard_size: 4        # FSDP within each 4-GPU group\ntensor_parallel_size: 2 # TP within each node\ndp_replicate_size: 2    # DDP across 2 groups\n# Total: (4 × 2) × 2 = 16 GPUs (2 nodes)\n```\n\n## Learn More\n\n- [ND Parallelism Documentation](https://docs.axolotl.ai/docs/nd_parallelism.html)\n- [Blog: Accelerate ND-Parallel Guide](https://huggingface.co/blog/accelerate-nd-parallel)\n- [Multi-GPU Training Guide](https://docs.axolotl.ai/docs/multi-gpu.html)\n- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)\n"
  },
  {
    "path": "examples/distributed-parallel/llama-3_1-8b-hsdp-tp.yaml",
    "content": "base_model: meta-llama/Llama-3.1-8B\n\nplugins:\n  - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin\n\ndp_shard_size: 4\ndp_replicate_size: 2\ntensor_parallel_size: 2\n# context_parallel_size: 2\n\ndataset_prepared_path: last_run_prepared\n\nspecial_tokens:\n  pad_token: <|end_of_text|>\n\nfsdp_version: 2\nfsdp_config:\n  offload_params: false\n  state_dict_type: FULL_STATE_DICT\n  auto_wrap_policy: TRANSFORMER_BASED_WRAP\n  transformer_layer_cls_to_wrap: LlamaDecoderLayer\n  reshard_after_forward: true\n\ndatasets:\n  - path: tatsu-lab/alpaca\n    type: alpaca\n\noutput_dir: ./outputs/ndp-out/\n\nsequence_len: 2048\nsample_packing: true\nflash_attention: true\n\ngradient_accumulation_steps: 1\nmicro_batch_size: 1\nnum_epochs: 2\noptimizer: adamw_torch_fused\nlr_scheduler: constant_with_warmup\nlearning_rate: 2e-6\n\nbf16: true\ntf32: true\n\nlogging_steps: 1\nsaves_per_epoch: 1\n\nwarmup_ratio: 0.1\n"
  },
  {
    "path": "examples/distributed-parallel/qwen3-8b-fsdp-tp-cp.yaml",
    "content": "base_model: Qwen/Qwen3-8B\n\nplugins:\n  - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin\n\ndp_shard_size: 2\n# dp_replicate_size: 1\ncontext_parallel_size: 2\ntensor_parallel_size: 2\n\ndataset_prepared_path: last_run_prepared\n\nfsdp_version: 2\nfsdp_config:\n  offload_params: false\n  state_dict_type: FULL_STATE_DICT\n  auto_wrap_policy: TRANSFORMER_BASED_WRAP\n  transformer_layer_cls_to_wrap: Qwen3DecoderLayer\n  reshard_after_forward: true\n\ndatasets:\n  - path: tatsu-lab/alpaca\n    type: alpaca\n\noutput_dir: ./outputs/ndp-out/\n\nsequence_len: 8192\nsample_packing: true\nflash_attention: true\n\ngradient_accumulation_steps: 1\nmicro_batch_size: 1  # must be 1 when using context parallel\nnum_epochs: 2\noptimizer: adamw_torch_fused\nlr_scheduler: constant_with_warmup\nlearning_rate: 2e-6\n\nbf16: true\ntf32: true\n\nlogging_steps: 1\nsaves_per_epoch: 1\n\nwarmup_ratio: 0.1\n\nspecial_tokens:\n"
  },
  {
    "path": "examples/eaft/eaft-example.yml",
    "content": "base_model: google/gemma-3-1b-it\n\nmodel_type: Gemma3ForCausalLM\ncls_model_config: Gemma3TextConfig\n\n# gemma3 doesn't seem to play nice with ddp\nddp_find_unused_parameters: true\n\nchat_template: gemma3\neot_tokens:\n  - <end_of_turn>\n\nload_in_8bit: false\nload_in_4bit: false\nstrict: false\n\ndatasets:\n  - path: cgato/SlimOrcaDedupCleaned\n    type: chat_template\n    field_messages: conversations\n    message_property_mappings:\n      role: from\n      content: value\n\ndataset_prepared_path:\nval_set_size: 0\noutput_dir: ./outputs/eaft-gemma-3-1b\n\nuse_eaft: true\neaft_alpha: 1.0\neaft_k: 20\n\nsequence_len: 1024\nsample_packing: false\n\nadapter:\nlora_model_dir:\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 1\neval_batch_size: 1\nmax_steps: 1000\nevaluation_strategy: \"no\"\noptimizer: adamw_torch_fused\nlr_scheduler: cosine\nlearning_rate: 5e-5\n\ntrain_on_inputs: false\ngroup_by_length: false\nbf16: auto\nfp16:\ntf32: true\n\ngradient_checkpointing: true\ngradient_checkpointing_kwargs:\n  use_reentrant: false\n\nearly_stopping_patience:\nresume_from_checkpoint:\nlocal_rank:\nlogging_steps: 1\nxformers_attention:\nflash_attention: true\n\nwarmup_ratio: 0.1\nweight_decay: 0.0\ndebug:\ndeepspeed:\nfsdp:\nfsdp_config:\nspecial_tokens:\n"
  },
  {
    "path": "examples/falcon-h1/falcon-h1-1b-deep-qlora.yaml",
    "content": "base_model: tiiuae/Falcon-H1-1.5B-Deep-Base\n# optionally might have model_type or tokenizer_type\nmodel_type: AutoModelForCausalLM\ntokenizer_type: AutoTokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nload_in_8bit: false\nload_in_4bit: true\n\n# huggingface repo\nchat_template: falcon_h1\ndatasets:\n  - path: cgato/SlimOrcaDedupCleaned\n    type: chat_template\n    field_messages: conversations\n    message_property_mappings:\n      role: from\n      content: value\n\nval_set_size: 0.0\noutput_dir: ./outputs/out\n\nadapter: qlora\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_modules:\n  - q_proj\n  - k_proj\n  - v_proj\n  - o_proj\n  - in_proj\n  - gate_proj\n  - up_proj\n  - down_proj\n\nsequence_len: 2048\nsample_packing: false\neval_sample_packing: false\n\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 1\nnum_epochs: 4\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: true\n\ngradient_checkpointing: true\ngradient_checkpointing_kwargs:\n  use_reentrant: false\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch:\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/falcon-h1/falcon-h1-1b-qlora.yaml",
    "content": "base_model: tiiuae/Falcon-H1-1.5B-Base\n# optionally might have model_type or tokenizer_type\nmodel_type: AutoModelForCausalLM\ntokenizer_type: AutoTokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nload_in_8bit: false\nload_in_4bit: true\n\n# huggingface repo\nchat_template: falcon_h1\ndatasets:\n  - path: cgato/SlimOrcaDedupCleaned\n    type: chat_template\n    field_messages: conversations\n    message_property_mappings:\n      role: from\n      content: value\n\nval_set_size: 0.0\noutput_dir: ./outputs/out\n\nadapter: qlora\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_modules:\n  - q_proj\n  - k_proj\n  - v_proj\n  - o_proj\n  - in_proj\n  - gate_proj\n  - up_proj\n  - down_proj\n\nsequence_len: 2048\nsample_packing: false\neval_sample_packing: false\n\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 1\nnum_epochs: 4\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: true\n\ngradient_checkpointing: true\ngradient_checkpointing_kwargs:\n  use_reentrant: false\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch:\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/falcon-h1/falcon-h1-34b-qlora.yaml",
    "content": "base_model: tiiuae/Falcon-H1-34B-Base\n# optionally might have model_type or tokenizer_type\nmodel_type: AutoModelForCausalLM\ntokenizer_type: AutoTokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nload_in_8bit: false\nload_in_4bit: true\n\n# huggingface repo\nchat_template: falcon_h1\ndatasets:\n  - path: cgato/SlimOrcaDedupCleaned\n    type: chat_template\n    field_messages: conversations\n    message_property_mappings:\n      role: from\n      content: value\n\nval_set_size: 0.0\noutput_dir: ./outputs/out\n\nadapter: qlora\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_modules:\n  - q_proj\n  - k_proj\n  - v_proj\n  - o_proj\n  - in_proj\n  - gate_proj\n  - up_proj\n  - down_proj\n\nsequence_len: 2048\nsample_packing: false\neval_sample_packing: false\n\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 1\nnum_epochs: 4\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: true\n\ngradient_checkpointing: true\ngradient_checkpointing_kwargs:\n  use_reentrant: false\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch:\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/falcon-h1/falcon-h1-3b-qlora.yaml",
    "content": "base_model: tiiuae/Falcon-H1-3B-Base\n# optionally might have model_type or tokenizer_type\nmodel_type: AutoModelForCausalLM\ntokenizer_type: AutoTokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nload_in_8bit: false\nload_in_4bit: true\n\n# huggingface repo\nchat_template: falcon_h1\ndatasets:\n  - path: cgato/SlimOrcaDedupCleaned\n    type: chat_template\n    field_messages: conversations\n    message_property_mappings:\n      role: from\n      content: value\n\nval_set_size: 0.0\noutput_dir: ./outputs/out\n\nadapter: qlora\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_modules:\n  - q_proj\n  - k_proj\n  - v_proj\n  - o_proj\n  - in_proj\n  - gate_proj\n  - up_proj\n  - down_proj\n\nsequence_len: 2048\nsample_packing: false\neval_sample_packing: false\n\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 1\nnum_epochs: 4\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: true\n\ngradient_checkpointing: true\ngradient_checkpointing_kwargs:\n  use_reentrant: false\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 1\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/falcon-h1/falcon-h1-500m-qlora.yaml",
    "content": "base_model: tiiuae/Falcon-H1-0.5B-Instruct\n# optionally might have model_type or tokenizer_type\nmodel_type: AutoModelForCausalLM\ntokenizer_type: AutoTokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nload_in_8bit: false\nload_in_4bit: true\n\n# huggingface repo\nchat_template: falcon_h1\ndatasets:\n  - path: cgato/SlimOrcaDedupCleaned\n    type: chat_template\n    field_messages: conversations\n    message_property_mappings:\n      role: from\n      content: value\n\nval_set_size: 0.0\noutput_dir: ./outputs/out\n\nadapter: qlora\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_modules:\n  - q_proj\n  - k_proj\n  - v_proj\n  - o_proj\n  - in_proj\n  - gate_proj\n  - up_proj\n  - down_proj\n\nsequence_len: 2048\nsample_packing: false\neval_sample_packing: false\n\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 1\nnum_epochs: 4\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: true\n\ngradient_checkpointing: true\ngradient_checkpointing_kwargs:\n  use_reentrant: false\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch:\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/falcon-h1/falcon-h1-7b-qlora.yaml",
    "content": "base_model: tiiuae/Falcon-H1-7B-Base\n# optionally might have model_type or tokenizer_type\nmodel_type: AutoModelForCausalLM\ntokenizer_type: AutoTokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nload_in_8bit: false\nload_in_4bit: true\n\n# huggingface repo\nchat_template: falcon_h1\ndatasets:\n  - path: cgato/SlimOrcaDedupCleaned\n    type: chat_template\n    field_messages: conversations\n    message_property_mappings:\n      role: from\n      content: value\n\nval_set_size: 0.0\noutput_dir: ./outputs/out\n\nadapter: qlora\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_modules:\n  - q_proj\n  - k_proj\n  - v_proj\n  - o_proj\n  - in_proj\n  - gate_proj\n  - up_proj\n  - down_proj\n\nsequence_len: 2048\nsample_packing: false\neval_sample_packing: false\n\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 1\nnum_epochs: 4\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: true\n\ngradient_checkpointing: true\ngradient_checkpointing_kwargs:\n  use_reentrant: false\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 1\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/gemma2/qlora.yml",
    "content": "base_model: google/gemma-2-9b\n# optionally might have model_type or tokenizer_type\nmodel_type: AutoModelForCausalLM\ntokenizer_type: AutoTokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nload_in_8bit: false\nload_in_4bit: true\n\n# huggingface repo\nchat_template: gemma\ndatasets:\n  - path: cgato/SlimOrcaDedupCleaned\n    type: chat_template\n    drop_system_message: true\n    field_messages: conversations\n    message_property_mappings:\n      role: from\n      content: value\n\nval_set_size: 0.0\noutput_dir: ./outputs/out\n\nadapter: qlora\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_linear: true\n\nsequence_len: 2048\nsample_packing: true\neval_sample_packing: false\n\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 1\nnum_epochs: 4\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: true\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch:\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/gemma2/reward-model.yaml",
    "content": "base_model: google/gemma-2-2b\n# optionally might have model_type or tokenizer_type\nmodel_type: AutoModelForSequenceClassification\nnum_labels: 1\ntokenizer_type: AutoTokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nreward_model: true\nchat_template: gemma\ndatasets:\n  - path: argilla/distilabel-intel-orca-dpo-pairs\n    type: bradley_terry.chat_template\nval_set_size: 0.0\noutput_dir: ./outputs/out\nremove_unused_columns: false\n\nsequence_len: 2048\nsample_packing: false\neval_sample_packing: false\n\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 2\nnum_epochs: 4\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: true\nfp16:\ntf32: true\n\ngradient_checkpointing: true\ngradient_checkpointing_kwargs:\n  use_reentrant: false\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch:\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/gemma3/gemma-3-1b-qlora.yml",
    "content": "base_model: google/gemma-3-1b-it\n\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\n# gemma3 doesn't seem to play nice with ddp\nddp_find_unused_parameters: true\n\nload_in_8bit: false\nload_in_4bit: true\n\n# huggingface repo\nchat_template: gemma3\neot_tokens:\n  - <end_of_turn>\ndatasets:\n  - path: cgato/SlimOrcaDedupCleaned\n    type: chat_template\n    field_messages: conversations\n    message_property_mappings:\n      role: from\n      content: value\n\nval_set_size: 0.0\noutput_dir: ./outputs/out\n\n# Freeze vision tower\nunfrozen_parameters:\n  - ^model\\.language_model\\..*\n  - ^lm_head\\..*\n\nadapter: qlora\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0\nlora_target_linear: true\n\nsequence_len: 2048\nsample_packing: true\neval_sample_packing: false\n\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 1\nnum_epochs: 4\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: true\n\ngradient_checkpointing: true\ngradient_checkpointing_kwargs:\n  use_reentrant: false\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch:\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/gemma3/gemma-3-270m-qlora.yml",
    "content": "base_model: google/gemma-3-270m-it\n\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\n# gemma3 doesn't seem to play nice with ddp\nddp_find_unused_parameters: true\n\nload_in_8bit: false\nload_in_4bit: true\n\n# huggingface repo\nchat_template: gemma3\neot_tokens:\n  - <end_of_turn>\ndatasets:\n  - path: cgato/SlimOrcaDedupCleaned\n    type: chat_template\n    field_messages: conversations\n    message_property_mappings:\n      role: from\n      content: value\n\nval_set_size: 0.0\noutput_dir: ./outputs/out\n\n# Freeze vision tower\nunfrozen_parameters:\n  - ^model\\.language_model\\..*\n  - ^lm_head\\..*\n\nadapter: qlora\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0\nlora_target_linear: true\n\nsequence_len: 2048\nsample_packing: true\neval_sample_packing: false\n\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 1\nnum_epochs: 1\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: true\n\ngradient_checkpointing: true\ngradient_checkpointing_kwargs:\n  use_reentrant: false\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch:\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n"
  },
  {
    "path": "examples/gemma3/gemma-3-4b-qlora.yml",
    "content": "base_model: google/gemma-3-4b-it\n\nload_in_4bit: true\n\n# gemma3 doesn't seem to play nice with ddp\nddp_find_unused_parameters: true\n\nchat_template: gemma3\neot_tokens:\n  - <end_of_turn>\ndatasets:\n  - path: cgato/SlimOrcaDedupCleaned\n    type: chat_template\n    field_messages: conversations\n    message_property_mappings:\n      role: from\n      content: value\n\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.01\noutput_dir: ./outputs/out\n\n# Freeze vision tower\nunfrozen_parameters:\n  - ^model\\.language_model\\..*\n  - ^lm_head\\..*\n\nadapter: qlora\nlora_model_dir:\n\nsequence_len: 2048\nsample_packing: true\n\n\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0\nlora_target_linear: true\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 2\nnum_epochs: 1\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: true\nfp16:\ntf32: true\n\ngradient_checkpointing: true\ngradient_checkpointing_kwargs:\n  use_reentrant: false\nlogging_steps: 1\nflash_attention: true\neager_attention:\n\nwarmup_ratio: 0.1\nevals_per_epoch: 1\nsaves_per_epoch: 1\nweight_decay: 0.0\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/gemma3/gemma-3-4b-vision-qlora.yml",
    "content": "base_model: google/gemma-3-4b-it\nprocessor_type: AutoProcessor\n\nload_in_4bit: true\n\n# these 3 lines are needed for now to handle vision chat templates w images\nskip_prepare_dataset: true\nremove_unused_columns: false\nsample_packing: false\n\n# gemma3 doesn't seem to play nice with ddp\nddp_find_unused_parameters: true\n\nchat_template: gemma3\neot_tokens:\n  - <end_of_turn>\ndatasets:\n  - path: HuggingFaceH4/llava-instruct-mix-vsft\n    type: chat_template\n    split: train[:1%]\n\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.01\noutput_dir: ./outputs/out\n\nadapter: qlora\nlora_model_dir:\n\nsequence_len: 2048\npad_to_sequence_len: false\n\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0\nlora_target_modules: 'model.language_model.layers.[\\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 2\nnum_epochs: 1\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: true\nfp16:\ntf32: true\n\ngradient_checkpointing: true\ngradient_checkpointing_kwargs:\n  use_reentrant: false\nlogging_steps: 1\nflash_attention: true\neager_attention:\n\nwarmup_ratio: 0.1\nevals_per_epoch: 1\nsaves_per_epoch: 1\nweight_decay: 0.0\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/gemma3n/README.md",
    "content": "# Finetune Gemma-3n with Axolotl\n\nGemma-3n is a family of multimodal models from Google found on [HuggingFace](https://huggingface.co/collections/google/gemma-3n-685065323f5984ef315c93f4). This guide shows how to fine-tune it with Axolotl.\n\n## Getting started\n\n1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).\n\n    Here is an example of how to install from pip:\n\n```bash\n# Ensure you have Pytorch installed (Pytorch 2.6.0 min)\npip3 install packaging==26.0 setuptools==75.8.0 wheel ninja\npip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'\n```\n\n2. In addition to Axolotl's requirements, Gemma-3n requires:\n\n```bash\npip3 install timm==1.0.17\n\n# for loading audio data\npip3 install librosa==0.11.0\n```\n\n3. Download sample dataset files\n\n```bash\n# for text + vision + audio only\nwget https://huggingface.co/datasets/Nanobit/text-vision-audio-2k-test/resolve/main/African_elephant.jpg\nwget https://huggingface.co/datasets/Nanobit/text-vision-audio-2k-test/resolve/main/En-us-African_elephant.oga\n```\n\n4. Run the finetuning example:\n\n```bash\n# text only\naxolotl train examples/gemma3n/gemma-3n-e2b-qlora.yml\n\n# text + vision\naxolotl train examples/gemma3n/gemma-3n-e2b-vision-qlora.yml\n\n# text + vision + audio\naxolotl train examples/gemma3n/gemma-3n-e2b-vision-audio-qlora.yml\n```\n\nLet us know how it goes. Happy finetuning! 🚀\n\nWARNING: The loss and grad norm will be much higher than normal. We suspect this to be inherent to the model as of the moment. If anyone would like to submit a fix for this, we are happy to take a look.\n\n### TIPS\n\n- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.\n- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).\n- The text dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).\n- The multimodal dataset format follows the OpenAI multi-content Messages format as seen [here](https://docs.axolotl.ai/docs/multimodal.html#dataset-format).\n\n## Optimization Guides\n\n- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)\n- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)\n- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)\n\n## Related Resources\n\n- [Gemma 3n Blog](https://ai.google.dev/gemma/docs/gemma-3n)\n- [Axolotl Docs](https://docs.axolotl.ai)\n- [Axolotl Website](https://axolotl.ai)\n- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)\n- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)\n"
  },
  {
    "path": "examples/gemma3n/gemma-3n-e2b-qlora.yml",
    "content": "base_model: google/gemma-3n-E2B-it\n\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nplugins:\n  - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin\ncut_cross_entropy: true\n\nload_in_8bit: false\nload_in_4bit: true\n\n# for use with fft to only train on language model layers\n# unfrozen_parameters:\n  # - model.language_model.*\n  # - lm_head\n  # - embed_tokens\n\n\nchat_template: gemma3n\neot_tokens:\n  - <end_of_turn>\ndatasets:\n  - path: cgato/SlimOrcaDedupCleaned\n    type: chat_template\n    split: train[:1%]\n    field_messages: conversations\n    message_property_mappings:\n      role: from\n      content: value\n\nval_set_size: 0.0\noutput_dir: ./outputs/out\n\nadapter: qlora\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\n# lora_target_linear: # Does not work with gemma3n currently\nlora_target_modules: 'model.language_model.layers.[\\d]+.(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'\n\nsequence_len: 2048\nsample_packing: true\neval_sample_packing: true\npad_to_sequence_len: true\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 1\nmicro_batch_size: 1\nnum_epochs: 4\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: true\n\ngradient_checkpointing: true\ngradient_checkpointing_kwargs:\n  use_reentrant: false\nresume_from_checkpoint:\nlogging_steps: 1\n# flash_attention: true  # Any attention impl does not work with gemma3n now\n\nwarmup_ratio: 0.1\nevals_per_epoch:\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n"
  },
  {
    "path": "examples/gemma3n/gemma-3n-e2b-vision-audio-qlora.yml",
    "content": "base_model: google/gemma-3n-E2B-it\nprocessor_type: AutoProcessor\n\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nplugins:\n  - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin\ncut_cross_entropy: true\n\n# for use with fft to only train on language model layers\n# unfrozen_parameters:\n  # - model.language_model.*\n  # - lm_head\n  # - embed_tokens\n\nload_in_4bit: true\n\n# these 3 lines are needed for now to handle vision chat templates w images\nskip_prepare_dataset: true\nremove_unused_columns: false\nsample_packing: false\n\n# gemma3 doesn't seem to play nice with ddp\nddp_find_unused_parameters: true\n\nchat_template: gemma3n\neot_tokens:\n  - <end_of_turn>\n\n# sample dataset below requires downloading audio/image in advance\n# wget https://huggingface.co/datasets/Nanobit/text-vision-audio-2k-test/resolve/main/African_elephant.jpg\n# wget https://huggingface.co/datasets/Nanobit/text-vision-audio-2k-test/resolve/main/En-us-African_elephant.oga\ndatasets:\n  - path: Nanobit/text-vision-audio-2k-test\n    type: chat_template\ndataset_prepared_path:\nval_set_size: 0.01\noutput_dir: ./outputs/out\n\nadapter: qlora\nlora_model_dir:\n\nsequence_len: 2048\npad_to_sequence_len: false\n\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_modules: 'model.language_model.layers.[\\d]+.(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 2\nnum_epochs: 1\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: true\nfp16:\ntf32: true\n\ngradient_checkpointing: true\ngradient_checkpointing_kwargs:\n  use_reentrant: false\nlogging_steps: 1\n# flash_attention: true  # Any attention impl does not work with gemma3n now\n\nwarmup_ratio: 0.1\nevals_per_epoch: 1\nsaves_per_epoch: 1\nweight_decay: 0.0\n"
  },
  {
    "path": "examples/gemma3n/gemma-3n-e2b-vision-qlora.yml",
    "content": "base_model: google/gemma-3n-E2B-it\nprocessor_type: AutoProcessor\n\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nplugins:\n  - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin\ncut_cross_entropy: true\n\n# for use with fft to only train on language model layers\n# unfrozen_parameters:\n  # - model.language_model.*\n  # - lm_head\n  # - embed_tokens\n\nload_in_4bit: true\n\n# these 3 lines are needed for now to handle vision chat templates w images\nskip_prepare_dataset: true\nremove_unused_columns: false\nsample_packing: false\n\n# gemma3 doesn't seem to play nice with ddp\nddp_find_unused_parameters: true\n\nchat_template: gemma3n\neot_tokens:\n  - <end_of_turn>\ndatasets:\n  - path: HuggingFaceH4/llava-instruct-mix-vsft\n    type: chat_template\n    split: train[:1%]\ndataset_prepared_path:\nval_set_size: 0.01\noutput_dir: ./outputs/out\n\nadapter: qlora\nlora_model_dir:\n\nsequence_len: 2048\npad_to_sequence_len: false\n\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_modules: 'model.language_model.layers.[\\d]+.(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 2\nnum_epochs: 1\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: true\nfp16:\ntf32: true\n\ngradient_checkpointing: true\ngradient_checkpointing_kwargs:\n  use_reentrant: false\nlogging_steps: 1\n# flash_attention: true  # Any attention impl does not work with gemma3n now\n\nwarmup_ratio: 0.1\nevals_per_epoch: 1\nsaves_per_epoch: 1\nweight_decay: 0.0\n"
  },
  {
    "path": "examples/glm4/qlora-32b.yaml",
    "content": "base_model: THUDM/GLM-4-32B-0414\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nload_in_4bit: true\n\ndatasets:\n  - path: teknium/GPT4-LLM-Cleaned\n    type: alpaca\ndataset_prepared_path: last_run_prepared\nval_set_size: 0\noutput_dir: ./outputs/qlora-out\n\nadapter: qlora\nlora_model_dir:\n\nsequence_len: 2048\nsample_packing: true\neval_sample_packing: true\n\n\nlora_r: 16\nlora_alpha: 32\nlora_dropout: 0.05\nlora_target_modules:\n  - gate_proj\n  - down_proj\n  - up_proj\n  - q_proj\n  - v_proj\n  - k_proj\n  - o_proj\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 2\nmicro_batch_size: 2\nnum_epochs: 1\noptimizer: adamw_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nloss_watchdog_threshold: 5.0\nloss_watchdog_patience: 3\n\nwarmup_ratio: 0.1\nevals_per_epoch: 1\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/glm45/README.md",
    "content": "# Finetune Z.ai's GLM-4.5-Air with Axolotl\n\n[GLM-4.5-Air](https://huggingface.co/zai-org/GLM-4.5-Air) is a MoE model by Z.ai.\n\nThis guide shows how to fine-tune it with Axolotl.\n\n## Getting started\n\n1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).\n\n2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.\n\n3. Run the finetuning example:\n\n```bash\n# QLoRA (1x80GB @ ~63.4GiB/GPU)\naxolotl train examples/glm45/glm-45-air-qlora.yaml\n```\n\n### Dataset\n\nIn addition to the standard OpenAI Messages format, GLM-4.5 supports an extra parameter for thinking in the assistant section.\n\n```json\n{\n    \"role\": \"assistant\",\n    \"reasoning_content\": \"...\",  // or have </think>...</think> in `content`\n    \"content\": \"...\"\n}\n```\n\nMake sure you set the below extra attributes if needed:\n\n```yaml\ndatasets:\n  - path: ...\n    type: chat_template\n    message_property_mappings:\n      role: role\n      content: content\n\n    #   tool_calls: tool_calls  # uncomment if using tools\n    #   reasoning_content: reasoning_content  # uncomment if have reasoning\n\n# Uncomment if training on tool role (you would rarely if ever need this)\n# eot_tokens:\n#   - <|observation|>\n```\n\n### Tips\n\n- The role name for tools in this template is `tool`.\n- You will see this Axolotl WARNING — this is expected as the template does not use EOS:\n  ```\n  EOS token '<|endoftext|>' not found in chat_template. Please check if your template/EOS token is correct.\n  ```\n- You can run a full finetuning by removing `adapter: qlora`, `load_in_4bit: true`, and `quantize_moe_experts: true` from the config.\n- **LoRA kernels**: Incompatible with this model. Must be explicitly disabled (`lora_*_kernel: false`).\n- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).\n\n## Optimization Guides\n\nPlease check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).\n\n## Related Resources\n\n- [GLM-4.5-Air on HuggingFace](https://huggingface.co/zai-org/GLM-4.5-Air)\n- [GLM-4.5 Blog](https://z.ai/blog/glm-4.5)\n- [Axolotl Docs](https://docs.axolotl.ai)\n- [Axolotl Website](https://axolotl.ai)\n- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)\n- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)\n"
  },
  {
    "path": "examples/glm45/glm-45-air-qlora.yaml",
    "content": "base_model: zai-org/GLM-4.5-Air\n\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nplugins:\n  - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin\n\nload_in_8bit: false\nload_in_4bit: true\n\nquantize_moe_experts: true # important\n\ndatasets:\n  - path: fozziethebeat/alpaca_messages_2k_test\n    type: chat_template\n\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.1\noutput_dir: ./outputs/lora-out\n\nadapter: qlora\nlora_model_dir:\n\nsequence_len: 2048\nsample_packing: true\n\nlora_r: 16\nlora_alpha: 8\nlora_dropout: 0\nlora_target_modules:\n  - q_proj\n  - v_proj\n  - k_proj\n  - o_proj\n\n# lora_target_parameters:\n#   - mlp.experts.gate_up_proj\n#   - mlp.experts.down_proj\n\nlora_mlp_kernel: false\nlora_qkv_kernel: false\nlora_o_kernel: false\n\ngradient_accumulation_steps: 2\nmicro_batch_size: 2\nnum_epochs: 1\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 1\nsaves_per_epoch: 1\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/glm46v/README.md",
    "content": "# Finetune GLM-4.6V with Axolotl\n\nGLM-4.6V is a family of vision-language models from ZhipuAI found on [HuggingFace](https://huggingface.co/zai-org/GLM-4.6V). This guide shows how to fine-tune it with Axolotl for vision-language tasks.\n\n\n\n## Getting started\n\n1. Install Axolotl from source following the [installation guide](https://docs.axolotl.ai/docs/installation.html#sec-edge-build).\n\n2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.\n\n\n3. Run the fine-tuning:\n\n    glm-4-6v-flash(9B)\n    ```bash\n    axolotl train examples/glm46v/glm-4-6v-flash-qlora.yaml\n    ```\n\nLet us know how it goes. Happy finetuning! 🚀\n\n## Tips\n\n- Vision datasets should follow the format described in the [multimodal docs](https://docs.axolotl.ai/docs/multimodal.html#dataset-format)\n- You can run a **full finetuning** by removing the `adapter: qlora` and `load_in_4bit: true` from the config.\n- Read more on how to load your own dataset in the [dataset loading docs](https://docs.axolotl.ai/docs/dataset_loading.html).\n\n## Supported Models\n\n- **GLM-4.6V**: Full vision-language model (`zai-org/GLM-4.6V`)\n- **GLM-4.6V-Flash**: Faster variant (`zai-org/GLM-4.6V-Flash`)\n\n## Optimization Guides\n\nPlease check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).\n\n## Related Resources\n\n- [ZhipuAI GLM-4.6V](https://huggingface.co/zai-org/GLM-4.6V)\n- [Axolotl Docs](https://docs.axolotl.ai)\n- [Axolotl Website](https://axolotl.ai)\n- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)\n- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)\n"
  },
  {
    "path": "examples/glm46v/glm-4-6v-flash-ddp.yaml",
    "content": "base_model: zai-org/GLM-4.6V-Flash\ntrust_remote_code: true\n\nprocessor_type: AutoProcessor\nload_in_4bit: true\n\n# these 3 lines are needed for now to handle vision chat templates w images\nskip_prepare_dataset: true\nremove_unused_columns: false\nsample_packing: false\nddp_find_unused_parameters: true\n\noutput_dir: ./outputs/glm-4-6v-flash-qlora\ndatasets:\n  - path: HuggingFaceH4/llava-instruct-mix-vsft\n    type: chat_template\n    split: train[:1%]\n\nadapter: qlora\nlora_r: 16\nlora_alpha: 32\nlora_dropout: 0.05\nlora_target_modules:\n  - gate_proj\n  - down_proj\n  - up_proj\n  - q_proj\n  - v_proj\n  - k_proj\n  - o_proj\n\nsequence_len: 2048\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 1\nnum_epochs: 1\noptimizer: adamw_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\ngradient_checkpointing_kwargs:\n  use_reentrant: false\nlogging_steps: 1\nsdp_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 0\nsaves_per_epoch: 1\nweight_decay: 0.0\n"
  },
  {
    "path": "examples/glm46v/glm-4-6v-flash-qlora.yaml",
    "content": "base_model: zai-org/GLM-4.6V-Flash\ntrust_remote_code: true\n\nprocessor_type: AutoProcessor\nload_in_4bit: true\n\n# these 3 lines are needed for now to handle vision chat templates w images\nskip_prepare_dataset: true\nremove_unused_columns: false\nsample_packing: false\n\noutput_dir: ./outputs/glm-4-6v-flash-qlora\ndatasets:\n  - path: HuggingFaceH4/llava-instruct-mix-vsft\n    type: chat_template\n    split: train[:1%]\n\nadapter: qlora\nlora_r: 16\nlora_alpha: 32\nlora_dropout: 0.05\nlora_target_modules:\n  - gate_proj\n  - down_proj\n  - up_proj\n  - q_proj\n  - v_proj\n  - k_proj\n  - o_proj\n\nsequence_len: 2048\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 1\nnum_epochs: 1\noptimizer: adamw_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\nlogging_steps: 1\nsdp_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 0\nsaves_per_epoch: 1\nweight_decay: 0.0\n"
  },
  {
    "path": "examples/glm47-flash/README.md",
    "content": "# Finetune Z.ai's GLM-4.7-Flash with Axolotl\n\n[GLM-4.7-Flash](https://huggingface.co/zai-org/GLM-4.7-Flash) is a 30B-A3B MoE model by Z.ai.\n\nThis guide shows how to fine-tune it with Axolotl.\n\n## Getting started\n\n1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).\n\n2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.\n\n3. Run the finetuning example:\n\n```bash\n# QLoRA\n# - no target experts (1x48GB @ ~24GiB/GPU)\n# - target experts (1x48GB @ ~34GiB/GPU)\naxolotl train examples/glm47-flash/qlora.yaml\n\n# QLoRA FSDP2 no target experts (2x48GB @ ~29GiB/GPU)\naxolotl train examples/glm47-flash/qlora_fsdp.yaml\n```\n\n```bash\n# LoRA\n# - no target experts (1x48GB @ ~35GiB/GPU)\n# - target experts (1x48GB @ OOM. Projected ~45-50GiB/GPU)\naxolotl train examples/glm47-flash/lora.yaml\n\n# LoRA FSDP2 no target experts (2x48GB @ ~43GiB/GPU)\naxolotl train examples/glm47-flash/lora_fsdp.yaml\n```\n\n### MoE Expert Quantization & Expert LoRA\n\nThis model quantize expert weights on load. To learn about expert quantization, expert LoRA targeting, and related limitations, see the [MoE Expert Quantization](https://docs.axolotl.ai/docs/expert_quantization.html) docs.\n\n## Limitations\n\n- **lora_target_linear**: Incompatible for this model.\n- **LoRA kernels**: Incompatible with this model due to non-standard attention projections (DSA). Must be explicitly disabled (`lora_*_kernel: false`).\n\n\n### TIPS\n\n- For inference, the official Z.ai team recommends these default settings (most tasks):\n  - `temperature: 1.0`\n  - `top_p: 0.95`\n  - `max_new_tokens: 131072`\n- You can run a full finetuning by removing `adapter: qlora`, `load_in_4bit: true`, and `quantize_moe_experts: true` from the config. This is heavy, so we have not tested this.\n- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).\n\n## Optimization Guides\n\nPlease check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).\n\n## Related Resources\n\n- [GLM-4.7-Flash on HuggingFace](https://huggingface.co/zai-org/GLM-4.7-Flash)\n- [GLM-4.7 Blog](https://z.ai/blog/glm-4.7)\n- [Axolotl Docs](https://docs.axolotl.ai)\n- [Axolotl Website](https://axolotl.ai)\n- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)\n- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)\n"
  },
  {
    "path": "examples/glm47-flash/lora.yaml",
    "content": "base_model: zai-org/GLM-4.7-Flash\n\nplugins:\n  - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin\n\nload_in_8bit: true\nquantize_moe_experts: true\n\ndatasets:\n  - path: fozziethebeat/alpaca_messages_2k_test\n    type: chat_template\n\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.1\noutput_dir: ./outputs/glm4.7-flash-lora-8bit-out\n\nadapter: lora\nlora_model_dir:\n\nsequence_len: 2048\nsample_packing: true\n\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0\nlora_target_modules:\n  - q_proj\n  - v_proj\n  - k_proj\n  - o_proj\n\n# Uncomment to also target MoE expert weights:\n# lora_target_parameters:\n#   - mlp.experts.gate_up_proj\n#   - mlp.experts.down_proj\n\n# LoRA kernels incompatible with DSA attention\nlora_mlp_kernel: false\nlora_qkv_kernel: false\nlora_o_kernel: false\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 2\nnum_epochs: 1\noptimizer: adamw_torch_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 1\nsaves_per_epoch: 1\n"
  },
  {
    "path": "examples/glm47-flash/lora_fsdp.yaml",
    "content": "base_model: zai-org/GLM-4.7-Flash\n\nplugins:\n  - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin\n\nload_in_8bit: true\nquantize_moe_experts: true\n\ndatasets:\n  - path: fozziethebeat/alpaca_messages_2k_test\n    type: chat_template\n\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.1\noutput_dir: ./outputs/glm4.7-flash-lora-8bit-fsdp-out\n\nadapter: lora\nlora_model_dir:\n\nsequence_len: 2048\nsample_packing: true\n\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0\nlora_target_modules:\n  - q_proj\n  - v_proj\n  - k_proj\n  - o_proj\n\n# Uncomment to also target MoE expert weights:\n# lora_target_parameters:\n#   - mlp.experts.gate_up_proj\n#   - mlp.experts.down_proj\n\n# LoRA kernels incompatible with DSA attention\nlora_mlp_kernel: false\nlora_qkv_kernel: false\nlora_o_kernel: false\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 2\nnum_epochs: 1\noptimizer: adamw_torch_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: false\n\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 1\nsaves_per_epoch: 1\n\nfsdp_config:\n  fsdp_version: 2\n  offload_params: false\n  cpu_ram_efficient_loading: false\n  auto_wrap_policy: TRANSFORMER_BASED_WRAP\n  transformer_layer_cls_to_wrap: Glm4MoeLiteDecoderLayer\n  state_dict_type: FULL_STATE_DICT\n  sharding_strategy: FULL_SHARD\n  reshard_after_forward: true\n  activation_checkpointing: true\n"
  },
  {
    "path": "examples/glm47-flash/qlora.yaml",
    "content": "base_model: zai-org/GLM-4.7-Flash\n\nplugins:\n  - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin\n\nload_in_4bit: true\nquantize_moe_experts: true\n\ndatasets:\n  - path: fozziethebeat/alpaca_messages_2k_test\n    type: chat_template\n\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.1\noutput_dir: ./outputs/glm4.7-flash-qlora-out\n\nadapter: qlora\nlora_model_dir:\n\nsequence_len: 2048\nsample_packing: true\n\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0\nlora_target_modules:\n  - q_proj\n  - v_proj\n  - k_proj\n  - o_proj\n\n# Uncomment to also target MoE expert weights:\n# lora_target_parameters:\n#   - mlp.experts.gate_up_proj\n#   - mlp.experts.down_proj\n\n# LoRA kernels incompatible with DSA attention\nlora_mlp_kernel: false\nlora_qkv_kernel: false\nlora_o_kernel: false\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 2\nnum_epochs: 1\noptimizer: adamw_torch_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 1\nsaves_per_epoch: 1\n"
  },
  {
    "path": "examples/glm47-flash/qlora_fsdp.yaml",
    "content": "base_model: zai-org/GLM-4.7-Flash\n\nplugins:\n  - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin\n\nload_in_4bit: true\nquantize_moe_experts: true\n\ndatasets:\n  - path: fozziethebeat/alpaca_messages_2k_test\n    type: chat_template\n\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.1\noutput_dir: ./outputs/glm4.7-flash-qlora-fsdp-out\n\nadapter: qlora\nlora_model_dir:\n\nsequence_len: 2048\nsample_packing: true\n\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0\nlora_target_modules:\n  - q_proj\n  - v_proj\n  - k_proj\n  - o_proj\n\n# Uncomment to also target MoE expert weights:\n# lora_target_parameters:\n#   - mlp.experts.gate_up_proj\n#   - mlp.experts.down_proj\n\n# LoRA kernels incompatible with DSA attention\nlora_mlp_kernel: false\nlora_qkv_kernel: false\nlora_o_kernel: false\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 2\nnum_epochs: 1\noptimizer: adamw_torch_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: false\n\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 1\nsaves_per_epoch: 1\n\nfsdp_config:\n  fsdp_version: 2\n  offload_params: false\n  cpu_ram_efficient_loading: false\n  auto_wrap_policy: TRANSFORMER_BASED_WRAP\n  transformer_layer_cls_to_wrap: Glm4MoeLiteDecoderLayer\n  state_dict_type: FULL_STATE_DICT\n  sharding_strategy: FULL_SHARD\n  reshard_after_forward: true\n  activation_checkpointing: true\n"
  },
  {
    "path": "examples/gpt-oss/README.md",
    "content": "# Finetune OpenAI's GPT-OSS with Axolotl\n\n[GPT-OSS](https://huggingface.co/collections/openai/gpt-oss-68911959590a1634ba11c7a4) are a family of open-weight MoE models trained by OpenAI, released in August 2025. There are two variants: 20B and 120B.\n\nIn October 2025, OpenAI released safeguard models built upon GPT-OSS called [GPT-OSS-Safeguard](https://huggingface.co/collections/openai/gpt-oss-safeguard). They use the same architecture, so the same examples below can be re-used.\n\nThis guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.\n\n## Getting started\n\n1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).\n\n    Here is an example of how to install from pip:\n\n```bash\n# Ensure you have Pytorch installed (Pytorch 2.6.0 min)\npip3 install packaging==26.0 setuptools==75.8.0 wheel ninja\npip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'\n```\n\n2. Choose one of the following configs below for training the 20B model. (for 120B, see [below](#training-120b))\n\n```bash\n# LoRA SFT linear layers (1x48GB @ ~44GiB)\naxolotl train examples/gpt-oss/gpt-oss-20b-sft-lora-singlegpu.yaml\n\n# FFT SFT with offloading (2x24GB @ ~21GiB/GPU)\naxolotl train examples/gpt-oss/gpt-oss-20b-fft-fsdp2-offload.yaml\n\n# FFT SFT (8x48GB @ ~36GiB/GPU or 4x80GB @ ~46GiB/GPU)\naxolotl train examples/gpt-oss/gpt-oss-20b-fft-fsdp2.yaml\n```\n\nNote: Memory usage taken from `device_mem_reserved(gib)` from logs.\n\n### Training 120B\n\nOn 8xH100s, make sure you have ~3TB of free disk space. With each checkpoint clocking in at ~720GB, along with the base\nmodel, and final model output, you may need at least 3TB of free disk space to keep at least 2 checkpoints.\n\n```bash\n# FFT SFT with offloading (8x80GB @ ~49GiB/GPU)\naxolotl train examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml\n```\n\nTo simplify fine-tuning across 2 nodes × 8x H100 (80GB) GPUs, we've partnered with [Baseten](https://baseten.co) to showcase multi-node\ntraining of the 120B model using Baseten Truss. You can read more about this recipe on\n[Baseten's blog](https://www.baseten.co/blog/how-to-fine-tune-gpt-oss-120b-with-baseten-and-axolotl/). The recipe can\nbe found on their\n[GitHub](https://github.com/basetenlabs/ml-cookbook/tree/main/examples/oss-gpt-120b-axolotl/training).\n\nERRATA: Transformers saves the model Architecture prefixed with `FSDP` which needs to be manually renamed in `config.json`.\nSee https://github.com/huggingface/transformers/pull/40207 for the status of this issue.\n\n```bash\nsed -i 's/FSDPGptOssForCausalLM/GptOssForCausalLM/g' ./outputs/gpt-oss-out/config.json\n```\n\nWhen using SHARDED_STATE_DICT with FSDP, the final checkpoint should automatically merge the sharded weights to your\nconfigured `output_dir`. However, if that step fails due to a disk space error, you can take an additional step to\nmerge the sharded weights.  This step will automatically determine the last checkpoint directory and merge the sharded\nweights to `{output_dir}/merged`.\n\n```bash\naxolotl merge-sharded-fsdp-weights examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml\nmv ./outputs/gpt-oss-out/merged/* ./outputs/gpt-oss-out/\n```\n\n### How to set reasoning_effort in template?\n\nThe harmony template has a feature to set the `reasoning_effort` during prompt building. The default is `medium`. If you would like to adjust this, you can add the following to your config:\n\n```yaml\nchat_template_kwargs:\n  reasoning_effort: \"high\"  # low | medium | high\n```\n\nCurrently, this applies globally. There is no method to apply per sample yet. If you are interested in adding this, please feel free to create an Issue to discuss.\n\n### Inferencing your fine-tuned model\n\n#### vLLM\n\nGPT-OSS support in vLLM does not exist in a stable release yet. See https://x.com/MaziyarPanahi/status/1955741905515323425\nfor more information about using a special vllm-openai docker image for inferencing with vLLM.\n\nOptionally, vLLM can be installed from nightly:\n\n```bash\npip install --no-build-isolation --pre -U vllm --extra-index-url https://wheels.vllm.ai/nightly\n```\nand the vLLM server can be started with the following command (modify `--tensor-parallel-size 8` to match your environment):\n```bash\nvllm serve ./outputs/gpt-oss-out/ --served-model-name axolotl/gpt-oss-20b --host 0.0.0.0 --port 8888  --tensor-parallel-size 8\n```\n\n#### SGLang\n\nSGLang has 0-day support in main, see https://github.com/sgl-project/sglang/issues/8833 for infomation on installing\nSGLang from source. Once you've installed SGLang, run the following command to launch a SGLang server:\n\n```bash\npython3 -m sglang.launch_server --model ./outputs/gpt-oss-out/ --served-model-name axolotl/gpt-oss-120b --host 0.0.0.0 --port 8888 --tp 8\n```\n\n### Tool use\n\nGPT-OSS has a comprehensive tool understanding. Axolotl supports tool calling datasets for Supervised Fine-tuning.\n\nHere is an example dataset config:\n```yaml\ndatasets:\n  - path: Nanobit/text-tools-2k-test\n    type: chat_template\n```\n\nSee [Nanobit/text-tools-2k-test](https://huggingface.co/datasets/Nanobit/text-tools-2k-test) for the sample dataset.\n\nRefer to [our docs](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#using-tool-use) for more info.\n\n### Thinking and chat_template masking conflict\n\nOpenAI’s Harmony template hides `thinking` in all non-final turns, which conflicts with Axolotl’s `chat_template` masking.\n\nIf your dataset has `thinking` content mid-turn, there are two paths we recommend:\n\n- Train only on the last turn. This can be accomplished via chat_template's [train on last doc](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#training-on-last-message).\n\n- Adjust your dataset to only have `thinking` content in the last turn.\n\n### TIPS\n\n- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).\n- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).\n\n## Optimization Guides\n\n- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)\n- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)\n\n## Related Resources\n\n- [GPT-OSS Blog](https://openai.com/index/introducing-gpt-oss/)\n- [Axolotl Docs](https://docs.axolotl.ai)\n- [Axolotl Website](https://axolotl.ai)\n- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)\n- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)\n"
  },
  {
    "path": "examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml",
    "content": "# the original mxfp4 quantized model is not supported with FSDP cpu_ram_efficient_loading\n# FSDP cpu_ram_efficient_loading is used to reduce the initial CPU memory usage when loading the model\nbase_model: axolotl-ai-co/gpt-oss-120b-dequantized\n\nuse_kernels: false\n\ndp_shard_size: 16  # requires 2x8xH100 nodes\n\nplugins:\n  - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin\n\nexperimental_skip_move_to_device: true  # prevent OOM by NOT putting model to GPU before sharding\n\ndatasets:\n  - path: HuggingFaceH4/Multilingual-Thinking\n    type: chat_template\n    field_thinking: thinking\n    template_thinking_key: thinking\n\ndataset_prepared_path: last_run_prepared\nval_set_size: 0\noutput_dir: ./outputs/gpt-oss-out/\nsave_total_limit: 2  # the 120B model can use up to 720GB of disk space per checkpoint, so let's only keep the last 2\n\nsequence_len: 4096\nsample_packing: true\npad_to_sequence_len: true\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ntrackio_project_name:\ntrackio_run_name:\ntrackio_space_id:\n\ngradient_accumulation_steps: 2\nmicro_batch_size: 1\nnum_epochs: 1\n\noptimizer: adamw_torch_fused  # 8bit optimizers do not work with FSDP2 offload\nlr_scheduler: constant_with_warmup\nlearning_rate: 2e-5\n\nbf16: true\ntf32: true\n\nflash_attention: true\nattn_implementation: kernels-community/vllm-flash-attn3  # this is not needed if using flash_attn >= 2.8.3\n\ngradient_checkpointing: true\nactivation_offloading: true\n\nlogging_steps: 1\nsaves_per_epoch: 1\n\nwarmup_ratio: 0.03\n\nspecial_tokens:\neot_tokens:\n  - \"<|end|>\"\n\nfsdp_version: 2\nfsdp_config:\n  offload_params: true\n  state_dict_type: SHARDED_STATE_DICT\n  auto_wrap_policy: TRANSFORMER_BASED_WRAP\n  transformer_layer_cls_to_wrap: GptOssDecoderLayer\n  reshard_after_forward: true\n  cpu_ram_efficient_loading: true\n"
  },
  {
    "path": "examples/gpt-oss/gpt-oss-20b-fft-deepspeed-zero3.yaml",
    "content": "base_model: openai/gpt-oss-20b\nuse_kernels: false\nmodel_quantization_config: Mxfp4Config\nmodel_quantization_config_kwargs:\n  dequantize: true\n\nplugins:\n  - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin\n\nexperimental_skip_move_to_device: true  # prevent OOM by NOT putting model to GPU before sharding\n\ndatasets:\n  - path: HuggingFaceH4/Multilingual-Thinking\n    type: chat_template\n    field_thinking: thinking\n    template_thinking_key: thinking\n\ndataset_prepared_path: last_run_prepared\nval_set_size: 0\noutput_dir: ./outputs/gpt-oss-out/\n\nsequence_len: 4096\nsample_packing: true\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ntrackio_project_name:\ntrackio_run_name:\ntrackio_space_id:\n\ngradient_accumulation_steps: 2\nmicro_batch_size: 1\nnum_epochs: 1\n\noptimizer: adamw_torch_8bit\nlr_scheduler: constant_with_warmup\nlearning_rate: 2e-5\n\nbf16: true\ntf32: true\n\nflash_attention: true\nattn_implementation: kernels-community/vllm-flash-attn3  # this is not needed if using flash_attn >= 2.8.3\n\ngradient_checkpointing: true\nactivation_offloading: true\n\nlogging_steps: 1\nsaves_per_epoch: 1\n\nwarmup_ratio: 0.03\n\nspecial_tokens:\neot_tokens:\n  - \"<|end|>\"\n\n# choose the zero3 configuration that best fits your system capabilities\ndeepspeed: deepspeed_configs/zero3_bf16.json\n"
  },
  {
    "path": "examples/gpt-oss/gpt-oss-20b-fft-fsdp2-offload.yaml",
    "content": "base_model: openai/gpt-oss-20b\nuse_kernels: true\nmodel_quantization_config: Mxfp4Config\nmodel_quantization_config_kwargs:\n  dequantize: true\n\nplugins:\n  - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin\n\nexperimental_skip_move_to_device: true  # prevent OOM by NOT putting model to GPU before sharding\n\ndatasets:\n  - path: HuggingFaceH4/Multilingual-Thinking\n    type: chat_template\n    field_thinking: thinking\n    template_thinking_key: thinking\n\ndataset_prepared_path: ./outputs/last_run_prepared\nval_set_size: 0\noutput_dir: ./outputs/gpt-oss-out/\n\nsequence_len: 4096\nsample_packing: true\npad_to_sequence_len: true\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ntrackio_project_name:\ntrackio_run_name:\ntrackio_space_id:\n\ngradient_accumulation_steps: 2\nmicro_batch_size: 1\nnum_epochs: 1\n\noptimizer: adamw_torch_fused  # 8bit optimizers do not work with FSDP2 offload\nlr_scheduler: constant_with_warmup\nlearning_rate: 2e-5\n\nbf16: true\ntf32: true\n\nflash_attention: true\nattn_implementation: kernels-community/vllm-flash-attn3  # this is not needed if using flash_attn >= 2.8.3\n\ngradient_checkpointing: true\nactivation_offloading: true\n\nlogging_steps: 1\nsaves_per_epoch: 1\n\nwarmup_ratio: 0.03\n\nspecial_tokens:\neot_tokens:\n  - \"<|end|>\"\n\nfsdp_version: 2\nfsdp_config:\n  offload_params: true\n  state_dict_type: SHARDED_STATE_DICT\n  auto_wrap_policy: TRANSFORMER_BASED_WRAP\n  transformer_layer_cls_to_wrap: GptOssDecoderLayer\n  reshard_after_forward: true\n  #  cpu_ram_efficient_loading: true\n\n# cpu_ram_efficient_loading cannot be used with MXFP4 model quantization.\n# It can only be used with a dequantized model like `axolotl-ai-co/gpt-oss-120b-dequantized`\n"
  },
  {
    "path": "examples/gpt-oss/gpt-oss-20b-fft-fsdp2.yaml",
    "content": "base_model: openai/gpt-oss-20b\nuse_kernels: false\nmodel_quantization_config: Mxfp4Config\nmodel_quantization_config_kwargs:\n  dequantize: true\n\nplugins:\n  - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin\n\nexperimental_skip_move_to_device: true  # prevent OOM by NOT putting model to GPU before sharding\n\ndatasets:\n  - path: HuggingFaceH4/Multilingual-Thinking\n    type: chat_template\n    field_thinking: thinking\n    template_thinking_key: thinking\n\ndataset_prepared_path: ./outputs/last_run_prepared\nval_set_size: 0\noutput_dir: ./outputs/gpt-oss-out/\n\nsequence_len: 4096\nsample_packing: true\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ntrackio_project_name:\ntrackio_run_name:\ntrackio_space_id:\n\ngradient_accumulation_steps: 2\nmicro_batch_size: 1\nnum_epochs: 1\n\noptimizer: adamw_torch_8bit\nlr_scheduler: constant_with_warmup\nlearning_rate: 2e-5\n\nbf16: true\ntf32: true\n\nflash_attention: true\nattn_implementation: kernels-community/vllm-flash-attn3  # this is not needed if using flash_attn >= 2.8.3\n\ngradient_checkpointing: true\nactivation_offloading: true\n\nlogging_steps: 1\nsaves_per_epoch: 1\n\nwarmup_ratio: 0.03\n\nspecial_tokens:\neot_tokens:\n  - \"<|end|>\"\n\nfsdp_version: 2\nfsdp_config:\n  offload_params: false\n  state_dict_type: SHARDED_STATE_DICT\n  auto_wrap_policy: TRANSFORMER_BASED_WRAP\n  transformer_layer_cls_to_wrap: GptOssDecoderLayer\n  reshard_after_forward: true\n#  cpu_ram_efficient_loading: true\n"
  },
  {
    "path": "examples/gpt-oss/gpt-oss-20b-sft-lora-singlegpu.yaml",
    "content": "base_model: openai/gpt-oss-20b\nuse_kernels: true\nmodel_quantization_config: Mxfp4Config\nmodel_quantization_config_kwargs:\n  dequantize: true\n\nplugins:\n  - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin\n\nexperimental_skip_move_to_device: true  # prevent OOM by not putting model to GPU before sharding\n\ndatasets:\n  - path: HuggingFaceH4/Multilingual-Thinking\n    type: chat_template\n    field_thinking: thinking\n    template_thinking_key: thinking\n\ndataset_prepared_path: last_run_prepared\nval_set_size: 0\noutput_dir: ./outputs/gpt-oss-out/\n\nsequence_len: 4096\nsample_packing: true\n\nadapter: lora\nlora_r: 8\nlora_alpha: 16\nlora_dropout: 0.0  # dropout not supported when using LoRA over expert parameters\nlora_target_linear: true\n\n# TODO: not supported for now, see peft#2710\n#lora_target_parameters:  # target the experts in the last two layers\n#  - \"22._checkpoint_wrapped_module.mlp.experts.gate_up_proj\"\n#  - \"22._checkpoint_wrapped_module.mlp.experts.down_proj\"\n#  - \"23._checkpoint_wrapped_module.mlp.experts.gate_up_proj\"\n#  - \"23._checkpoint_wrapped_module.mlp.experts.down_proj\"\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ntrackio_project_name:\ntrackio_run_name:\ntrackio_space_id:\n\ngradient_accumulation_steps: 8\nmicro_batch_size: 1\nnum_epochs: 1\n\noptimizer: adamw_torch_8bit\nlr_scheduler: constant_with_warmup\nlearning_rate: 2e-4\n\nbf16: true\ntf32: true\n\nflash_attention: true\nattn_implementation: kernels-community/vllm-flash-attn3  # this is not needed if using flash_attn >= 2.8.3\n\ngradient_checkpointing: true\nactivation_offloading: true\n\nlogging_steps: 1\nsaves_per_epoch: 1\nwarmup_ratio: 0.1\n\nspecial_tokens:\neot_tokens:\n  - \"<|end|>\"\n"
  },
  {
    "path": "examples/gpt-oss/gpt-oss-safeguard-20b-sft-lora-singlegpu.yaml",
    "content": "base_model: openai/gpt-oss-safeguard-20b\nuse_kernels: true\nmodel_quantization_config: Mxfp4Config\nmodel_quantization_config_kwargs:\n  dequantize: true\n\nplugins:\n  - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin\n\nexperimental_skip_move_to_device: true  # prevent OOM by not putting model to GPU before sharding\n\ndatasets:\n  - path: HuggingFaceH4/Multilingual-Thinking\n    type: chat_template\n    field_thinking: thinking\n    template_thinking_key: thinking\n\ndataset_prepared_path: last_run_prepared\nval_set_size: 0\noutput_dir: ./outputs/gpt-oss-safeguard-out/\n\nsequence_len: 4096\nsample_packing: true\n\nadapter: lora\nlora_r: 8\nlora_alpha: 16\nlora_dropout: 0.0  # dropout not supported when using LoRA over expert parameters\nlora_target_linear: true\n\n# TODO: not supported for now, see peft#2710\n#lora_target_parameters:  # target the experts in the last two layers\n#  - \"22._checkpoint_wrapped_module.mlp.experts.gate_up_proj\"\n#  - \"22._checkpoint_wrapped_module.mlp.experts.down_proj\"\n#  - \"23._checkpoint_wrapped_module.mlp.experts.gate_up_proj\"\n#  - \"23._checkpoint_wrapped_module.mlp.experts.down_proj\"\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ntrackio_project_name:\ntrackio_run_name:\ntrackio_space_id:\n\ngradient_accumulation_steps: 8\nmicro_batch_size: 1\nnum_epochs: 1\n\noptimizer: adamw_torch_8bit\nlr_scheduler: constant_with_warmup\nlearning_rate: 2e-4\n\nbf16: true\ntf32: true\n\nflash_attention: true\nattn_implementation: kernels-community/vllm-flash-attn3  # this is not needed if using flash_attn >= 2.8.3\n\ngradient_checkpointing: true\nactivation_offloading: true\n\nlogging_steps: 1\nsaves_per_epoch: 1\nwarmup_ratio: 0.1\n\nspecial_tokens:\neot_tokens:\n  - \"<|end|>\"\n"
  },
  {
    "path": "examples/granite4/README.md",
    "content": "# Finetune IBM's Granite 4.0 with Axolotl\n\n[Granite 4.0](https://huggingface.co/collections/ibm-granite/granite-40-language-models) are a family of open source models trained by IBM Research.\n\nThis guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.\n\n## Getting started\n\n1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Granite4 is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).\n\n    Here is an example of how to install from main for pip:\n\n```bash\n# Ensure you have Pytorch installed (Pytorch 2.7.1 min)\ngit clone https://github.com/axolotl-ai-cloud/axolotl.git\ncd axolotl\n\npip3 install packaging==26.0 setuptools==75.8.0 wheel ninja\npip3 install --no-build-isolation -e '.[flash-attn]'\n\n# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy\npython scripts/cutcrossentropy_install.py | sh\n```\n\n2. Run the finetuning example:\n\n```bash\naxolotl train examples/granite4/granite-4.0-tiny-fft.yaml\n```\n\nThis config uses about 40.8GiB VRAM.\n\nLet us know how it goes. Happy finetuning! 🚀\n\n### TIPS\n\n- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).\n- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).\n\n### Limitation\n\nAdapter finetuning does not work at the moment. It would error with\n\n```bash\nRuntimeError: mat1 and mat2 shapes cannot be multiplied (4096x3072 and 1x1179648)\n```\n\nIn addition, if adapter training works, `lora_target_linear: true` will not work due to:\n```bash\nValueError: Target module GraniteMoeHybridParallelExperts() is not supported.\n```\n\n## Optimization Guides\n\n- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)\n- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)\n- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)\n\n## Related Resources\n\n- [Granite Docs](https://www.ibm.com/granite/docs/models/granite)\n- [Axolotl Docs](https://docs.axolotl.ai)\n- [Axolotl Website](https://axolotl.ai)\n- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)\n- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)\n"
  },
  {
    "path": "examples/granite4/granite-4.0-tiny-fft.yaml",
    "content": "base_model: ibm-granite/granite-4.0-tiny-preview\n\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nplugins:\n  - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin\n\ndatasets:\n  - path: fozziethebeat/alpaca_messages_2k_test\n    type: chat_template\n\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.1\noutput_dir: ./outputs/model-out\n\nsequence_len: 2048\nsample_packing: true\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 2\nnum_epochs: 1\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 1\nsaves_per_epoch: 1\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/hunyuan/README.md",
    "content": "# Finetune HunYuan with Axolotl\n\nTencent released a family of opensource models called HunYuan with varying parameter scales of 0.5B, 1.8B, 4B, and 7B scale for both Pre-trained and Instruct variants. The models can be found at [HuggingFace](https://huggingface.co/collections/tencent/hunyuan-dense-model-6890632cda26b19119c9c5e7). This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.\n\n## Getting started\n\n1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as HunYuan is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).\n\n    Here is an example of how to install from main for pip:\n\n```bash\n# Ensure you have Pytorch installed (Pytorch 2.6.0 min)\ngit clone https://github.com/axolotl-ai-cloud/axolotl.git\ncd axolotl\n\npip3 install packaging==26.0 setuptools==75.8.0 wheel ninja\npip3 install --no-build-isolation -e '.[flash-attn]'\n\n# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy\npython scripts/cutcrossentropy_install.py | sh\n```\n\n2. Run the finetuning example:\n\n```bash\naxolotl train examples/hunyuan/hunyuan-v1-dense-qlora.yaml\n```\n\nThis config uses about 4.7 GB VRAM.\n\nLet us know how it goes. Happy finetuning! 🚀\n\n### Dataset\n\nHunYuan Instruct models can choose to enter a slow think or fast think pattern. For best performance on fine-tuning their Instruct models, your dataset should be adjusted to match their pattern.\n\n```python\n# fast think pattern\nmessages = [\n    {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n    {\"role\": \"user\", \"content\": \"/no_think What color is the sun?\" },\n    {\"role\": \"assistant\", \"content\": \"<think>\\n\\n</think>\\n<answer>\\nThe sun is yellow.\\n</answer>\"}\n]\n\n# slow think pattern\nmessages = [\n    {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n    {\"role\": \"user\", \"content\": \"/no_think What color is the sun?\" },\n    {\"role\": \"assistant\", \"content\": \"<think>\\nThe user is asking about the color of the sun. I need to ...\\n</think>\\n<answer>\\nThe sun is yellow.\\n</answer>\"}\n]\n```\n\n### TIPS\n\n- For inference, the official Tencent team recommends\n\n```json\n\n{\n  \"do_sample\": true,\n  \"top_k\": 20,\n  \"top_p\": 0.8,\n  \"repetition_penalty\": 1.05,\n  \"temperature\": 0.7\n}\n\n```\n\n- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.\n- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).\n- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).\n\n## Optimization Guides\n\n- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)\n- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)\n- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)\n\n## Related Resources\n\n- [Tencent HunYuan Blog](https://hunyuan.tencent.com/)\n- [Axolotl Docs](https://docs.axolotl.ai)\n- [Axolotl Website](https://axolotl.ai)\n- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)\n- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)\n"
  },
  {
    "path": "examples/hunyuan/hunyuan-v1-dense-qlora.yaml",
    "content": "base_model: tencent/Hunyuan-0.5B-Instruct\n\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nplugins:\n  - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin\n\nload_in_8bit: false\nload_in_4bit: true\n\ndatasets:\n  - path: fozziethebeat/alpaca_messages_2k_test\n    type: chat_template\n\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.1\noutput_dir: ./outputs/lora-out\n\nadapter: qlora\nlora_model_dir:\n\nsequence_len: 2048\nsample_packing: true\n\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_linear: true\nlora_target_modules:\n  - gate_proj\n  - down_proj\n  - up_proj\n  - q_proj\n  - v_proj\n  - k_proj\n  - o_proj\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 2\nnum_epochs: 1\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 1\nsaves_per_epoch: 1\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/internvl3_5/README.md",
    "content": "# Finetune OpenGV's InternVL with Axolotl\n\n[InternVL 3.5](https://huggingface.co/OpenGVLab/InternVL3_5-8B-HF) is a family of powerful vision-language models supporting dynamic resolution and multi-image understanding by OpenGV. It features a ViT-style vision encoder and strong language model backbone for tasks like visual question answering, OCR, and scene text understanding.\n\nThis guide shows how to fine-tune it with Axolotl.\n\n## Getting started\n\n1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).\n\n2. Install `timm` for vision model support:\n\n    ```bash\n    pip install timm==1.0.19\n    ```\n\n3. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.\n\n4. Run the finetuning example:\n\n    ```bash\n    axolotl train examples/internvl3_5/internvl3_5-8b-qlora.yml\n    ```\n\nThis config uses about 8.21 GiB VRAM. Let us know how it goes. Happy finetuning! 🚀\n\n### Tips\n\n- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.\n- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).\n- The dataset format follows the multi-modal format as seen [here](https://docs.axolotl.ai/docs/multimodal.html#dataset-format).\n\n## Optimization Guides\n\nPlease check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).\n\n## Related Resources\n\n- [InternVL Paper](https://huggingface.co/papers/2508.18265)\n- [Axolotl Docs](https://docs.axolotl.ai)\n- [Axolotl Website](https://axolotl.ai)\n- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)\n- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)\n"
  },
  {
    "path": "examples/internvl3_5/internvl3_5-8b-qlora.yml",
    "content": "base_model: OpenGVLab/InternVL3_5-8B-HF\nprocessor_type: AutoProcessor\n\nplugins:\n  - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin\n\nload_in_4bit: true\n\n# these 3 lines are needed for now to handle vision chat templates w images\nskip_prepare_dataset: true\nremove_unused_columns: false\nsample_packing: false\n\ndatasets:\n  - path: HuggingFaceH4/llava-instruct-mix-vsft\n    type: chat_template\n    split: train[:1%]\n    field_messages: messages\n\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.01\noutput_dir: ./outputs/out\n\nadapter: qlora\nlora_model_dir:\n\nsequence_len: 2048\n\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_modules: 'model.language_model.layers.[\\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 2\nnum_epochs: 1\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: true\nfp16:\ntf32: true\n\ngradient_checkpointing: true\nlogging_steps: 1\nflash_attention: true\neager_attention:\n\nwarmup_ratio: 0.1\nevals_per_epoch: 1\nsaves_per_epoch: 1\nweight_decay: 0.0\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/jamba/README.md",
    "content": "# Jamba\n\n- ✅ qlora w/ deepspeed Zero-2 needs at least 2x GPUs and\n  - 35GiB VRAM per GPU w minimal context length\n  - 56GiB VRAM per GPU (w multipack enabled)\n- ✅ qlora w/ deepspeed Zero-3 needs at least 2x GPUs and 67GiB VRAM (wtf?)\n- ✅ qlora single-gpu, ~51GiB VRAM\n- ✅ multipack\n- ✅ FSDP\n- ❓ 8-bit LoRA\n"
  },
  {
    "path": "examples/jamba/qlora.yaml",
    "content": "base_model: ai21labs/Jamba-v0.1\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\ntrust_remote_code: true\n\nload_in_8bit: false\nload_in_4bit: true\n\ndatasets:\n  - path: mhenrichsen/alpaca_2k_test\n    type: alpaca\ndataset_prepared_path:\nval_set_size: 0.0\noutput_dir: ./outputs/out\n\nsequence_len: 4096\nsample_packing: false\npad_to_sequence_len: false\neval_sample_packing: false\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\nadapter: qlora\nlora_r: 8\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_linear: true\n\nlow_cpu_mem_usage: true\ngradient_accumulation_steps: 4\nmicro_batch_size: 1\nnum_epochs: 2\noptimizer: paged_adamw_8bit\nlr_scheduler: cosine\nlearning_rate: 0.00001\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\ngradient_checkpointing_kwargs:\n  use_reentrant: false\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch:\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/jamba/qlora_deepspeed.yaml",
    "content": "base_model: ai21labs/Jamba-v0.1\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\ntrust_remote_code: true\n\nload_in_8bit: false\nload_in_4bit: true\n\ndatasets:\n  - path: mhenrichsen/alpaca_2k_test\n    type: alpaca\ndataset_prepared_path:\nval_set_size: 0.0\noutput_dir: ./outputs/out\n\nsequence_len: 4096\nsample_packing: false\npad_to_sequence_len: false\neval_sample_packing: false\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\nadapter: qlora\nlora_r: 8\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_linear: true\n\nlow_cpu_mem_usage: true\ngradient_accumulation_steps: 4\nmicro_batch_size: 1\nnum_epochs: 2\noptimizer: paged_adamw_8bit\nlr_scheduler: cosine\nlearning_rate: 0.00001\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\ngradient_checkpointing_kwargs:\n  use_reentrant: false\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch:\nsaves_per_epoch: 1\n\ndeepspeed: deepspeed_configs/zero2.json\nweight_decay: 0.0\nspecial_tokens:\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/jamba/qlora_fsdp_large.yaml",
    "content": "base_model: ai21labs/AI21-Jamba-1.5-Large\n# optionally might have model_type or tokenizer_type\ntokenizer_type: AutoTokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nload_in_4bit: true\nuse_tensorboard: true\nchat_template: jamba\ndatasets:\n  - path: cgato/SlimOrcaDedupCleaned\n    type: chat_template\n    drop_system_message: true\n    field_messages: conversations\n    message_property_mappings:\n      role: from\n      content: value\n\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.0\noutput_dir: jamba-large-fsdp-qlora-ft\nadapter: qlora\nsequence_len: 2048\nsample_packing: true\n\n\nlora_r: 16\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_modules: [down_proj,gate_proj,in_proj,k_proj,o_proj,out_proj,q_proj,up_proj,v_proj,x_proj]\nlora_target_linear: false\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 1\nnum_epochs: 2\noptimizer: adamw_torch_fused\nlr_scheduler: cosine\nlearning_rate: 0.00001\n\nbf16: true\ntf32: true\n\ngradient_checkpointing: true\ngradient_checkpointing_kwargs:\n  use_reentrant: true\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 1\nsaves_per_epoch: 1\nweight_decay: 0.0\nfsdp:\n  - full_shard\n  - auto_wrap\nfsdp_config:\n  fsdp_limit_all_gathers: true\n  fsdp_sync_module_states: true\n  fsdp_offload_params: false\n  fsdp_use_orig_params: false\n  fsdp_cpu_ram_efficient_loading: true\n  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP\n  fsdp_transformer_layer_cls_to_wrap: JambaAttentionDecoderLayer,JambaMambaDecoderLayer\n  fsdp_state_dict_type: FULL_STATE_DICT\n  fsdp_sharding_strategy: FULL_SHARD\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/kimi-linear/README.md",
    "content": "# Finetune MoonshotAI's Kimi Linear with Axolotl\n\n[Kimi Linear](https://huggingface.co/collections/moonshotai/kimi-linear-a3b) is a MoE model (48B total, 3B active) by MoonshotAI using a hybrid linear attention architecture to achieve a 1M token context length. It uses Kimi Delta Attention (KDA), a refined version of Gated DeltaNet that reduces KV cache size by up to 75% and boosts decoding throughput by up to 6x for long contexts.\n\nThis guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.\n\n**Note:** Axolotl uses experimental training code for Kimi Linear as their original modeling code is inference-only.\n\n## Getting started\n\n1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).\n\n2. Install CCE via [docs](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy)\n\n3. Run the finetuning example:\n\n    ```bash\n    axolotl train examples/kimi-linear/kimi-48b-lora.yaml\n    ```\n\nThis config uses about 98.7GiB VRAM.\n\nLet us know how it goes. Happy finetuning!\n\n### TIPS\n\n- Kimi Linear requires `trust_remote_code: true`.\n- You can run a full finetuning by removing the `adapter: lora` and `load_in_8bit: true`.\n- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html)\n- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template)\n\n## Optimization Guides\n\nSee 👉 [docs](https://docs.axolotl.ai/docs/optimizations.html).\n\n## Limitations\n\nThis is not yet compatible with MoE kernels from transformers v5.\n\n## Related Resources\n\n- [Kimi Linear Paper](https://huggingface.co/papers/2510.26692)\n- [Kimi Linear GitHub](https://github.com/MoonshotAI/Kimi-Linear)\n- [Axolotl Docs](https://docs.axolotl.ai)\n- [Axolotl Website](https://axolotl.ai)\n- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)\n- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)\n"
  },
  {
    "path": "examples/kimi-linear/kimi-48b-lora.yaml",
    "content": "base_model: moonshotai/Kimi-Linear-48B-A3B-Instruct\n\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\ntrust_remote_code: true\n\nplugins:\n  - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin\n\nload_in_8bit: true\nload_in_4bit: false\nstrict: false\n\ndatasets:\n  - path: fozziethebeat/alpaca_messages_2k_test\n    type: chat_template\n    split: train\n\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.2\noutput_dir: ./outputs/lora-out\n\nadapter: lora\nlora_model_dir:\n\nsequence_len: 2048\nsample_packing: true\npad_to_sequence_len: true\n\nlora_r: 16\nlora_alpha: 32\nlora_dropout: 0.05\nlora_fan_in_fan_out:\nlora_target_modules:\n  - gate_proj\n  - down_proj\n  - up_proj\n  - q_proj\n  - v_proj\n  - k_proj\n  - o_proj\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 2\nmicro_batch_size: 2\nnum_epochs: 1\noptimizer: adamw_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\ntrain_on_inputs: false\ngroup_by_length: false\nbf16: auto\nfp16:\ntf32: false\n\ngradient_checkpointing: true\nearly_stopping_patience:\nresume_from_checkpoint:\nlocal_rank:\nlogging_steps: 1\nflash_attention: true\n\nloss_watchdog_threshold: 5.0\nloss_watchdog_patience: 3\n\nwarmup_ratio: 0.1\nevals_per_epoch: 2\nsaves_per_epoch: 1\ndebug:\ndeepspeed:\nweight_decay: 0.0\nfsdp:\nfsdp_config:\nspecial_tokens:\n"
  },
  {
    "path": "examples/llama-2/README.md",
    "content": "# Overview\n\nThis is an example of a llama-2 configuration for 7b and 13b. The yaml file contains configuration for the 7b variant, but you can just aswell use the same settings for 13b.\n\nThe 7b variant fits on any 24GB VRAM GPU and will take up about 17 GB of VRAM during training if using qlora and 20 GB if using lora. On a RTX 4090 it trains 3 epochs of the default dataset in about 15 minutes.\n\nThe 13b variant will fit if you change these settings to these values:\ngradient_accumulation_steps: 2\nmicro_batch_size: 1\n\n```shell\naccelerate launch -m axolotl.cli.train examples/llama-2/qlora.yml\n```\nor\n\n```shell\naccelerate launch -m axolotl.cli.train examples/llama-2/lora.yml\n```\n\nTo launch a full finetuning with 16-bit precision:\n\n```shell\naccelerate launch -m axolotl.cli.train examples/llama-2/fft_optimized.yml\n```\n"
  },
  {
    "path": "examples/llama-2/fft_optimized.yml",
    "content": "base_model: NousResearch/Llama-2-7b-hf\n# optionally might have model_type or tokenizer_type\nmodel_type: LlamaForCausalLM\ntokenizer_type: LlamaTokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\ndatasets:\n  - path: mhenrichsen/alpaca_2k_test\n    type: alpaca\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.05\noutput_dir: ./outputs/out\n\nsequence_len: 4096\nsample_packing: true\n\n\nadapter:\nlora_model_dir:\nlora_r:\nlora_alpha:\nlora_dropout:\nlora_target_linear:\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 1\nmicro_batch_size: 1\nnum_epochs: 1\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\nflash_attn_cross_entropy: false\nflash_attn_rms_norm: true\nflash_attn_fuse_mlp: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\n\ndeepspeed: #deepspeed_configs/zero2.json # multi-gpu only\nweight_decay: 0.1\nspecial_tokens:\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/llama-2/gptq-lora.yml",
    "content": "base_model: TheBloke/Llama-2-7B-GPTQ\n# optionally might have model_type or tokenizer_type\nmodel_type: AutoModelForCausalLM\ntokenizer_type: LlamaTokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\ngptq: true\ngptq_disable_exllama: true\n\ntokenizer_use_fast: true\ntokenizer_legacy: true\npush_dataset_to_hub:\nhf_use_auth_token: true\ndatasets:\n  - path: mhenrichsen/alpaca_2k_test\n    type: alpaca\ndataset_prepared_path:\nval_set_size: 0.05\nadapter: lora\nlora_model_dir:\nsequence_len: 4096\nsample_packing:\nlora_r: 8\nlora_alpha: 32\nlora_dropout: 0.05\nlora_target_modules:\n  - k_proj\n  - o_proj\n  - q_proj\n  - v_proj\nlora_target_linear:\nwandb_project:\nwandb_watch:\nwandb_name:\nwandb_log_model:\noutput_dir: ./outputs/model-out\ngradient_accumulation_steps: 1\nmicro_batch_size: 1\nnum_epochs: 4\noptimizer: adamw_torch_fused\nadam_beta2: 0.95\nadam_eps: 0.00001\nmax_grad_norm: 1.0\ntorchdistx_path:\nlr_scheduler: cosine\nlr_quadratic_warmup: true\nlearning_rate: 0.000017\nbf16: false\nfp16: false\nfloat16: true\ntf32: true\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention:\nsdp_attention:\nflash_optimum:\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\nweight_decay: 0.1\nspecial_tokens:\n  bos_token: \"<s>\"\n  eos_token: \"</s>\"\n  unk_token: \"<unk>\"\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/llama-2/lisa.yml",
    "content": "base_model: NousResearch/Llama-2-7b-hf\n# optionally might have model_type or tokenizer_type\nmodel_type: LlamaForCausalLM\ntokenizer_type: LlamaTokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\ndatasets:\n  - path: teknium/GPT4-LLM-Cleaned\n    type: alpaca\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.05\noutput_dir: ./outputs/lisa-out\n\nsequence_len: 4096\nsample_packing: true\n\n\nadapter:\nlora_model_dir:\nlora_r:\nlora_alpha:\nlora_dropout:\nlora_target_linear:\n\nlisa_n_layers: 4\nlisa_step_interval: 20\nlisa_layers_attribute: model.layers\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 2\nmicro_batch_size: 1\nnum_epochs: 1\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 5e-5 # recommendation from lisa paper for 7b\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\nflash_attn_cross_entropy: false\nflash_attn_rms_norm: true\nflash_attn_fuse_mlp: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\nweight_decay: 0.1\nspecial_tokens:\n  bos_token: \"<s>\"\n  eos_token: \"</s>\"\n  unk_token: \"<unk>\"\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/llama-2/loftq.yml",
    "content": "base_model: NousResearch/Llama-2-7b-hf\n# optionally might have model_type or tokenizer_type\nmodel_type: LlamaForCausalLM\ntokenizer_type: LlamaTokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\ndatasets:\n  - path: mhenrichsen/alpaca_2k_test\n    type: alpaca\ndataset_prepared_path:\nval_set_size: 0.05\noutput_dir: ./outputs/lora-out\n\nsequence_len: 4096\nsample_packing: true\n\n\nadapter: lora\nlora_model_dir:\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_linear: true\npeft:\n  loftq_config:\n    loftq_bits: 4\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 2\nnum_epochs: 4\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/llama-2/lora.yml",
    "content": "base_model: NousResearch/Llama-2-7b-hf\n# optionally might have model_type or tokenizer_type\nmodel_type: LlamaForCausalLM\ntokenizer_type: LlamaTokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nload_in_8bit: true\nload_in_4bit: false\n\ndatasets:\n  - path: mhenrichsen/alpaca_2k_test\n    type: alpaca\ndataset_prepared_path:\nval_set_size: 0.05\noutput_dir: ./outputs/lora-out\n\nsequence_len: 4096\nsample_packing: true\n\n\nadapter: lora\nlora_model_dir:\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_linear: true\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 2\nnum_epochs: 4\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/llama-2/qlora-fsdp.yml",
    "content": "base_model: NousResearch/Llama-2-7b-hf\n# optionally might have model_type or tokenizer_type\nmodel_type: LlamaForCausalLM\ntokenizer_type: LlamaTokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nload_in_8bit: false\nload_in_4bit: true\n\ndatasets:\n  - path: yahma/alpaca-cleaned\n    type: alpaca\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.05\noutput_dir: ./outputs/qlora-out\n\nadapter: qlora\nlora_model_dir:\n\nsequence_len: 512\nsample_packing: false\n\n\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_linear: true\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 4\nnum_epochs: 4\noptimizer: adamw_torch_fused\nlr_scheduler: cosine\nlearning_rate: 0.00001\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\ngradient_checkpointing_kwargs:\n  use_reentrant: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\nweight_decay: 0.0\nfsdp:\n  - full_shard\n  - auto_wrap\nfsdp_config:\n  fsdp_limit_all_gathers: true\n  fsdp_sync_module_states: true\n  fsdp_offload_params: true\n  fsdp_use_orig_params: false\n  fsdp_cpu_ram_efficient_loading: true\n  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP\n  fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer\n  fsdp_state_dict_type: FULL_STATE_DICT\n  # fsdp_cpu_offload_pin_memory: false  # uncomment to enable swap memory usage when RAM is insufficient\nspecial_tokens:\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/llama-2/qlora.yml",
    "content": "base_model: NousResearch/Llama-2-7b-hf\n# optionally might have model_type or tokenizer_type\nmodel_type: LlamaForCausalLM\ntokenizer_type: LlamaTokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nload_in_8bit: false\nload_in_4bit: true\n\ndatasets:\n  - path: mhenrichsen/alpaca_2k_test\n    type: alpaca\ndataset_prepared_path:\nval_set_size: 0.05\noutput_dir: ./outputs/qlora-out\n\nadapter: qlora\nlora_model_dir:\n\nsequence_len: 4096\nsample_packing: true\n\n\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_linear: true\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 2\nnum_epochs: 4\noptimizer: paged_adamw_32bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/llama-2/relora.yml",
    "content": "base_model: NousResearch/Llama-2-7b-hf\nmodel_type: LlamaForCausalLM\ntokenizer_type: LlamaTokenizer\n\n\nload_in_8bit: false\nload_in_4bit: true\n\ndatasets:\n  - path: teknium/GPT4-LLM-Cleaned\n    type: alpaca\ndataset_prepared_path:\nval_set_size: 0.05\noutput_dir: ./outputs/relora-out\n\nadapter: qlora\nlora_model_dir:\n\nsequence_len: 4096\nsample_packing: true\n\n\nlora_r: 8\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_linear: true\n\nrelora: true\nrelora_prune_ratio: 0.9\nrelora_cpu_offload: false\njagged_restart_steps: 150\njagged_restart_warmup_steps: 10\njagged_restart_anneal_steps: false\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 4\nnum_epochs: 4\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n  bos_token: \"<s>\"\n  eos_token: \"</s>\"\n  unk_token: \"<unk>\"\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/llama-3/3b-fp8-fsdp2.yaml",
    "content": "base_model: meta-llama/Llama-3.2-3B\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nload_in_8bit: false\nload_in_4bit: false\nstrict: false\n\nplugins:\n  - axolotl.integrations.liger.LigerPlugin\n\nliger_rope: true\nliger_rms_norm: true\nliger_glu_activation: true\nliger_layer_norm: true\nliger_fused_linear_cross_entropy: true\n\ndatasets:\n  - path: yahma/alpaca-cleaned\n    type: alpaca\n\noutput_dir: ./outputs/fp8_out/\n\nsample_packing: true\npad_to_sequence_len: true\nsequence_len: 512\n\nflex_attention: true\nflex_attn_compile_kwargs:\n  dynamic: false\n  mode: max-autotune-no-cudagraphs\ntorch_compile: true\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 1\nmicro_batch_size: 16\nnum_epochs: 1\noptimizer: adamw_torch_fused\n\ncosine_constant_lr_ratio: 0\ncosine_min_lr_ratio: 1.0\nlearning_rate: 2e-5\nsave_only_model: true\n\nfp8: true\nfp8_enable_fsdp_float8_all_gather: true\n\nresume_from_checkpoint:\nlogging_steps: 1\n\nevals_per_epoch: 1\nsaves_per_epoch: 1\n\nwarmup_steps: 10\nweight_decay: 0.0\n\nfsdp_version: 2\nfsdp_config:\n  offload_params: false\n  auto_wrap_policy: TRANSFORMER_BASED_WRAP\n  transformer_layer_cls_to_wrap: LlamaDecoderLayer\n  state_dict_type: FULL_STATE_DICT\n  sharding_strategy: FULL_SHARD\n  reshard_after_forward: true\n  activation_checkpointing: false\n\nspecial_tokens:\n  pad_token: <|end_of_text|>\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/llama-3/3b-qat-fsdp2.yaml",
    "content": "base_model: meta-llama/Llama-3.2-3B\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nload_in_8bit: false\nload_in_4bit: false\nstrict: false\n\nplugins:\n  - axolotl.integrations.liger.LigerPlugin\n\nliger_rope: true\nliger_rms_norm: true\nliger_glu_activation: true\nliger_layer_norm: true\nliger_fused_linear_cross_entropy: true\n\n\ndatasets:\n  - path: yahma/alpaca-cleaned\n    type: alpaca\n    split: train[:95%]\n\noutput_dir: ./outputs/qat_out/\ndataset_prepared_path: ./outputs/qat_out/dataset_prepared\n\nsample_packing: false\nsequence_len: 8192\nflash_attention: true\n\nqat:\n  activation_dtype: int8\n  weight_dtype: int4\n  group_size: 32\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 1\nmicro_batch_size: 16\nnum_epochs: 1\noptimizer: adamw_torch_fused\n\ncosine_constant_lr_ratio: 0\ncosine_min_lr_ratio: 1.0\nlearning_rate: 2e-5\nsave_only_model: true\nbf16: true\n\nresume_from_checkpoint:\nlogging_steps: 1\n\nevals_per_epoch: 1\nsaves_per_epoch: 1\n\nwarmup_ratio: 0.1\nweight_decay: 0.0\nfsdp:\n  - full_shard\n  - auto_wrap\n\nfsdp_config:\n  fsdp_version: 2\n  fsdp_offload_params: false\n  fsdp_cpu_ram_efficient_loading: false\n  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP\n  fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer\n  fsdp_state_dict_type: FULL_STATE_DICT\n  fsdp_sharding_strategy: FULL_SHARD\n  fsdp_reshard_after_forward: true\n  fsdp_activation_checkpointing: true\n\nspecial_tokens:\n  pad_token: <|finetune_right_pad_id|>\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/llama-3/3b-qat-mxfp4.yaml",
    "content": "base_model: meta-llama/Llama-3.2-3B\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nload_in_8bit: false\nload_in_4bit: false\nstrict: false\n\nplugins:\n  - axolotl.integrations.liger.LigerPlugin\n\nliger_rope: true\nliger_rms_norm: true\nliger_glu_activation: true\nliger_layer_norm: true\nliger_fused_linear_cross_entropy: true\n\ndatasets:\n  - path: yahma/alpaca-cleaned\n    type: alpaca\n    split: train[:95%]\n\noutput_dir: ./outputs/qat_out/\ndataset_prepared_path: ./outputs/dataset_prepared\n\nsequence_len: 2048\nflash_attention: true\n\nqat:\n  activation_dtype: mxfp4\n  weight_dtype: mxfp4\n  group_size: 32\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_checkpointing: true\nactivation_offloading: true\ngradient_accumulation_steps: 4\nmicro_batch_size: 1\nnum_epochs: 1\noptimizer: adamw_torch_8bit\n\ncosine_constant_lr_ratio: 0\ncosine_min_lr_ratio: 1.0\nlearning_rate: 2e-5\nsave_only_model: true\nbf16: true\n\nresume_from_checkpoint:\nlogging_steps: 1\n\nevals_per_epoch: 1\nsaves_per_epoch: 1\n\nwarmup_ratio: 0.1\nweight_decay: 0.0\n\nspecial_tokens:\n  pad_token: <|finetune_right_pad_id|>\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/llama-3/3b-qat-nvfp4.yaml",
    "content": "base_model: meta-llama/Llama-3.2-3B\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nload_in_8bit: false\nload_in_4bit: false\nstrict: false\n\nplugins:\n  - axolotl.integrations.liger.LigerPlugin\n\nliger_rope: true\nliger_rms_norm: true\nliger_glu_activation: true\nliger_layer_norm: true\nliger_fused_linear_cross_entropy: true\n\ndatasets:\n  - path: yahma/alpaca-cleaned\n    type: alpaca\n    split: train[:95%]\n\noutput_dir: ./outputs/qat_out/\ndataset_prepared_path: ./outputs/dataset_prepared\n\nsequence_len: 8192\nflash_attention: true\n\nqat:\n  activation_dtype: nvfp4\n  weight_dtype: nvfp4\n  group_size: 16 # only group_size of 16 is supported with nvfp4\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_checkpointing: true\ngradient_accumulation_steps: 1\nmicro_batch_size: 64\nnum_epochs: 1\noptimizer: adamw_torch_fused\n\ncosine_constant_lr_ratio: 0\ncosine_min_lr_ratio: 1.0\nlearning_rate: 2e-5\nsave_only_model: true\nbf16: true\n\nresume_from_checkpoint:\nlogging_steps: 1\n\nevals_per_epoch: 1\nsaves_per_epoch: 1\n\nwarmup_ratio: 0.1\nweight_decay: 0.0\n\nspecial_tokens:\n  pad_token: <|finetune_right_pad_id|>\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/llama-3/README.md",
    "content": "# Llama-3\n\nhttps://llama.meta.com/llama3/\n\n[8B Base Model](https://huggingface.co/meta-llama/Meta-Llama-3-8B)\n - [Full Fine Tune](./fft-8b.yaml)\n   - Single GPU @ 48GB VRAM\n - [LoRA](./lora-8b.yml)\n   - Single GPU @ 11GB VRAM\n\n[70B Base Model](https://huggingface.co/meta-llama/Meta-Llama-3-70B)\n - [QLORA+FSDP](./qlora-fsdp-70b.yaml)\n   - Dual GPU @ 21GB VRAM\n"
  },
  {
    "path": "examples/llama-3/diffusion/pretrain-1b.yaml",
    "content": "base_model: meta-llama/Llama-3.2-1B\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\npretraining_dataset:\n  - path: wikitext\n    name: wikitext-103-raw-v1\n    type: completion\n    field: text\n\nplugins:\n  - axolotl.integrations.diffusion.DiffusionPlugin\n\ndiffusion:\n  noise_schedule: cosine\n  min_mask_ratio: 0.15\n  max_mask_ratio: 0.85\n  num_diffusion_steps: 128\n  eps: 5e-4\n  importance_weighting: true\n  mask_token_id: 128002\n  generate_samples: true\n  generation_interval: 250\n\noutput_dir: ./outputs/model-out\n\nsequence_len: 512\nsample_packing: true\n\ngradient_accumulation_steps: 8\nmicro_batch_size: 4\nmax_steps: 10000\nwarmup_ratio: 0.1\n\noptimizer: adamw_8bit\nlr_scheduler: cosine\nlearning_rate: 3e-4\nsdp_attention: true\n\nbf16: auto\ntf32: true\n\nlogging_steps: 1\nsave_strategy: steps\nsave_steps: 1000\n\nspecial_tokens:\n  pad_token: \"<|end_of_text|>\"\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/llama-3/diffusion/sft-1b.yaml",
    "content": "base_model: meta-llama/Llama-3.2-1B\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\ndatasets:\n  - path: teknium/GPT4-LLM-Cleaned\n    type: alpaca\nval_set_size: 0.05\n\nplugins:\n  - axolotl.integrations.diffusion.DiffusionPlugin\n\ndiffusion:\n  noise_schedule: cosine\n  min_mask_ratio: 0.1\n  max_mask_ratio: 0.9\n  num_diffusion_steps: 128\n  eps: 1e-3\n  importance_weighting: true\n  mask_token_id: 128002\n  generate_samples: true\n  generation_interval: 250\n\noutput_dir: ./outputs/model-out\n\nsequence_len: 512\nsample_packing: true\neval_sample_packing: true\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 4\nnum_epochs: 1\nwarmup_steps: 0.1\n\noptimizer: adamw_8bit\nlr_scheduler: cosine\nlearning_rate: 1e-5\n\nbf16: auto\ntf32: true\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nsdp_attention: true\n\nlogging_steps: 1\nsave_strategy: best\neval_strategy: epoch\n\nspecial_tokens:\n  pad_token: \"<|end_of_text|>\"\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/llama-3/fft-8b-liger-fsdp.yaml",
    "content": "base_model: NousResearch/Meta-Llama-3.1-8B\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nplugins:\n  - axolotl.integrations.liger.LigerPlugin\nliger_rope: true\nliger_rms_norm: true\nliger_glu_activation: true\nliger_fused_linear_cross_entropy: true\n\n\nchat_template: llama3\ndatasets:\n  - path: mlabonne/FineTome-100k\n    type: chat_template\n    split: train[:20%]\n    field_messages: conversations\n    message_property_mappings:\n      role: from\n      content: value\n\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.02\noutput_dir: ./outputs/out\n\nsequence_len: 4096\nsample_packing: true\n\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 2\nnum_epochs: 1\noptimizer: adamw_torch_fused\nlr_scheduler: cosine\nlearning_rate: 2e-5\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\ngradient_checkpointing_kwargs:\n  use_reentrant: false\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 2\nsaves_per_epoch: 1\nweight_decay: 0.0\nfsdp:\n  - full_shard\n  - auto_wrap\nfsdp_config:\n  fsdp_limit_all_gathers: true\n  fsdp_sync_module_states: true\n  fsdp_offload_params: true\n  fsdp_use_orig_params: false\n  fsdp_cpu_ram_efficient_loading: true\n  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP\n  fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer\n  fsdp_state_dict_type: FULL_STATE_DICT\n  fsdp_sharding_strategy: FULL_SHARD\n  fsdp_backward_prefetch: BACKWARD_PRE\nspecial_tokens:\n  pad_token: <|finetune_right_pad_id|>\n  eos_token: <|eot_id|>\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/llama-3/fft-8b.yaml",
    "content": "base_model: NousResearch/Meta-Llama-3.1-8B\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\ndatasets:\n  - path: tatsu-lab/alpaca\n    type: alpaca\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.05\noutput_dir: ./outputs/out\n\nsequence_len: 8192\nsample_packing: true\n\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 8\nmicro_batch_size: 1\nnum_epochs: 1\noptimizer: paged_adamw_8bit\nlr_scheduler: cosine\nlearning_rate: 2e-5\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\ngradient_checkpointing_kwargs:\n  use_reentrant: false\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 2\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n  pad_token: <|end_of_text|>\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/llama-3/instruct-dpo-lora-8b.yml",
    "content": "base_model: meta-llama/Meta-Llama-3-8B-Instruct\n# optionally might have model_type or tokenizer_type\nmodel_type: LlamaForCausalLM\ntokenizer_type: AutoTokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nspecial_tokens:\n  pad_token: <|finetune_right_pad_id|>\n  eos_token: <|eot_id|>\n\nload_in_8bit: true\nload_in_4bit: false\n\nchat_template: llama3\nrl: dpo\ndatasets:\n  - path: fozziethebeat/alpaca_messages_2k_dpo_test\n    type: chat_template.default\n    field_messages: conversation\n    field_chosen: chosen\n    field_rejected: rejected\n    message_property_mappings:\n      role: role\n      content: content\n    roles:\n      system:\n        - system\n      user:\n        - user\n      assistant:\n        - assistant\n\ndataset_prepared_path:\nval_set_size: 0.05\noutput_dir: ./outputs/lora-out\n\nsequence_len: 4096\nsample_packing: false\n\n\nadapter: lora\nlora_model_dir:\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_linear: true\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 2\nnum_epochs: 4\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\nweight_decay: 0.0\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/llama-3/instruct-lora-8b.yml",
    "content": "base_model: NousResearch/Meta-Llama-3-8B-Instruct\n# optionally might have model_type or tokenizer_type\nmodel_type: LlamaForCausalLM\ntokenizer_type: AutoTokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nload_in_8bit: true\nload_in_4bit: false\n\nchat_template: llama3\ndatasets:\n  - path: fozziethebeat/alpaca_messages_2k_test\n    type: chat_template\n\ndataset_prepared_path:\nval_set_size: 0.05\noutput_dir: ./outputs/lora-out\n\nsequence_len: 4096\nsample_packing: false\n\n\nadapter: lora\nlora_model_dir:\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_linear: true\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 2\nnum_epochs: 4\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n   pad_token: <|end_of_text|>\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/llama-3/lora-1b-deduplicate-dpo.yml",
    "content": "base_model: meta-llama/Llama-3.2-1B\n# optionally might have model_type or tokenizer_type\nmodel_type: LlamaForCausalLM\ntokenizer_type: AutoTokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nload_in_8bit: true\nload_in_4bit: false\n\nchat_template: llama3\nrl: dpo\ndatasets:\n  - path: fozziethebeat/alpaca_messages_2k_dpo_test\n    type: chat_template.default\n    field_messages: conversation\n    field_chosen: chosen\n    field_rejected: rejected\n    message_property_mappings:\n      role: role\n      content: content\n    roles:\n      system:\n        - system\n      user:\n        - user\n      assistant:\n        - assistant\n  - path: fozziethebeat/alpaca_messages_2k_dpo_test\n    type: chat_template.default\n    field_messages: conversation\n    field_chosen: chosen\n    field_rejected: rejected\n    message_property_mappings:\n      role: role\n      content: content\n    roles:\n      system:\n        - system\n      user:\n        - user\n      assistant:\n        - assistant\n\ndataset_exact_deduplication: true\ndataset_prepared_path:\nval_set_size: 0\noutput_dir: ./outputs/lora-out\n\nsequence_len: 4096\nsample_packing: false\n\n\nadapter: lora\nlora_model_dir:\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_linear: true\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 2\nnum_epochs: 4\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\nweight_decay: 0.0\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/llama-3/lora-1b-deduplicate-sft.yml",
    "content": "base_model: meta-llama/Llama-3.2-1B\n# optionally might have model_type or tokenizer_type\nmodel_type: LlamaForCausalLM\ntokenizer_type: AutoTokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nload_in_8bit: true\nload_in_4bit: false\n\ndatasets:\n  - path: mhenrichsen/alpaca_2k_test\n    type: alpaca\n  - path: mhenrichsen/alpaca_2k_test\n    type: alpaca\ndataset_prepared_path:\nval_set_size: 0.0\noutput_dir: ./outputs/lora-out\n\ndataset_exact_deduplication: true\n\nsequence_len: 4096\nsample_packing: true\neval_sample_packing: false\n\n\nadapter: lora\nlora_model_dir:\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_linear: true\nlora_modules_to_save:\n  - embed_tokens\n  - lm_head\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 2\nnum_epochs: 4\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n   pad_token: <|end_of_text|>\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/llama-3/lora-1b-kernels.yml",
    "content": "base_model: NousResearch/Llama-3.2-1B\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\ndatasets:\n  - path: teknium/GPT4-LLM-Cleaned\n    type: alpaca\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.1\noutput_dir: ./outputs/lora-out\n\nadapter: lora\nlora_model_dir:\n\nsequence_len: 2048\nsample_packing: true\n\n\nlora_r: 16\nlora_alpha: 32\n# Currently, we don't support dropout with our custom Triton kernels\n# lora_dropout: 0.05\nlora_target_modules:\n  - gate_proj\n  - down_proj\n  - up_proj\n  - q_proj\n  - v_proj\n  - k_proj\n  - o_proj\n\n# These options enable our custom Triton kernels / autograd\n# functions for MLP and attention calculations\nlora_mlp_kernel: true\nlora_qkv_kernel: true\nlora_o_kernel: true\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 2\nmicro_batch_size: 2\nnum_epochs: 1\noptimizer: adamw_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nloss_watchdog_threshold: 5.0\nloss_watchdog_patience: 3\n\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n  pad_token: \"<|end_of_text|>\"\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/llama-3/lora-1b-ray.yml",
    "content": "base_model: NousResearch/Llama-3.2-1B\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\ndatasets:\n  - path: teknium/GPT4-LLM-Cleaned\n    type: alpaca\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.1\noutput_dir: ./outputs/lora-out\n\nadapter: lora\nlora_model_dir:\n\nsequence_len: 2048\nsample_packing: true\neval_sample_packing: true\n\n\nlora_r: 16\nlora_alpha: 32\nlora_dropout: 0.05\nlora_target_modules:\n  - gate_proj\n  - down_proj\n  - up_proj\n  - q_proj\n  - v_proj\n  - k_proj\n  - o_proj\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 2\nmicro_batch_size: 2\nnum_epochs: 1\noptimizer: adamw_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nloss_watchdog_threshold: 5.0\nloss_watchdog_patience: 3\n\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\n\ndeepspeed: deepspeed_configs/zero3.json\nweight_decay: 0.0\nspecial_tokens:\n  pad_token: \"<|end_of_text|>\"\n\nuse_ray: true\nray_num_workers: 4\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/llama-3/lora-1b-sample-packing-sequentially.yml",
    "content": "base_model: meta-llama/Llama-3.2-1B\n# optionally might have model_type or tokenizer_type\nmodel_type: LlamaForCausalLM\ntokenizer_type: AutoTokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nload_in_8bit: true\nload_in_4bit: false\n\ndatasets:\n  - path: mhenrichsen/alpaca_2k_test\n    type: alpaca\n  - path: mhenrichsen/alpaca_2k_test\n    type: alpaca\ndataset_prepared_path:\nval_set_size: 0.0\noutput_dir: ./outputs/lora-out\n\ntest_value: true\n\nsequence_len: 4096\nsample_packing: true\nsample_packing_sequentially: true\ncurriculum_sampling: true\neval_sample_packing: false\n\n\nadapter: lora\nlora_model_dir:\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_linear: true\nlora_modules_to_save:\n  - embed_tokens\n  - lm_head\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 2\nnum_epochs: 4\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n  pad_token: <|end_of_text|>\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/llama-3/lora-1b.yml",
    "content": "base_model: NousResearch/Llama-3.2-1B\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\ndatasets:\n  - path: teknium/GPT4-LLM-Cleaned\n    type: alpaca\n\nval_set_size: 0.1\noutput_dir: ./outputs/lora-out\n\nadapter: lora\nlora_model_dir:\n\nsequence_len: 2048\nsample_packing: true\neval_sample_packing: true\n\n\nlora_r: 16\nlora_alpha: 32\nlora_dropout: 0.05\nlora_target_modules:\n  - gate_proj\n  - down_proj\n  - up_proj\n  - q_proj\n  - v_proj\n  - k_proj\n  - o_proj\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 2\nmicro_batch_size: 2\nnum_epochs: 1\n\noptimizer: adamw_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nloss_watchdog_threshold: 5.0\nloss_watchdog_patience: 3\n\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n  pad_token: \"<|end_of_text|>\"\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/llama-3/lora-8b.yml",
    "content": "base_model: NousResearch/Meta-Llama-3-8B\n# optionally might have model_type or tokenizer_type\nmodel_type: LlamaForCausalLM\ntokenizer_type: AutoTokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nload_in_8bit: true\nload_in_4bit: false\n\ndatasets:\n  - path: mhenrichsen/alpaca_2k_test\n    type: alpaca\ndataset_prepared_path:\nval_set_size: 0.05\noutput_dir: ./outputs/lora-out\n\nsequence_len: 4096\nsample_packing: true\neval_sample_packing: false\n\n\nadapter: lora\nlora_model_dir:\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_linear: true\nlora_modules_to_save:\n  - embed_tokens\n  - lm_head\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 2\nnum_epochs: 4\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n   pad_token: <|end_of_text|>\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/llama-3/opentelemetry-qlora.yml",
    "content": "base_model: NousResearch/Llama-3.2-1B\nmodel_type: AutoModelForCausalLM\ntokenizer_type: AutoTokenizer\n\nload_in_4bit: true\n\ndatasets:\n  - path: mhenrichsen/alpaca_2k_test\n    type: alpaca\n\noutput_dir: ./outputs/opentelemetry-example\n\nadapter: qlora\nsequence_len: 512\nsample_packing: false\n\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_linear: true\n\n# OpenTelemetry Configuration\nuse_otel_metrics: true\notel_metrics_host: \"localhost\"\notel_metrics_port: 8000\n\n# Disable WandB\nuse_wandb: false\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 2\nnum_epochs: 1\noptimizer: paged_adamw_32bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\nlogging_steps: 1\nflash_attention: false\n\nwarmup_ratio: 0.1\nevals_per_epoch: 2\nsaves_per_epoch: 1\nweight_decay: 0.0\n\nspecial_tokens:\n  pad_token: \"<|end_of_text|>\"\n"
  },
  {
    "path": "examples/llama-3/qlora-1b-gdpo.yaml",
    "content": "base_model: meta-llama/Llama-3.2-1B-Instruct\n\nchat_template: llama3\n\nrl: gdpo\n\ntrl:\n  beta: 0.001\n  max_completion_length: 128\n  num_generations: 2\n  temperature: 0.7\n  top_p: 0.95\n\n  use_vllm: false\n\n\n  multi_objective_aggregation: normalize_then_sum\n\n  reward_funcs:\n    - rwd.format_reward\n    - rwd.correctness_reward\n  reward_weights: [1.0, 2.0]\n\n  log_completions: true\n  num_completions_to_print: 3\n  scale_rewards: true\n\ndatasets:\n  - path: openai/gsm8k\n    name: main\n    split: train[:1000]\n    type: rwd.gsm8k_transform\n\nval_set_size: 0.0\noutput_dir: ./outputs/llama3-gdpo-out\n\nsequence_len: 512\nsample_packing: false\npad_to_sequence_len: false\n\ngradient_accumulation_steps: 8\nmicro_batch_size: 1\nnum_epochs: 1\nmax_steps: 100\n\noptimizer: adamw_torch_fused\nlr_scheduler: cosine\nlearning_rate: 5e-5\nweight_decay: 0.01\nwarmup_steps: 10\n\nbf16: auto\ntf32: true\n\ngradient_checkpointing: true\ngradient_checkpointing_kwargs:\n  use_reentrant: false\n\nflash_attention: true\nlogging_steps: 1\nsave_steps: 50\nsave_safetensors: true\n\nspecial_tokens:\n  pad_token: \"<|end_of_text|>\"\n\n\nseed: 42\n"
  },
  {
    "path": "examples/llama-3/qlora-1b-kto.yaml",
    "content": "base_model: meta-llama/Llama-3.2-1B\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nload_in_8bit: false\nload_in_4bit: true\n\nrl: kto\nrl_beta: 0.5\nkto_desirable_weight: 0.2\n\ndatasets:\n  - path: argilla/ultrafeedback-binarized-preferences-cleaned-kto\n    type: llama3.ultra\n    split: train\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.0\noutput_dir: ./outputs/qlora-out\n\nremove_unused_columns: false\n\nadapter: qlora\nlora_model_dir:\n\nsequence_len: 2048\nsample_packing: false  # not supported with kto\neval_sample_packing: false\npad_to_sequence_len: false\n\nlora_r: 32\nlora_alpha: 64\nlora_dropout: 0.05\nlora_target_linear: true\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 1\nmicro_batch_size: 2\nnum_epochs: 1\noptimizer: adamw_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: true\n\ngradient_checkpointing: true\ngradient_checkpointing_kwargs:\n  use_reentrant: false\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n  pad_token: \"<|end_of_text|>\"\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/llama-3/qlora-1b.yml",
    "content": "base_model: NousResearch/Llama-3.2-1B\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nload_in_8bit: false\nload_in_4bit: true\n\ndatasets:\n  - path: teknium/GPT4-LLM-Cleaned\n    type: alpaca\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.1\noutput_dir: ./outputs/qlora-out\n\nadapter: qlora\nlora_model_dir:\n\nsequence_len: 2048\nsample_packing: true\neval_sample_packing: true\n\n\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_modules:\n  - gate_proj\n  - down_proj\n  - up_proj\n  - q_proj\n  - v_proj\n  - k_proj\n  - o_proj\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 2\nnum_epochs: 1\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nloss_watchdog_threshold: 5.0\nloss_watchdog_patience: 3\n\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n  pad_token: \"<|end_of_text|>\"\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/llama-3/qlora-fsdp-405b.yaml",
    "content": "base_model: hugging-quants/Meta-Llama-3.1-405B-BNB-NF4-BF16\n# optionally might have model_type or tokenizer_type\ntokenizer_type: AutoTokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nload_in_4bit: true\n\ndatasets:\n  - path: tatsu-lab/alpaca\n    type: alpaca\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.0\noutput_dir: ./outputs/out/qlora-llama3_1-405b\n\nadapter: qlora\n\nsequence_len: 2048\nsample_packing: true\n\n\nlora_r: 16\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_linear: true\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 1\nnum_epochs: 2\noptimizer: adamw_torch_fused\nlr_scheduler: cosine\nlearning_rate: 0.00001\n\nbf16: true\ntf32: true\n\ngradient_checkpointing: true\ngradient_checkpointing_kwargs:\n  use_reentrant: true\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\nweight_decay: 0.0\nfsdp:\n  - full_shard\n  - auto_wrap\nfsdp_config:\n  fsdp_limit_all_gathers: true\n  fsdp_sync_module_states: true\n  fsdp_offload_params: true\n  fsdp_use_orig_params: false\n  fsdp_cpu_ram_efficient_loading: true\n  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP\n  fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer\n  fsdp_state_dict_type: FULL_STATE_DICT\n  fsdp_sharding_strategy: FULL_SHARD\nspecial_tokens:\n  pad_token: <|finetune_right_pad_id|>\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/llama-3/qlora-fsdp-70b.yaml",
    "content": "base_model: casperhansen/llama-3-70b-fp16\n# optionally might have model_type or tokenizer_type\nmodel_type: LlamaForCausalLM\ntokenizer_type: AutoTokenizer  # PreTrainedTokenizerFast\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nload_in_8bit: false\nload_in_4bit: true\n\ndatasets:\n  - path: tatsu-lab/alpaca\n    type: alpaca\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.05\noutput_dir: ./outputs/out/qlora-llama3-70b\n\nadapter: qlora\nlora_model_dir:\n\nsequence_len: 512\nsample_packing: false\n\n\nlora_r: 8\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_linear: true\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 1\nnum_epochs: 4\noptimizer: adamw_torch_fused\nlr_scheduler: cosine\nlearning_rate: 0.00001\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\ngradient_checkpointing_kwargs:\n  use_reentrant: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\nweight_decay: 0.0\nfsdp:\n  - full_shard\n  - auto_wrap\nfsdp_config:\n  fsdp_limit_all_gathers: true\n  fsdp_sync_module_states: true\n  fsdp_offload_params: true\n  fsdp_use_orig_params: false\n  fsdp_cpu_ram_efficient_loading: true\n  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP\n  fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer\n  fsdp_state_dict_type: FULL_STATE_DICT\n  fsdp_sharding_strategy: FULL_SHARD\nspecial_tokens:\n  pad_token: <|end_of_text|>\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/llama-3/qlora.yml",
    "content": "base_model: NousResearch/Meta-Llama-3-8B\n# optionally might have model_type or tokenizer_type\nmodel_type: AutoModelForCausalLM\ntokenizer_type: AutoTokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nload_in_8bit: false\nload_in_4bit: true\n\ndatasets:\n  - path: aaditya/alpaca_subset_1\n    type: alpaca\ndataset_prepared_path:\nval_set_size: 0\noutput_dir: ./outputs/qlora-out\n\nadapter: qlora\nlora_model_dir:\n\nsequence_len: 4096\nsample_packing: true\n\n\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_linear: true\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 2\nnum_epochs: 4\noptimizer: paged_adamw_32bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n  pad_token: \"<|end_of_text|>\"\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/llama-3/sparse-finetuning.yaml",
    "content": "base_model: neuralmagic/Sparse-Llama-3.1-8B-2of4\n\nplugins:\n  - axolotl.integrations.llm_compressor.LLMCompressorPlugin\n\nload_in_8bit: false\nload_in_4bit: false\nstrict: false\n\ndatasets:\n  - path: tatsu-lab/alpaca\n    type: alpaca\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.05\noutput_dir: ./outputs/out\n\nsequence_len: 4096\nsample_packing: true\n\neval_sample_packing: false\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 8\nmicro_batch_size: 1\nnum_epochs: 1\noptimizer: paged_adamw_8bit\nlr_scheduler: cosine\nlearning_rate: 2e-5\n\ntrain_on_inputs: false\ngroup_by_length: false\nbf16: auto\nfp16:\ntf32: false\n\ngradient_checkpointing: true\ngradient_checkpointing_kwargs:\n  use_reentrant: false\nearly_stopping_patience:\nresume_from_checkpoint:\nlogging_steps: 1\nxformers_attention:\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 2\neval_table_size:\nsaves_per_epoch: 1\ndebug:\ndeepspeed:\nweight_decay: 0.0\nfsdp:\nfsdp_config:\nspecial_tokens:\n  pad_token: <|end_of_text|>\n\nllmcompressor:\n  recipe:\n    finetuning_stage:\n      finetuning_modifiers:\n        ConstantPruningModifier:\n          targets: [\n            're:.*q_proj.weight',\n            're:.*k_proj.weight',\n            're:.*v_proj.weight',\n            're:.*o_proj.weight',\n            're:.*gate_proj.weight',\n            're:.*up_proj.weight',\n            're:.*down_proj.weight',\n          ]\n          start: 0\n  save_compressed: true\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/llama-3-vision/lora-11b.yaml",
    "content": "base_model: alpindale/Llama-3.2-11B-Vision-Instruct\n# optionally might have model_type or tokenizer_type or processor_type\nprocessor_type: AutoProcessor\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\n\n# these 3 lines are needed for now to handle vision chat templates w images\nskip_prepare_dataset: true\nremove_unused_columns: false\nsample_packing: false\n\nchat_template: llama3_2_vision\ndatasets:\n  - path: HuggingFaceH4/llava-instruct-mix-vsft\n    type: chat_template\n    split: train[:1%]\ndataset_prepared_path:\nval_set_size: 0.0\noutput_dir: ./outputs/out\n\nadapter: lora\nlora_model_dir:\n\nsequence_len: 8192\npad_to_sequence_len: false\n\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_modules: 'model.language_model.layers.[\\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 1\nnum_epochs: 1\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: true\nfp16:\ntf32: true\n\ngradient_checkpointing: true\nlogging_steps: 1\n# flash_attention: true  # use for text-only mode\nsdp_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 1\nsaves_per_epoch: 1\nweight_decay: 0.0\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/llama-4/README.md",
    "content": "# Llama 4 by Meta AI\n\n## Flash Attention vs Flex Attention\n\nWhile Flash Attention to support is \"enabled\" for Llama-4, the upstream implementation is not correct and usage of Flex Attention is recommended.\n\n## Available Examples\n\n### Llama 4 Scout 17Bx16Experts (109B)\n\nFlex Attention\n- [Text Single GPU (H100) QLoRA](./scout-qlora-single-h100-flex.yaml)\n- [Text Multi GPU QLoRA w/ FSDP2](./scout-qlora-flexattn-fsdp2.yaml)\n\n[//]: # (Flash Attention &#40;Do not use&#41;)\n\n[//]: # (- [Multi-Modal/Vision QLoRA w/ FSDP1]&#40;./scout-vision-qlora-fsdp.yaml&#41;)\n\n[//]: # (- [Text Single GPU &#40;H100&#41; QLoRA]&#40;./scout-qlora-single-h100.yaml&#41;)\n\n[//]: # (- [Text Multi GPU QLoRA w/ FSDP1]&#40;./scout-qlora-fsdp1.yaml&#41;)\n\nOur Single H100 implementation for Llama 4 Scout uses only 64.5GB VRAM for post-training with 4k context length @ 519 tokens/second. [WandB logs here](https://wandb.ai/axolotl-ai/llama4-flexattn-qlora/runs/wpie7dkj)\nMulti-GPU (4xH100) for Llama 4 Scout uses 62.8GB VRAM/GPU @ 4k contenxt length @ 280tps/gpu, [WandB logs here](https://wandb.ai/axolotl-ai/llama4-flexattn-qlora/runs/2lkezdj8)\n\n### Llama 4 Maverick 17Bx128Experts (400B)\n\nComing Soon\n\n## Delinearized Llama 4 Models\n\nWe provide a script to delinearize Llama 4 linearized models into regular HuggingFace Llama 4 models.\n\n```bash\naxolotl delinearize-llama4 --model path/to/model_dir --output path/to/output_dir\n```\n\nNote: This only works with the non-quantized linearized model. If you have an adapter, merge it with the *non-quantized linearized* model before delinearizing.\n"
  },
  {
    "path": "examples/llama-4/do-no-use-fa2/maverick-qlora-fsdp1.yaml",
    "content": "base_model: axolotl-quants/Llama-4-Maverick-17B-128E-Linearized-bnb-nf4-bf16\nmodel_type: Llama4ForConditionalGeneration\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\n\nplugins:\n  - axolotl.integrations.liger.LigerPlugin\n\nliger_glu_activation: true\nliger_rms_norm: true\nliger_layer_norm: true\n\nllama4_linearized_experts: true\nload_in_4bit: true\nadapter: qlora\nlora_r: 32\nlora_alpha: 64\nlora_target_modules:\n  - self_attn.q_proj\n  - self_attn.k_proj\n  - self_attn.v_proj\n  - self_attn.o_proj\n  - shared_expert.gate_proj\n  - shared_expert.up_proj\n  - shared_expert.down_proj\n  # - experts.gate_projs.[0-9]+$\n  # - experts.up_projs.[0-9]+$\n  # - experts.down_projs.[0-9]+$\nlora_modules_to_save:\n# - lm_head\n# - embed_tokens\n\nchat_template: llama4\ndatasets:\n  - path: mlabonne/FineTome-100k\n    type: chat_template\n    split: train[:20%]\n    field_messages: conversations\n    message_property_mappings:\n      role: from\n      content: value\n\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.0\noutput_dir: ./outputs/out\n\nsequence_len: 4096\nsample_packing: true\n\n\ngradient_accumulation_steps: 1\nmicro_batch_size: 1\nnum_epochs: 1\noptimizer: adamw_torch_fused\nlr_scheduler: cosine\nlearning_rate: 1e-4\n\nbf16: true\ntf32: true\n\nlogging_steps: 1\nflash_attention: true\n\ngradient_checkpointing: offload\ngradient_checkpointing_kwargs:\n  use_reentrant: false\n\nwarmup_ratio: 0.1\nevals_per_epoch: 1\nsaves_per_epoch: 1\nweight_decay: 0.0\nfsdp:\n  - auto_wrap\n  - full_shard\nfsdp_config:\n  fsdp_transformer_layer_cls_to_wrap: Llama4TextDecoderLayer\n  fsdp_limit_all_gathers: true\n  fsdp_sync_module_states: true\n  fsdp_offload_params: true\n  fsdp_use_orig_params: false\n  fsdp_cpu_ram_efficient_loading: true\n  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP\n  fsdp_state_dict_type: FULL_STATE_DICT\n  fsdp_sharding_strategy: FULL_SHARD\nspecial_tokens:\n  pad_token: <|finetune_right_pad|>\n  eos_token: <|eot|>\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/llama-4/do-no-use-fa2/scout-qlora-fsdp1.yaml",
    "content": "base_model: axolotl-quants/Llama-4-Scout-17B-16E-Linearized-bnb-nf4-bf16\nmodel_type: Llama4ForConditionalGeneration\n  # Automatically upload checkpoint and final model to HF\n  # hub_model_id: username/custom_model_name\n\n\n# torch_compile: true\nplugins:\n  - axolotl.integrations.liger.LigerPlugin\n\nliger_glu_activation: true\nliger_rms_norm: true\nliger_layer_norm: true\n\nllama4_linearized_experts: true\nload_in_4bit: true\nadapter: qlora\nlora_r: 32\nlora_alpha: 64\nlora_target_modules:\n  - self_attn.q_proj\n  - self_attn.k_proj\n  - self_attn.v_proj\n  - self_attn.o_proj\n  - shared_expert.gate_proj\n  - shared_expert.up_proj\n  - shared_expert.down_proj\n    # - experts.gate_projs.[0-9]+$\n    # - experts.up_projs.[0-9]+$\n    # - experts.down_projs.[0-9]+$\nlora_modules_to_save:\n  - lm_head\n  - embed_tokens\n\nchat_template: llama4\ndatasets:\n  - path: mlabonne/FineTome-100k\n    type: chat_template\n    split: train[:20%]\n    field_messages: conversations\n    message_property_mappings:\n      role: from\n      content: value\n\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.0\noutput_dir: ./outputs/out\n\nsequence_len: 4096\nsample_packing: true\n\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 1\nmicro_batch_size: 1\nnum_epochs: 1\noptimizer: adamw_torch_fused\nlr_scheduler: cosine\nlearning_rate: 2e-5\n\nbf16: true\ntf32: true\n\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 1\nsaves_per_epoch: 1\nweight_decay: 0.0\nfsdp:\n  - auto_wrap\n  - full_shard\nfsdp_config:\n  fsdp_transformer_layer_cls_to_wrap: Llama4TextDecoderLayer\n  fsdp_limit_all_gathers: true\n  fsdp_sync_module_states: true\n  fsdp_offload_params: true\n  fsdp_use_orig_params: false\n  fsdp_cpu_ram_efficient_loading: true\n  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP\n  fsdp_state_dict_type: FULL_STATE_DICT\n  fsdp_sharding_strategy: FULL_SHARD\n  fsdp_activation_checkpointing: true\nspecial_tokens:\n  pad_token: <|finetune_right_pad|>\n  eos_token: <|eot|>\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/llama-4/do-no-use-fa2/scout-qlora-single-h100.yaml",
    "content": "base_model: axolotl-quants/Llama-4-Scout-17B-16E-Linearized-bnb-nf4-bf16\nmodel_type: Llama4ForConditionalGeneration\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\n\nplugins:\n  - axolotl.integrations.liger.LigerPlugin\n\nliger_glu_activation: true\nliger_rms_norm: true\nliger_layer_norm: true\n\nllama4_linearized_experts: true\nload_in_4bit: true\nadapter: qlora\nlora_r: 32\nlora_alpha: 64\nlora_target_modules:\n  - self_attn.q_proj\n  - self_attn.k_proj\n  - self_attn.v_proj\n  - self_attn.o_proj\n  - shared_expert.gate_proj\n  - shared_expert.up_proj\n  - shared_expert.down_proj\n  # - experts.gate_projs.[0-9]+$\n  # - experts.up_projs.[0-9]+$\n  # - experts.down_projs.[0-9]+$\nlora_modules_to_save:\n  # - lm_head\n  # - embed_tokens\n\nlora_mlp_kernel: true\nlora_qkv_kernel: true\nlora_o_kernel: true\n\nchat_template: llama4\ndatasets:\n  - path: mlabonne/FineTome-100k\n    type: chat_template\n    split: train[:20%]\n    field_messages: conversations\n    message_property_mappings:\n      role: from\n      content: value\n\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.0\noutput_dir: ./outputs/out\n\nsequence_len: 4096  # up to 8k will work on a single H100\nsample_packing: true\n\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 1\nmicro_batch_size: 1\nnum_epochs: 1\noptimizer: adamw_torch_4bit\nlr_scheduler: cosine\nlearning_rate: 1e-4\n\nbf16: true\ntf32: true\n\nlogging_steps: 1\nflash_attention: true\n\ngradient_checkpointing: offload\ngradient_checkpointing_kwargs:\n  use_reentrant: false\n\nwarmup_ratio: 0.1\nevals_per_epoch: 1\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n  pad_token: <|finetune_right_pad|>\n  eos_token: <|eot|>\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/llama-4/do-no-use-fa2/scout-vision-qlora-fsdp.yaml",
    "content": "base_model: axolotl-quants/Llama-4-Scout-17B-16E-Linearized-bnb-nf4-bf16\nmodel_type: Llama4ForConditionalGeneration\nprocessor_type: Llama4Processor\n  # Automatically upload checkpoint and final model to HF\n  # hub_model_id: username/custom_model_name\n\n\n# these 3 lines are needed for now to handle vision chat templates w images\nskip_prepare_dataset: true\nremove_unused_columns: false\nsample_packing: false\n\nsequence_len: 4096\n\nplugins:\n  - axolotl.integrations.liger.LigerPlugin\n\nliger_glu_activation: true\nliger_rms_norm: true\nliger_layer_norm: true\n\nllama4_linearized_experts: true  # use Axolotl's customized model\nload_in_4bit: true\nadapter: qlora\nlora_r: 32\nlora_alpha: 64\nlora_target_modules:\n  - self_attn.q_proj\n  - self_attn.k_proj\n  - self_attn.v_proj\n  - self_attn.o_proj\n  - shared_expert.gate_proj\n  - shared_expert.up_proj\n  - shared_expert.down_proj\n  - vision_adapter.mlp.fc1\n  - vision_adapter.mlp.fc2\n  # - experts.gate_projs.[0-9]+$\n  # - experts.up_projs.[0-9]+$\n  # - experts.down_projs.[0-9]+$\nlora_modules_to_save:\n  - lm_head\n  - embed_tokens\n\nchat_template: llama4\ndatasets:\n  - path: HuggingFaceH4/llava-instruct-mix-vsft\n    type: chat_template\n    split: train[:1%]\n\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.0\noutput_dir: ./outputs/out\n\ngradient_accumulation_steps: 1\nmicro_batch_size: 1\nnum_epochs: 1\noptimizer: adamw_torch_4bit\nlr_scheduler: cosine\nlearning_rate: 2e-5\n\nbf16: true\ntf32: true\n\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 1\nsaves_per_epoch: 1\nweight_decay: 0.0\nfsdp:\n  - auto_wrap\n  - full_shard\nfsdp_config:\n  fsdp_transformer_layer_cls_to_wrap: Llama4TextDecoderLayer\n  fsdp_limit_all_gathers: true\n  fsdp_sync_module_states: true\n  fsdp_offload_params: true\n  fsdp_use_orig_params: false\n  fsdp_cpu_ram_efficient_loading: true\n  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP\n  fsdp_state_dict_type: FULL_STATE_DICT\n  fsdp_sharding_strategy: FULL_SHARD\n  fsdp_activation_checkpointing: true\nspecial_tokens:\n  pad_token: <|finetune_right_pad|>\n  eos_token: <|eot|>\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/llama-4/scout-qlora-flexattn-fsdp2.yaml",
    "content": "base_model: axolotl-quants/Llama-4-Scout-17B-16E-Linearized-bnb-nf4-bf16\nmodel_type: Llama4ForConditionalGeneration\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nplugins:\n  - axolotl.integrations.liger.LigerPlugin\n\nliger_glu_activation: true\nliger_rms_norm: true\nliger_layer_norm: true\n\nllama4_linearized_experts: true\nload_in_4bit: true\nadapter: qlora\nlora_r: 32\nlora_alpha: 64\nlora_target_modules:\n  - self_attn.q_proj\n  - self_attn.k_proj\n  - self_attn.v_proj\n  - self_attn.o_proj\n  - shared_expert.gate_proj\n  - shared_expert.up_proj\n  - shared_expert.down_proj\n  # - experts.gate_projs.[0-9]+$\n  # - experts.up_projs.[0-9]+$\n  # - experts.down_projs.[0-9]+$\nlora_modules_to_save:\n  # - lm_head\n  # - embed_tokens\n\nchat_template: llama4\ndatasets:\n  - path: mlabonne/FineTome-100k\n    type: chat_template\n    split: train[:20%]\n    field_messages: conversations\n    message_property_mappings:\n      role: from\n      content: value\n\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.0\noutput_dir: ./outputs/out\n\nsequence_len: 4096\nsample_packing: true\n\n\ngradient_accumulation_steps: 1\nmicro_batch_size: 2\nnum_epochs: 3\noptimizer: adamw_torch_4bit\nlr_scheduler: cosine\nlearning_rate: 1e-4\n\nbf16: true\ntf32: true\n\nlogging_steps: 1\nflex_attention: true\nflex_attn_compile_kwargs:\n  dynamic: false\n  mode: max-autotune-no-cudagraphs\n\nwarmup_ratio: 0.1\nevals_per_epoch: 1\nsaves_per_epoch: 1\nweight_decay: 0.0\nfsdp:\n  - auto_wrap\n  - full_shard\nfsdp_config:\n  fsdp_version: 2\n  fsdp_offload_params: false\n  # fsdp_cpu_ram_efficient_loading: true # does not work with load_in_8bit/4bit\n  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP\n  fsdp_transformer_layer_cls_to_wrap: Llama4TextDecoderLayer\n  fsdp_state_dict_type: SHARDED_STATE_DICT\n  fsdp_sharding_strategy: FULL_SHARD\n  fsdp_reshard_after_forward: true\n  fsdp_activation_checkpointing: true\nspecial_tokens:\n  pad_token: <|finetune_right_pad|>\n  eos_token: <|eot|>\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/llama-4/scout-qlora-single-h100-flex.yaml",
    "content": "base_model: axolotl-quants/Llama-4-Scout-17B-16E-Linearized-bnb-nf4-bf16\nmodel_type: Llama4ForConditionalGeneration\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nplugins:\n  - axolotl.integrations.liger.LigerPlugin\n  - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin\n\nliger_glu_activation: true\nliger_rms_norm: true\nliger_layer_norm: true\n\nllama4_linearized_experts: true  # needed with custom linearized experts model\nload_in_4bit: true\nadapter: qlora\nlora_r: 32\nlora_alpha: 64\nlora_target_modules:\n  - self_attn.q_proj\n  - self_attn.k_proj\n  - self_attn.v_proj\n  - self_attn.o_proj\n  - shared_expert.gate_proj\n  - shared_expert.up_proj\n  - shared_expert.down_proj\n  # - experts.gate_projs.[0-9]+$  # optionally train the moe experts\n  # - experts.up_projs.[0-9]+$\n  # - experts.down_projs.[0-9]+$\nlora_modules_to_save:\n  # - lm_head  # needed if modifying vocabulary\n  # - embed_tokens\n\nlora_mlp_kernel: true\nlora_qkv_kernel: true\nlora_o_kernel: true\n\nchat_template: llama4\ndatasets:\n  - path: mlabonne/FineTome-100k\n    type: chat_template\n    split: train[:20%]\n    field_messages: conversations\n    message_property_mappings:\n      role: from\n      content: value\n\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.0\noutput_dir: ./outputs/out\n\nsequence_len: 4096  # up to 8k will work on a single H100\nsample_packing: true\n\n\ngradient_accumulation_steps: 1\nmicro_batch_size: 1\nnum_epochs: 1\noptimizer: adamw_torch_4bit\nlr_scheduler: cosine\nlearning_rate: 1e-4\n\nbf16: true\ntf32: true\n\ntorch_compile: true\nflex_attention: true\nflex_attn_compile_kwargs:\n  dynamic: false\n  mode: max-autotune-no-cudagraphs\n\ngradient_checkpointing: offload\ngradient_checkpointing_kwargs:\n  use_reentrant: false\n\nlogging_steps: 1\nwarmup_ratio: 0.1\nevals_per_epoch: 1\nsaves_per_epoch: 1\n\nweight_decay: 0.0\nspecial_tokens:\n  pad_token: <|finetune_right_pad|>\n  eos_token: <|eot|>\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/llama-4/scout-vision-qlora-fsdp2-flex.yaml",
    "content": "base_model: axolotl-quants/Llama-4-Scout-17B-16E-Linearized-bnb-nf4-bf16\nmodel_type: Llama4ForConditionalGeneration\nprocessor_type: Llama4Processor\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\n# these 3 lines are needed for now to handle vision chat templates w images\nskip_prepare_dataset: true\nremove_unused_columns: false\nsample_packing: false\n\nsequence_len: 4096\n\nplugins:\n  - axolotl.integrations.liger.LigerPlugin\n\nliger_glu_activation: true\nliger_rms_norm: true\nliger_layer_norm: true\n\nllama4_linearized_experts: true  # use Axolotl's customized model\nload_in_4bit: true\nadapter: qlora\nlora_r: 32\nlora_alpha: 64\nlora_target_modules:\n  - self_attn.q_proj\n  - self_attn.k_proj\n  - self_attn.v_proj\n  - self_attn.o_proj\n  - shared_expert.gate_proj\n  - shared_expert.up_proj\n  - shared_expert.down_proj\n  - vision_adapter.mlp.fc1\n  - vision_adapter.mlp.fc2\n  # - experts.gate_projs.[0-9]+$\n  # - experts.up_projs.[0-9]+$\n  # - experts.down_projs.[0-9]+$\nlora_modules_to_save:\n  - lm_head\n  - embed_tokens\n\nchat_template: llama4\ndatasets:\n  - path: HuggingFaceH4/llava-instruct-mix-vsft\n    type: chat_template\n    split: train[:1%]\n\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.0\noutput_dir: ./outputs/out\n\ngradient_accumulation_steps: 1\nmicro_batch_size: 1\nnum_epochs: 1\noptimizer: adamw_torch_4bit\nlr_scheduler: cosine\nlearning_rate: 1e-4\n\nbf16: true\ntf32: true\n\nlogging_steps: 1\nflex_attention: true\nflex_attn_compile_kwargs:\n  dynamic: false\n  mode: max-autotune-no-cudagraphs\n\nwarmup_ratio: 0.1\nevals_per_epoch: 1\nsaves_per_epoch: 1\nweight_decay: 0.0\nfsdp:\n  - auto_wrap\n  - full_shard\nfsdp_config:\n  fsdp_version: 2\n  fsdp_offload_params: false\n  fsdp_cpu_ram_efficient_loading: true\n  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP\n  fsdp_transformer_layer_cls_to_wrap: Llama4TextDecoderLayer\n  fsdp_state_dict_type: SHARDED_STATE_DICT\n  fsdp_sharding_strategy: FULL_SHARD\n  fsdp_reshard_after_forward: true\n  fsdp_activation_checkpointing: true\nspecial_tokens:\n  pad_token: <|finetune_right_pad|>\n  eos_token: <|eot|>\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/llava/lora-7b.yaml",
    "content": "base_model: llava-hf/llava-1.5-7b-hf\nprocessor_type: AutoProcessor\n\n# these 3 lines are needed for now to handle vision chat templates w images\nskip_prepare_dataset: true\nremove_unused_columns: false\nsample_packing: false\n\nchat_template: llava\ndatasets:\n  - path: HuggingFaceH4/llava-instruct-mix-vsft\n    type: chat_template\n    split: train[:1%]\ndataset_prepared_path:\nval_set_size: 0.0\noutput_dir: ./outputs/out\n\nadapter: lora\nlora_model_dir:\n\nsequence_len: 8192\npad_to_sequence_len: false\n\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_modules: 'model.language_model.layers.[\\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 1\nnum_epochs: 1\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: true\nfp16:\ntf32: true\n\ngradient_checkpointing: true\nlogging_steps: 1\nflash_attention: true\neager_attention:\n\nwarmup_ratio: 0.1\nevals_per_epoch: 1\nsaves_per_epoch: 1\nweight_decay: 0.0\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/magistral/README.md",
    "content": "# Finetune Magistral Small with Axolotl\n\nMagistral Small is a 24B parameter opensource model from MistralAI found on HuggingFace at [2506](https://huggingface.co/mistralai/Magistral-Small-2506), [2507](https://huggingface.co/mistralai/Magistral-Small-2507) (see [Thinking](#thinking)), and [2509](https://huggingface.co/mistralai/Magistral-Small-2509) (see [Vision](#vision)). This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.\n\nMistralAI has also released a proprietary medium-sized version called Magistral Medium.\n\nThanks to the team at MistralAI for giving us early access to prepare for these releases.\n\n## Getting started\n\n1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).\n\n    Here is an example of how to install from pip:\n\n```bash\n# Ensure you have Pytorch installed (Pytorch 2.7.0 min)\npip3 install packaging==26.0 setuptools==75.8.0 wheel ninja\npip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'\n```\n\n2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage\n\n```bash\npython scripts/cutcrossentropy_install.py | sh\n```\n\n3. Run the finetuning example:\n\n```bash\naxolotl train examples/magistral/magistral-small-qlora.yaml\n```\n\nThis config uses about 24GB VRAM.\n\nLet us know how it goes. Happy finetuning! 🚀\n\n### Thinking\n\nMistralAI has released their [2507](https://huggingface.co/mistralai/Magistral-Small-2507) model with thinking capabilities, enabling Chain-of-Thought reasoning with explicit thinking steps.\n\n📚 **[See the Thinking fine-tuning guide →](./think/README.md)**\n\n### Vision\n\nMistralAI has released their [2509](https://huggingface.co/mistralai/Magistral-Small-2509) model with vision capabilities.\n\n📚 **[See the Vision fine-tuning guide →](./vision/README.md)**\n\n### Tips\n\n- We recommend adding the same/similar SystemPrompt that the model is tuned for. You can find this within the repo's files titled `SYSTEM_PROMPT.txt`.\n- For inference, the official MistralAI team recommends `top_p: 0.95` and `temperature: 0.7` with `max_tokens: 40960`.\n- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.\n- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).\n- The text dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).\n\n## Optimization Guides\n\n- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)\n- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)\n- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)\n\n## Limitations\n\nWe only support the `mistral-common` tokenizer for Supervised Fine-tuning at the moment and for `type: chat_template` only.\n\nIn addition, we do not support overriding tokens yet.\n\n## Related Resources\n\n- [MistralAI Magistral Blog](https://mistral.ai/news/magistral/)\n- [Axolotl Docs](https://docs.axolotl.ai)\n- [Axolotl Website](https://axolotl.ai)\n- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)\n- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)\n\n\n## Future Work\n\n- Add parity to Preference Tuning, RL, etc.\n- Add parity to other tokenizer configs like overriding tokens.\n"
  },
  {
    "path": "examples/magistral/magistral-small-fsdp-qlora.yaml",
    "content": "base_model: mistralai/Magistral-Small-2506\n\n# Enable to use mistral-common tokenizer\ntokenizer_use_mistral_common: true\n\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nplugins:\n  - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin\n\nload_in_8bit: false\nload_in_4bit: true\n\ndatasets:\n  - path: fozziethebeat/alpaca_messages_2k_test\n    type: chat_template\n\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.1\noutput_dir: ./outputs/lora-out\n\nadapter: qlora\nlora_model_dir:\n\nsequence_len: 2048\nsample_packing: true\neval_sample_packing: false\n\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_linear: true\nlora_target_modules:\n  - gate_proj\n  - down_proj\n  - up_proj\n  - q_proj\n  - v_proj\n  - k_proj\n  - o_proj\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 2\nnum_epochs: 1\noptimizer: adamw_torch_fused\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing:\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 1\nsaves_per_epoch: 1\n\nfsdp:\n  - full_shard\n  - auto_wrap\nfsdp_config:\n  fsdp_state_dict_type: FULL_STATE_DICT\n  fsdp_transformer_layer_cls_to_wrap: MistralDecoderLayer\n  fsdp_activation_checkpointing: true\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/magistral/magistral-small-qlora.yaml",
    "content": "base_model: mistralai/Magistral-Small-2506\n\n# Enable to use mistral-common tokenizer\ntokenizer_use_mistral_common: true\n\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nplugins:\n  - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin\n\nload_in_8bit: false\nload_in_4bit: true\n\ndatasets:\n  - path: fozziethebeat/alpaca_messages_2k_test\n    type: chat_template\n\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.1\noutput_dir: ./outputs/lora-out\n\nadapter: qlora\nlora_model_dir:\n\nsequence_len: 2048\nsample_packing: true\n\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_linear: true\nlora_target_modules:\n  - gate_proj\n  - down_proj\n  - up_proj\n  - q_proj\n  - v_proj\n  - k_proj\n  - o_proj\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 2\nnum_epochs: 1\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 1\nsaves_per_epoch: 1\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/magistral/think/README.md",
    "content": "# Magistral Small Thinking Fine-tuning\n\nThis guide covers fine-tuning [Magistral Small 2507](https://huggingface.co/mistralai/Magistral-Small-2507) with thinking capabilities using Axolotl. The thinking model enables explicit Chain-of-Thought reasoning with separate thinking and response sections.\n\n## Prerequisites\n\nBefore starting, ensure you have:\n\n- Installed Axolotl (see [main README](../README.md))\n\n## Getting Started\n\nRun the thinking model fine-tuning:\n\n```bash\naxolotl train examples/magistral/think/magistral-small-think-qlora.yaml\n```\n\nThis config uses about 19.1 GiB VRAM.\n\n### Tips\n\n- Dataset uses multi-content format with `type: thinking` support. See [Dataset Format](#dataset-format) below.\n- You cannot mix `content: str` and `content: list[dict]`, otherwise, dataset loading will fail. Keep it consistent.\n\n## Dataset Format\n\nThe thinking model requires the multi-content dataset format with support for an extra `role: thinking` within system and assistant messages.\n\nExample format:\n\n```json\n{\n    \"messages\": [\n        {\n            \"role\": \"system\",\n            \"content\": [\n                { \"type\": \"text\", \"text\": \"{SYSTEM_PROMPT}\"}\n            ]\n        },\n        {\n            \"role\": \"user\",\n            \"content\": [\n                { \"type\": \"text\", \"text\": \"Solve this step by step: What is 15% of 240?\"}\n            ]\n        },\n        {\n            \"role\": \"assistant\",\n            \"content\": [\n                {\n                    \"type\": \"thinking\",\n                    \"thinking\": \"I need to calculate 15% of 240. First, I'll convert 15% to decimal: 0.15. Then multiply: 0.15 × 240 = 36.\"\n                },\n                {\n                    \"type\": \"text\",\n                    \"text\": \"To find 15% of 240, I'll multiply 240 by 0.15:\\n\\n240 × 0.15 = 36\\n\\nTherefore, 15% of 240 is 36.\"\n                }\n            ]\n        }\n    ]\n}\n```\n\n### Advanced Options\n\nThe `thinking` section supports an optional `closed` parameter:\n\n```json\n{\n    \"type\": \"thinking\",\n    \"thinking\": \"Internal reasoning here...\",\n    \"closed\": true  // Default: true, controls adding the closing [/THINK] tag\n}\n```\n"
  },
  {
    "path": "examples/magistral/think/magistral-small-think-qlora.yaml",
    "content": "base_model: mistralai/Magistral-Small-2507\n\n# Enable to use mistral-common tokenizer\ntokenizer_use_mistral_common: true\n\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nplugins:\n  - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin\n\nload_in_8bit: false\nload_in_4bit: true\n\ndatasets:\n  - path: Nanobit/text-think-2k-test\n    type: chat_template\n\ndataset_prepared_path: last_run_prepared\nval_set_size: 0\noutput_dir: ./outputs/lora-out\n\nadapter: qlora\nlora_model_dir:\n\nsequence_len: 2048\nsample_packing: true\n\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_linear: true\nlora_target_modules:\n  - gate_proj\n  - down_proj\n  - up_proj\n  - q_proj\n  - v_proj\n  - k_proj\n  - o_proj\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 2\nnum_epochs: 1\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 1\nsaves_per_epoch: 1\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/magistral/vision/README.md",
    "content": "# Magistral Small Vision Fine-tuning\n\nThis guide covers fine-tuning [Magistral Small 2509](https://huggingface.co/mistralai/Magistral-Small-2509) with vision capabilities using Axolotl.\n\n## Prerequisites\n\nBefore starting, ensure you have:\n\n- Installed Axolotl from source (see [main README](../README.md))\n\n## Getting started\n\n1. Install the required vision lib:\n    ```bash\n    pip install 'mistral-common[opencv]==1.8.5'\n    ```\n\n2. Download the example dataset image:\n   ```bash\n   wget https://huggingface.co/datasets/Nanobit/text-vision-2k-test/resolve/main/African_elephant.jpg\n   ```\n\n3. Run the fine-tuning:\n   ```bash\n   axolotl train examples/magistral/vision/magistral-small-vision-24B-qlora.yml\n   ```\n\nThis config uses about 17GiB VRAM.\n\nWARNING: The loss and grad norm will be much higher than normal at first. We suspect this to be inherent to the model as of the moment. If anyone would like to submit a fix for this, we are happy to take a look.\n\n### Tips\n\nKey differences from text-only model:\n- `max_tokens: 131072` for inference\n- Multi-modal dataset format required\n- Sample packing not supported\n\n## Dataset Format\n\nThe vision model requires multi-modal dataset format as documented [here](https://docs.axolotl.ai/docs/multimodal.html#dataset-format).\n\nOne exception is that, passing `\"image\": PIL.Image` is not supported. MistralTokenizer only supports `path`, `url`, and `base64` for now.\n\nExample:\n```json\n{\n    \"messages\": [\n        {\"role\": \"system\", \"content\": [{ \"type\": \"text\", \"text\": \"{SYSTEM_PROMPT}\"}]},\n        {\"role\": \"user\", \"content\": [\n            { \"type\": \"text\", \"text\": \"What's in this image?\"},\n            {\"type\": \"image\", \"path\": \"path/to/image.jpg\" }\n        ]},\n        {\"role\": \"assistant\", \"content\": [{ \"type\": \"text\", \"text\": \"...\" }]},\n    ],\n}\n```\n\n## Limitations\n\n- Sample Packing is not supported for multi-modality training currently.\n"
  },
  {
    "path": "examples/magistral/vision/magistral-small-vision-24B-qlora.yml",
    "content": "base_model: mistralai/Magistral-Small-2509\nprocessor_type: AutoProcessor\n\n# Enable to use mistral-common tokenizer\ntokenizer_use_mistral_common: true\n\nplugins:\n  - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin\n\nload_in_4bit: true\n\n# these 3 lines are needed for now to handle vision chat templates w images\nskip_prepare_dataset: true\nremove_unused_columns: false\nsample_packing: false\n\n# sample dataset below requires downloading image in advance\n# wget https://huggingface.co/datasets/Nanobit/text-vision-2k-test/resolve/main/African_elephant.jpg\ndatasets:\n  - path: Nanobit/text-vision-2k-test\n    type: chat_template\n\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.01\noutput_dir: ./outputs/out\n\nadapter: qlora\nlora_model_dir:\n\nsequence_len: 2048\n\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_modules: 'model.language_model.layers.[\\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 1\nmicro_batch_size: 1\nnum_epochs: 1\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: true\nfp16:\ntf32: true\n\ngradient_checkpointing: true\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 1\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/mamba/config.yml",
    "content": "base_model: state-spaces/mamba-2.8b\n# optionally might have model_type or tokenizer_type or tokenizer_config\nmodel_type: MambaLMHeadModel\ntokenizer_type: AutoTokenizer\ntokenizer_config: EleutherAI/gpt-neox-20b\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\ndatasets:\n  - path: mhenrichsen/alpaca_2k_test\n    type: alpaca\ndataset_prepared_path:\nval_set_size: 0.0\noutput_dir: ./outputs/out\n\nsequence_len: 2048\nsample_packing: false\npad_to_sequence_len: false\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 1\nnum_epochs: 2\noptimizer: paged_adamw_8bit\nlr_scheduler: cosine\nlearning_rate: 5e-5\n\ntrain_on_inputs: false\ngroup_by_length: true\n\nbf16: auto\ntf32: true\n\ngradient_checkpointing: false\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention:\n\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\ntokens:\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/mimo/README.md",
    "content": "# Finetune Xiaomi's MiMo with Axolotl\n\n[MiMo](https://huggingface.co/XiaomiMiMo/MiMo-7B-RL) is a family of models trained from scratch for reasoning tasks, incorporating **Multiple-Token Prediction (MTP)** as an additional training objective for enhanced performance and faster inference. Pre-trained on ~25T tokens with a three-stage data mixture strategy and optimized reasoning pattern density.\n\nThis guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.\n\n## Getting started\n\n1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).\n\n2. Run the finetuning example:\n\n    ```bash\n    axolotl train examples/mimo/mimo-7b-qlora.yaml\n    ```\n\nThis config uses about 17.2 GiB VRAM. Let us know how it goes. Happy finetuning! 🚀\n\n### Tips\n\n- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.\n- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).\n- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).\n\n## Optimization Guides\n\nPlease check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).\n\n## Limitations\n\n**Cut Cross Entropy (CCE)**: Currently not supported. We plan to include CCE support for MiMo in the near future.\n\n## Related Resources\n\n- [MiMo Paper](https://arxiv.org/abs/2505.07608)\n- [Axolotl Docs](https://docs.axolotl.ai)\n- [Axolotl Website](https://axolotl.ai)\n- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)\n- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)\n"
  },
  {
    "path": "examples/mimo/mimo-7b-qlora.yaml",
    "content": "base_model: XiaomiMiMo/MiMo-7B-RL\ntrust_remote_code: true\nrevision_of_model: 6299b5a\n\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\n# CCE - N/A as of now\n# plugins:\n#   - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin\n\nload_in_8bit: false\nload_in_4bit: true\n\ndatasets:\n  - path: fozziethebeat/alpaca_messages_2k_test\n    type: chat_template\n\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.1\noutput_dir: ./outputs/lora-out\n\nadapter: qlora\nlora_model_dir:\n\nsequence_len: 2048\nsample_packing: true\n\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_linear: true\nlora_target_modules:\n  - gate_proj\n  - down_proj\n  - up_proj\n  - q_proj\n  - v_proj\n  - k_proj\n  - o_proj\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 2\nnum_epochs: 1\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 1\nsaves_per_epoch: 1\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/ministral/README.md",
    "content": "# Finetune Ministral with Axolotl\n\nMinistral is a family of openweight models from MistralAI found on [HuggingFace](mistralai/Ministral-8B-Instruct-2410). This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.\n\n## Getting started\n\n1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).\n\n2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.\n\n3. Run the finetuning example:\n\n    ```bash\n    axolotl train examples/ministral/ministral-small-qlora.yaml\n    ```\n\nThis config uses about 8.76 GiB VRAM.\n\nLet us know how it goes. Happy finetuning! 🚀\n\n### Tips\n\n- We recommend adding the same/similar SystemPrompt that the model is tuned for. You can find this within the repo's files titled `SYSTEM_PROMPT.txt`.\n- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.\n- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).\n- The text dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).\n\n## Optimization Guides\n\nPlease check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).\n\n## Limitations\n\nWe only support the `mistral-common` tokenizer for Supervised Fine-tuning at the moment and for `type: chat_template` only.\n\nIn addition, we do not support overriding tokens yet.\n\n## Related Resources\n\n- [MistralAI Ministral Blog](https://mistral.ai/news/ministraux)\n- [Axolotl Docs](https://docs.axolotl.ai)\n- [Axolotl Website](https://axolotl.ai)\n- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)\n- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)\n\n\n## Future Work\n\n- Add parity to Preference Tuning, RL, etc.\n- Add parity to other tokenizer configs like overriding tokens.\n"
  },
  {
    "path": "examples/ministral/ministral-small-qlora.yaml",
    "content": "base_model: mistralai/Ministral-8B-Instruct-2410\n\n# Enable to use mistral-common tokenizer\ntokenizer_use_mistral_common: true\n\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nplugins:\n  - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin\n\nload_in_8bit: false\nload_in_4bit: true\n\ndatasets:\n  - path: fozziethebeat/alpaca_messages_2k_test\n    type: chat_template\n\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.1\noutput_dir: ./outputs/lora-out\n\nadapter: qlora\nlora_model_dir:\n\nsequence_len: 2048\nsample_packing: true\n\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_linear: true\nlora_target_modules:\n  - gate_proj\n  - down_proj\n  - up_proj\n  - q_proj\n  - v_proj\n  - k_proj\n  - o_proj\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 2\nnum_epochs: 1\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 1\nsaves_per_epoch: 1\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/ministral3/README.md",
    "content": "# Finetune Ministral3 with Axolotl\n\nMinistral3 is a family of open-weight models from MistralAI found on [HuggingFace](https://huggingface.co/collections/mistralai/ministral-3). This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.\n\nPlease see [Thinking](#thinking) and [Vision](#vision) for their respective fine-tuning.\n\nThanks to the team at MistralAI for giving us early access to prepare for these releases.\n\nNote: This is still experimental given it is based on transformers v5 RC.\n\n## Getting started\n\n1. Install Axolotl from source following the [installation guide](https://docs.axolotl.ai/docs/installation.html#sec-edge-build).\n\n2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.\n\n3. Swap to the Axolotl transformers v5 branch\n\n    ```bash\n    cp examples/ministral3/ministral3-3b-qlora.yaml ministral3-3b-qlora.yaml\n\n    git fetch\n    git checkout transformers-v5\n\n    # Install packages for transformers v5\n    pip install -e .\n    ```\n\n4. Run the fine-tuning:\n\n    ```bash\n    axolotl train ministral3-3b-qlora.yaml\n    ```\n\nLet us know how it goes. Happy finetuning! 🚀\n\n\n### Tips\n\n- We recommend adding the same/similar SystemPrompt that the model is tuned for. You can find this within the repo's files titled `SYSTEM_PROMPT.txt`.\n- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.\n- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).\n- The text dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).\n\n### Thinking\n\nMinistral3 2512 model supports thinking capabilities, enabling Chain-of-Thought reasoning with explicit thinking steps.\n\n📚 **[See the Thinking fine-tuning guide →](./think/README.md)**\n\n### Vision\n\nMinistral3 2512 model also supports vision capabilities.\n\n📚 **[See the Vision fine-tuning guide →](./vision/README.md)**\n\n## Optimization Guides\n\nPlease check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).\n\n## Limitations\n\nWe only support the `mistral-common` tokenizer for Supervised Fine-tuning at the moment and for `type: chat_template` only.\n\nIn addition, we do not support overriding tokens yet.\n\n## Related Resources\n\n- [MistralAI Mistral3 Blog](https://mistral.ai/news/mistral-3)\n- [Axolotl Docs](https://docs.axolotl.ai)\n- [Axolotl Website](https://axolotl.ai)\n- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)\n- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)\n\n\n## Future Work\n\n- Add parity to Preference Tuning, RL, etc.\n- Add parity to other tokenizer configs like overriding tokens.\n"
  },
  {
    "path": "examples/ministral3/ministral3-3b-qlora.yaml",
    "content": "base_model: mistralai/Ministral-3-3B-Reasoning-2512\n\n# Enable to use mistral-common tokenizer\ntokenizer_use_mistral_common: true\n\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nplugins:\n  - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin\n\nload_in_8bit: false\nload_in_4bit: true\n\ndatasets:\n  - path: fozziethebeat/alpaca_messages_2k_test\n    type: chat_template\n\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.1\noutput_dir: ./outputs/lora-out\n\nadapter: qlora\nlora_model_dir:\n\nsequence_len: 2048\nsample_packing: true\n\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_linear: true\nlora_target_modules:\n  - gate_proj\n  - down_proj\n  - up_proj\n  - q_proj\n  - v_proj\n  - k_proj\n  - o_proj\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 2\nnum_epochs: 1\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\nscaling_softmax: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 1\nsaves_per_epoch: 1\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/ministral3/think/README.md",
    "content": "# Ministral3 2512 Thinking Fine-tuning\n\nThis guide covers fine-tuning [Ministral3 2512](https://huggingface.co/collections/mistralai/ministral-3) with thinking capabilities using Axolotl. The thinking model enables explicit Chain-of-Thought reasoning with separate thinking and response sections.\n\n## Prerequisites\n\nBefore starting, ensure you have:\n\n- Installed Axolotl (see [main README](../README.md))\n\n## Getting Started\n\nRun the thinking model fine-tuning:\n\n```bash\naxolotl train examples/ministral3/think/ministral3-3b-think-qlora.yaml\n```\n\nThis config uses about 4.76 GiB VRAM.\n\n### Tips\n\n- Dataset uses multi-content format with `type: thinking` support. See [Dataset Format](#dataset-format) below.\n- You cannot mix `content: str` and `content: list[dict]`, otherwise, dataset loading will fail. Keep it consistent.\n\n## Dataset Format\n\nThe thinking model requires the multi-content dataset format with support for an extra `role: thinking` within system and assistant messages.\n\nExample format:\n\n```json\n{\n    \"messages\": [\n        {\n            \"role\": \"system\",\n            \"content\": [\n                { \"type\": \"text\", \"text\": \"{SYSTEM_PROMPT}\"}\n            ]\n        },\n        {\n            \"role\": \"user\",\n            \"content\": [\n                { \"type\": \"text\", \"text\": \"Solve this step by step: What is 15% of 240?\"}\n            ]\n        },\n        {\n            \"role\": \"assistant\",\n            \"content\": [\n                {\n                    \"type\": \"thinking\",\n                    \"thinking\": \"I need to calculate 15% of 240. First, I'll convert 15% to decimal: 0.15. Then multiply: 0.15 × 240 = 36.\"\n                },\n                {\n                    \"type\": \"text\",\n                    \"text\": \"To find 15% of 240, I'll multiply 240 by 0.15:\\n\\n240 × 0.15 = 36\\n\\nTherefore, 15% of 240 is 36.\"\n                }\n            ]\n        }\n    ]\n}\n```\n\n### Advanced Options\n\nThe `thinking` section supports an optional `closed` parameter:\n\n```json\n{\n    \"type\": \"thinking\",\n    \"thinking\": \"Internal reasoning here...\",\n    \"closed\": true  // Default: true, controls adding the closing [/THINK] tag\n}\n```\n"
  },
  {
    "path": "examples/ministral3/think/ministral3-3b-think-qlora.yaml",
    "content": "base_model: mistralai/Ministral-3-3B-Reasoning-2512\n\n# Enable to use mistral-common tokenizer\ntokenizer_use_mistral_common: true\n\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nplugins:\n  - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin\n\nload_in_8bit: false\nload_in_4bit: true\n\ndatasets:\n  - path: Nanobit/text-think-2k-test\n    type: chat_template\n\ndataset_prepared_path: last_run_prepared\nval_set_size: 0\noutput_dir: ./outputs/lora-out\n\nadapter: qlora\nlora_model_dir:\n\nsequence_len: 2048\nsample_packing: true\n\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_linear: true\nlora_target_modules:\n  - gate_proj\n  - down_proj\n  - up_proj\n  - q_proj\n  - v_proj\n  - k_proj\n  - o_proj\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 2\nnum_epochs: 1\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 1\nsaves_per_epoch: 1\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/ministral3/vision/README.md",
    "content": "# Ministral3 2512 Vision Fine-tuning\n\nThis guide covers fine-tuning [Ministral3 2512](https://huggingface.co/collections/mistralai/ministral-3) with vision capabilities using Axolotl.\n\n## Prerequisites\n\nBefore starting, ensure you have:\n\n- Installed Axolotl from source (see [main README](../README.md))\n\n## Getting started\n\n1. Install the required vision lib:\n    ```bash\n    pip install 'mistral-common[opencv]==1.8.6'\n    ```\n\n2. Download the example dataset image:\n   ```bash\n   wget https://huggingface.co/datasets/Nanobit/text-vision-2k-test/resolve/main/African_elephant.jpg\n   ```\n\n3. Run the fine-tuning:\n   ```bash\n   axolotl train examples/ministral3/vision/ministral3-3b-vision-qlora.yml\n   ```\n\nWARNING: The loss and grad norm will be much higher than normal at first. We suspect this to be inherent to the model as of the moment. If anyone would like to submit a fix for this, we are happy to take a look.\n\n### Tips\n\nKey differences from text-only model:\n- Multi-modal dataset format required\n- Sample packing not supported\n\n## Dataset Format\n\nThe vision model requires multi-modal dataset format as documented [here](https://docs.axolotl.ai/docs/multimodal.html#dataset-format).\n\nOne exception is that, passing `\"image\": PIL.Image` is not supported. MistralTokenizer only supports `path`, `url`, and `base64` for now.\n\nExample:\n```json\n{\n    \"messages\": [\n        {\"role\": \"system\", \"content\": [{ \"type\": \"text\", \"text\": \"{SYSTEM_PROMPT}\"}]},\n        {\"role\": \"user\", \"content\": [\n            { \"type\": \"text\", \"text\": \"What's in this image?\"},\n            {\"type\": \"image\", \"path\": \"path/to/image.jpg\" }\n        ]},\n        {\"role\": \"assistant\", \"content\": [{ \"type\": \"text\", \"text\": \"...\" }]},\n    ],\n}\n```\n\n## Limitations\n\n- Sample Packing is not supported for multi-modality training currently.\n"
  },
  {
    "path": "examples/ministral3/vision/ministral3-3b-vision-qlora.yml",
    "content": "base_model: mistralai/Ministral-3-3B-Reasoning-2512\nprocessor_type: AutoProcessor\n\n# Enable to use mistral-common tokenizer\ntokenizer_use_mistral_common: true\n\nplugins:\n  - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin\n\nload_in_4bit: true\n\n# these 3 lines are needed for now to handle vision chat templates w images\nskip_prepare_dataset: true\nremove_unused_columns: false\nsample_packing: false\n\n# sample dataset below requires downloading image in advance\n# wget https://huggingface.co/datasets/Nanobit/text-vision-2k-test/resolve/main/African_elephant.jpg\ndatasets:\n  - path: Nanobit/text-vision-2k-test\n    type: chat_template\n\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.01\noutput_dir: ./outputs/out\n\nadapter: qlora\nlora_model_dir:\n\nsequence_len: 2048\n\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_modules: 'model.language_model.layers.[\\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 1\nmicro_batch_size: 1\nnum_epochs: 1\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: true\nfp16:\ntf32: true\n\ngradient_checkpointing: true\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 1\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/mistral/README.md",
    "content": "**Mistral 7B** is a language model with a total of 7.3 billion parameters, showcasing a notable performance across a variety of benchmarks.\n\nFine Tune:\n```shell\naccelerate launch -m axolotl.cli.train examples/mistral/config.yml\n\n```\n\nIf you run into CUDA OOM, use deepspeed with config zero2.json:\n```shell\naccelerate launch -m axolotl.cli.train examples/mistral/config.yml --deepspeed deepspeed_configs/zero2.json\n```\n"
  },
  {
    "path": "examples/mistral/bigstral/bigstral-ds-zero3.yaml",
    "content": "base_model: mistral-community/Mixtral-8x22B-v0.1\n# optionally might have model_type or tokenizer_type\nmodel_type: AutoModelForCausalLM\ntokenizer_type: LlamaTokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\ntrust_remote_code: true\n\nunfrozen_parameters:\n  - ^lm_head.weight$\n  - ^model.embed_tokens.weight$\n  - model.layers.4[4-9]+.block_sparse_moe.gate\n  - model.layers.4[4-9]+.block_sparse_moe.experts\n  - model.layers.5[0-5]+.block_sparse_moe.gate\n  - model.layers.5[0-5]+.block_sparse_moe.experts\n\nmodel_config:\n  output_router_logits: true\n\ndatasets:\n  - path: tatsu-lab/alpaca\n    type: alpaca\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.05\noutput_dir: ./outputs/out\n\nsequence_len: 2048\nsample_packing: true\n\n\ngradient_accumulation_steps: 1\nmicro_batch_size: 1\nnum_epochs: 3\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0001\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nsave_total_limit: 1\nsave_steps:\n\ndeepspeed: deepspeed_configs/zero3_bf16_cpuoffload_params.json\nweight_decay: 0.0\nspecial_tokens:\n  eos_token: \"<|im_end|>\"\ntokens:\n  - \"<|im_start|>\"\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/mistral/config.yml",
    "content": "base_model: mistralai/Mistral-7B-v0.1\n# optionally might have model_type or tokenizer_type\nmodel_type: MistralForCausalLM\ntokenizer_type: LlamaTokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\ndatasets:\n  - path: mhenrichsen/alpaca_2k_test\n    type: alpaca\ndataset_prepared_path:\nval_set_size: 0.05\noutput_dir: ./outputs/out\n\nsequence_len: 8192\nsample_packing: true\n\neval_sample_packing: false\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 2\nnum_epochs: 4\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.000005\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/mistral/dpo/mistral-dpo-qlora.yml",
    "content": "#Note that we are switching from the regular chat template to chatml.\n#If you experience problems with the special tokens, training for more epochs can help.\n#After training, merge the model before inference otherwise you might\n#face problems with the special tokens.\n\nbase_model: mistralai/Mistral-7B-Instruct-v0.2\n# optionally might have model_type or tokenizer_type\nmodel_type: MistralForCausalLM\ntokenizer_type: LlamaTokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nload_in_8bit: false\nload_in_4bit: true\n\nchat_template: chatml\nrl: dpo\ndatasets:\n  - path: olivermolenschot/alpaca_messages_dpo_test\n    type: chat_template.default\n    field_messages: conversation\n    field_chosen: chosen\n    field_rejected: rejected\n    message_property_mappings:\n      role: role\n      content: content\n\ndataset_prepared_path:\nval_set_size: 0.05\noutput_dir: ./outputs/dpo-qlora\n\nsequence_len: 2048\nsample_packing: false\n\n\nadapter: qlora\nlora_model_dir:\nlora_r: 8\nlora_alpha: 16\nlora_dropout: 0.2\nlora_target_linear: true\n\nlora_target_modules:\n  - gate_proj\n  - down_proj\n  - up_proj\n  - q_proj\n  - v_proj\n  - k_proj\n  - o_proj\nlora_modules_to_save:\n - embed_tokens\n - lm_head\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 16\nnum_epochs: 6\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0001\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: false\n\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n  bos_token: \"<|im_start|>\"\n  eos_token: \"<|im_end|>\"\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/mistral/lora.yml",
    "content": "base_model: mistralai/Mistral-7B-v0.1\n# optionally might have model_type or tokenizer_type\nmodel_type: MistralForCausalLM\ntokenizer_type: LlamaTokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nload_in_8bit: true\nload_in_4bit: false\n\ndatasets:\n  - path: mhenrichsen/alpaca_2k_test\n    type: alpaca\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.1\noutput_dir: ./outputs/lora-out\n\nadapter: lora\nlora_model_dir:\n\nsequence_len: 8192\nsample_packing: true\n\n\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_linear: true\nlora_target_modules:\n  - gate_proj\n  - down_proj\n  - up_proj\n  - q_proj\n  - v_proj\n  - k_proj\n  - o_proj\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 2\nnum_epochs: 1\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nloss_watchdog_threshold: 5.0\nloss_watchdog_patience: 3\n\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/mistral/mistral-qlora-fsdp.yml",
    "content": "base_model: mistralai/Mixtral-8x7B-v0.1\n# optionally might have model_type or tokenizer_type\nmodel_type: AutoModelForCausalLM\ntokenizer_type: LlamaTokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\ntrust_remote_code: true\n\nload_in_8bit: false\nload_in_4bit: true\n\ndatasets:\n  - path: tatsu-lab/alpaca\n    type: alpaca\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.02\noutput_dir: ./outputs/qlora-out\n\nmodel_config:\n  output_router_logits: true\n\nadapter: qlora\nlora_model_dir:\n\nsequence_len: 1024\nsample_packing: false\npad_to_sequence_len: false\n\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_linear: true\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 2\nnum_epochs: 1\noptimizer: paged_adamw_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nloss_watchdog_threshold: 5.0\nloss_watchdog_patience: 3\n\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\n\nweight_decay: 0.0\nfsdp:\n  - full_shard\n  - auto_wrap\nfsdp_config:\n  fsdp_limit_all_gathers: true\n  fsdp_sync_module_states: true\n  fsdp_offload_params: false\n  fsdp_use_orig_params: false\n  fsdp_cpu_ram_efficient_loading: false\n  fsdp_transformer_layer_cls_to_wrap: MistralDecoderLayer\n  fsdp_state_dict_type: FULL_STATE_DICT\n  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP\nspecial_tokens:\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/mistral/mixtral/mixtral-8x22b-qlora-fsdp.yml",
    "content": "base_model: mistral-community/Mixtral-8x22B-v0.1\n# optionally might have model_type or tokenizer_type\nmodel_type: AutoModelForCausalLM\ntokenizer_type: LlamaTokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nload_in_8bit: false\nload_in_4bit: true\n\ndatasets:\n  - path: tatsu-lab/alpaca\n    type: alpaca\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.02\noutput_dir: ./outputs/qlora-out\n\nmodel_config:\n  output_router_logits: true\n\nadapter: qlora\nlora_model_dir:\n\nsequence_len: 1024\nsample_packing: false\npad_to_sequence_len: false\n\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_linear: true\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 2\nnum_epochs: 1\noptimizer: adamw_torch_fused\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: true\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nloss_watchdog_threshold: 5.0\nloss_watchdog_patience: 3\n\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\n\nweight_decay: 0.0\nfsdp:\n  - full_shard\n  - auto_wrap\nfsdp_config:\n  fsdp_limit_all_gathers: true\n  fsdp_sync_module_states: true\n  fsdp_offload_params: true\n  fsdp_use_orig_params: false\n  fsdp_cpu_ram_efficient_loading: true\n  fsdp_transformer_layer_cls_to_wrap: MixtralSparseMoeBlock\n  fsdp_state_dict_type: FULL_STATE_DICT\n  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP\nspecial_tokens:\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/mistral/mixtral/mixtral-qlora-fsdp.yml",
    "content": "base_model: mistralai/Mixtral-8x7B-v0.1\n# optionally might have model_type or tokenizer_type\nmodel_type: AutoModelForCausalLM\ntokenizer_type: LlamaTokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\ntrust_remote_code: true\n\nload_in_8bit: false\nload_in_4bit: true\n\ndatasets:\n  - path: tatsu-lab/alpaca\n    type: alpaca\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.02\noutput_dir: ./outputs/qlora-out\n\nmodel_config:\n  output_router_logits: true\n\nadapter: qlora\nlora_model_dir:\n\nsequence_len: 1024\nsample_packing: false\npad_to_sequence_len: false\n\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_linear: true\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 2\nnum_epochs: 1\noptimizer: adamw_torch_fused\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: true\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nloss_watchdog_threshold: 5.0\nloss_watchdog_patience: 3\n\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\n\nweight_decay: 0.0\nfsdp:\n  - full_shard\n  - auto_wrap\nfsdp_config:\n  fsdp_limit_all_gathers: true\n  fsdp_sync_module_states: true\n  fsdp_offload_params: true\n  fsdp_use_orig_params: false\n  fsdp_cpu_ram_efficient_loading: true\n  fsdp_transformer_layer_cls_to_wrap: MixtralSparseMoeBlock\n  fsdp_state_dict_type: FULL_STATE_DICT\n  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP\n  fsdp_sharding_strategy: FULL_SHARD\n  fsdp_forward_prefetch: false\n  fsdp_backward_prefetch: BACKWARD_PRE\nspecial_tokens:\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/mistral/mixtral/mixtral.yml",
    "content": "base_model: mistralai/Mixtral-8x7B-v0.1\n# optionally might have model_type or tokenizer_type\nmodel_type: AutoModelForCausalLM\ntokenizer_type: LlamaTokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\ntrust_remote_code: true\n\nload_in_8bit: false\nload_in_4bit: true\n\ndatasets:\n  - path: tatsu-lab/alpaca\n    type: alpaca\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.0\noutput_dir: ./outputs/qlora-out\n\n## You can optionally freeze the entire model and unfreeze a subset of parameters\nunfrozen_parameters:\n#  - ^lm_head.weight$\n#  - ^model.embed_tokens.weight$[:32000]\n#  - model.layers.2[0-9]+.block_sparse_moe.gate\n#  - model.layers.2[0-9]+.block_sparse_moe.experts\n#  - model.layers.3[0-9]+.block_sparse_moe.gate\n#  - model.layers.3[0-9]+.block_sparse_moe.experts\n\nmodel_config:\n  output_router_logits: true\n\nadapter: qlora\nlora_model_dir:\n\nsequence_len: 4096\nsample_packing: true\n\n\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_linear: true\n#lora_target_modules:\n#  - gate\n#  - q_proj\n#  - k_proj\n#  - v_proj\n#  - o_proj\n#  - w1\n#  - w2\n#  - w3\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 2\nmicro_batch_size: 1\nnum_epochs: 1\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nloss_watchdog_threshold: 5.0\nloss_watchdog_patience: 3\n\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\n\ndeepspeed: deepspeed_configs/zero2.json\nweight_decay: 0.0\nspecial_tokens:\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/mistral/mixtral/mixtral_22.yml",
    "content": "base_model: mistral-community/Mixtral-8x22B-v0.1\n# optionally might have model_type or tokenizer_type\nmodel_type: AutoModelForCausalLM\ntokenizer_type: LlamaTokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\ntrust_remote_code: true\n\nunfrozen_parameters:\n  - ^lm_head.weight$\n  - ^model.embed_tokens.weight$\n  - model.layers.4[4-9]+.block_sparse_moe.gate\n  - model.layers.4[4-9]+.block_sparse_moe.experts\n  - model.layers.5[0-5]+.block_sparse_moe.gate\n  - model.layers.5[0-5]+.block_sparse_moe.experts\n\nmodel_config:\n  output_router_logits: true\n\ndatasets:\n  - path: yahma/alpaca-cleaned\n    type: alpaca\noutput_dir: ./outputs/out\n\nsequence_len: 8000\nsample_packing: true\n\n\ngradient_accumulation_steps: 1\nmicro_batch_size: 1\nnum_epochs: 3\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0001\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nsave_total_limit: 1\nsave_steps:\n\ndeepspeed: deepspeed_configs/zero3_bf16_cpuoffload_all.json\nweight_decay: 0.0\nspecial_tokens:\n  eos_token: \"<|im_end|>\"\ntokens:\n  - \"<|im_start|>\"\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/mistral/mps/lora-mps.yml",
    "content": "base_model: mistralai/Mistral-7B-v0.1\n# optionally might have model_type or tokenizer_type\nmodel_type: MistralForCausalLM\ntokenizer_type: LlamaTokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\ndatasets:\n  - path: mhenrichsen/alpaca_2k_test\n    type: alpaca\ndataset_prepared_path: last_run_prepared\nval_set_size: 0\noutput_dir: ./outputs/lora-out\neval_sample_packing: false\n\nadapter: lora\nlora_model_dir:\n\nsequence_len: 4096\nsample_packing: true\n\n\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_linear: true\nlora_target_modules:\n  - gate_proj\n  - down_proj\n  - up_proj\n  - q_proj\n  - v_proj\n  - k_proj\n  - o_proj\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 8\nmicro_batch_size: 1\nnum_epochs: 2\noptimizer: adamw_torch_fused\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\nfp16: false\ntf32: true\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: false\nsdp_attention: true\n\nloss_watchdog_threshold: 5.0\nloss_watchdog_patience: 3\n\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/mistral/orpo/mistral-qlora-orpo.yml",
    "content": "base_model: mistralai/Mistral-7B-v0.1\n# optionally might have model_type or tokenizer_type\nmodel_type: MistralForCausalLM\ntokenizer_type: LlamaTokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nload_in_8bit: false\nload_in_4bit: true\n\nrl: orpo\norpo_alpha: 0.1\nremove_unused_columns: false\n\nchat_template: chatml\ndatasets:\n  - path: argilla/ultrafeedback-binarized-preferences-cleaned\n    type: chat_template.argilla\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.1\noutput_dir: ./outputs/mistral-qlora-orpo-out\n\nadapter: qlora\nlora_model_dir:\n\nsequence_len: 4096\nsample_packing: false\n\n\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_linear: true\nlora_target_modules:\n  - gate_proj\n  - down_proj\n  - up_proj\n  - q_proj\n  - v_proj\n  - k_proj\n  - o_proj\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 2\nnum_epochs: 1\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nloss_watchdog_threshold: 5.0\nloss_watchdog_patience: 3\n\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/mistral/qlora.yml",
    "content": "base_model: mistralai/Mistral-7B-v0.1\n# optionally might have model_type or tokenizer_type\nmodel_type: MistralForCausalLM\ntokenizer_type: LlamaTokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nload_in_8bit: false\nload_in_4bit: true\n\ndatasets:\n  - path: mhenrichsen/alpaca_2k_test\n    type: alpaca\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.1\noutput_dir: ./outputs/qlora-out\n\nadapter: qlora\nlora_model_dir:\n\nsequence_len: 8192\nsample_packing: true\n\n\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_linear: true\nlora_target_modules:\n  - gate_proj\n  - down_proj\n  - up_proj\n  - q_proj\n  - v_proj\n  - k_proj\n  - o_proj\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 2\nnum_epochs: 1\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nloss_watchdog_threshold: 5.0\nloss_watchdog_patience: 3\n\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/mistral-small/README.md",
    "content": "# Mistral Small 3.1/3.2 Fine-tuning\n\nThis guide covers fine-tuning [Mistral Small 3.1](mistralai/Mistral-Small-3.1-24B-Instruct-2503) and [Mistral Small 3.2](mistralai/Mistral-Small-3.2-24B-Instruct-2506) with vision capabilities using Axolotl.\n\n## Prerequisites\n\nBefore starting, ensure you have:\n\n- Installed Axolotl (see [Installation docs](https://docs.axolotl.ai/docs/installation.html))\n\n## Getting Started\n\n1. Install the required vision lib:\n    ```bash\n    pip install 'mistral-common[opencv]==1.8.5'\n    ```\n\n2. Download the example dataset image:\n   ```bash\n   wget https://huggingface.co/datasets/Nanobit/text-vision-2k-test/resolve/main/African_elephant.jpg\n   ```\n\n3. Run the fine-tuning:\n   ```bash\n   axolotl train examples/mistral/mistral-small/mistral-small-3.1-24B-lora.yml\n   ```\n\nThis config uses about 29.4 GiB VRAM.\n\n## Dataset Format\n\nThe vision model requires multi-modal dataset format as documented [here](https://docs.axolotl.ai/docs/multimodal.html#dataset-format).\n\nOne exception is that, passing `\"image\": PIL.Image` is not supported. MistralTokenizer only supports `path`, `url`, and `base64` for now.\n\nExample:\n```json\n{\n    \"messages\": [\n        {\"role\": \"system\", \"content\": [{ \"type\": \"text\", \"text\": \"{SYSTEM_PROMPT}\"}]},\n        {\"role\": \"user\", \"content\": [\n            { \"type\": \"text\", \"text\": \"What's in this image?\"},\n            {\"type\": \"image\", \"path\": \"path/to/image.jpg\" }\n        ]},\n        {\"role\": \"assistant\", \"content\": [{ \"type\": \"text\", \"text\": \"...\" }]},\n    ],\n}\n```\n\n## Limitations\n\n- Sample Packing is not supported for multi-modality training currently.\n"
  },
  {
    "path": "examples/mistral-small/mistral-small-3.1-24B-lora.yml",
    "content": "base_model: mistralai/Mistral-Small-3.1-24B-Instruct-2503\nprocessor_type: AutoProcessor\n\n# Enable to use mistral-common tokenizer\ntokenizer_use_mistral_common: true\n\nload_in_8bit: true\n\n# these 3 lines are needed for now to handle vision chat templates w images\nskip_prepare_dataset: true\nremove_unused_columns: false\nsample_packing: false\n\n# sample dataset below requires downloading image in advance\n# wget https://huggingface.co/datasets/Nanobit/text-vision-2k-test/resolve/main/African_elephant.jpg\ndatasets:\n  - path: Nanobit/text-vision-2k-test\n    type: chat_template\n\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.01\noutput_dir: ./outputs/out\n\nadapter: lora\nlora_model_dir:\n\nsequence_len: 2048\npad_to_sequence_len: false\n\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_modules: 'model.language_model.layers.[\\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 1\nmicro_batch_size: 2\nnum_epochs: 1\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: true\nfp16:\ntf32: true\n\ngradient_checkpointing: true\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 1\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/mistral4/README.md",
    "content": "# Finetune Mistral Small 4 with Axolotl\n\nMistral Small 4 is a 119B parameter (6.5B active) multimodal MoE model from MistralAI that unifies instruct, reasoning, and coding capabilities into a single model. It is available on HuggingFace at [Mistral-Small-4-119B-2603](https://huggingface.co/mistralai/Mistral-Small-4-119B-2603).\n\nThanks to the team at MistralAI for giving us early access to prepare for this release.\n\n## Getting started\n\n1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).\n\n2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage\n\n3. Install transformers from main\n\n  ```bash\n  pip install git+https://github.com/huggingface/transformers.git\n  ```\n\n4. Run one of the example configs:\n\n  ```bash\n  # text-only\n  axolotl train examples/mistral4/qlora-text.yml  # no experts ~69 GiB, experts ~93 GiB\n  axolotl train examples/mistral4/fft-text.yml\n\n  # text + vision\n  # run: wget https://huggingface.co/datasets/Nanobit/text-vision-2k-test/resolve/main/African_elephant.jpg\n  axolotl train examples/mistral4/qlora-vision.yml  # no experts ~68 GiB\n  axolotl train examples/mistral4/fft-vision.yml\n  ```\n\nNote: FFT configs provided as reference. Please adjust hyperparameters as needed.\n\n## Reasoning Effort\n\nThe chat template supports a `reasoning_effort` variable to control the model's reasoning depth:\n\n- `\"none\"` — instruct mode (default)\n- `\"high\"` — reasoning mode with explicit thinking steps\n\nPass it via `chat_template_kwargs` under your dataset config:\n\n```yaml\ndatasets:\n  - path: your/dataset\n    type: chat_template\n    chat_template_kwargs:\n      reasoning_effort: high\n```\n\n## Thinking Support\n\nThe chat template supports a `thinking` content type in assistant messages for training on reasoning traces (rendered as `[THINK]...[/THINK]` blocks).\n\nTo use thinking datasets, add the `thinking` mapping via `message_property_mappings`:\n\n```yaml\ndatasets:\n  - path: your/thinking-dataset\n    type: chat_template\n    message_property_mappings:\n      role: role\n      content: content\n      thinking: thinking\n    chat_template_kwargs:\n      reasoning_effort: high\n```\n\nSee the [Magistral thinking guide](../magistral/think/README.md) for dataset format details.\n\n## Tips\n\n- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).\n- The text dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).\n- The vision model requires multi-modal dataset format as documented [here](https://docs.axolotl.ai/docs/multimodal.html#dataset-format).\n\n## Related Resources\n\n- [MistralAI Mistral Small 4 Blog](https://mistral.ai/news/mistral-small-4)\n- [Axolotl Docs](https://docs.axolotl.ai)\n- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)\n- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)\n"
  },
  {
    "path": "examples/mistral4/fft-text.yml",
    "content": "base_model: axolotl-ai-co/Mistral-Small-4-119B-2603-BF16\n\nplugins:\n  - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin\n  - axolotl.integrations.kernels.KernelsPlugin\nuse_kernels: true\nuse_sonicmoe: true\n\n# only train language model layers, freeze vision tower\nunfrozen_parameters:\n  - model.language_model.*\n  - lm_head\n  - embed_tokens\n\ndatasets:\n  - path: fozziethebeat/alpaca_messages_2k_test\n    type: chat_template\n\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.01\noutput_dir: ./outputs/out\n\nsequence_len: 2048\nsample_packing: true\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 1\nmicro_batch_size: 1\nnum_epochs: 1\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 2e-5\n\nbf16: true\ntf32: true\n\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 1\nsaves_per_epoch: 1\nweight_decay: 0.0\n\nfsdp_version: 2\nfsdp_config:\n  offload_params: false\n  cpu_ram_efficient_loading: false\n  state_dict_type: FULL_STATE_DICT\n  auto_wrap_policy: TRANSFORMER_BASED_WRAP\n  transformer_layer_cls_to_wrap: Mistral4DecoderLayer\n  reshard_after_forward: true\n  activation_checkpointing: true\n"
  },
  {
    "path": "examples/mistral4/fft-vision.yml",
    "content": "base_model: axolotl-ai-co/Mistral-Small-4-119B-2603-BF16\nprocessor_type: AutoProcessor\n\nplugins:\n  - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin\n  - axolotl.integrations.kernels.KernelsPlugin\nuse_kernels: true\nuse_sonicmoe: true\n\n# vision requirements\nskip_prepare_dataset: true\nremove_unused_columns: false\nsample_packing: false\n\ndatasets:\n  - path: Nanobit/text-vision-2k-test\n    type: chat_template\n\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.01\noutput_dir: ./outputs/out\n\nsequence_len: 2048\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 1\nmicro_batch_size: 1\nnum_epochs: 1\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 2e-5\n\nbf16: true\ntf32: true\n\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 1\nsaves_per_epoch: 1\nweight_decay: 0.0\n\nfsdp_version: 2\nfsdp_config:\n  offload_params: false\n  cpu_ram_efficient_loading: false\n  state_dict_type: FULL_STATE_DICT\n  auto_wrap_policy: TRANSFORMER_BASED_WRAP\n  transformer_layer_cls_to_wrap: Mistral4DecoderLayer\n  reshard_after_forward: true\n  activation_checkpointing: true\n"
  },
  {
    "path": "examples/mistral4/qlora-text.yml",
    "content": "base_model: axolotl-ai-co/Mistral-Small-4-119B-2603-BF16\n\nplugins:\n  - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin\n\nload_in_4bit: true\nquantize_moe_experts: true\n\ndatasets:\n  - path: fozziethebeat/alpaca_messages_2k_test\n    type: chat_template\n\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.01\noutput_dir: ./outputs/out\n\nadapter: qlora\n\nsequence_len: 2048\nsample_packing: true\n\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_modules: 'model.language_model.layers.[\\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'\n\n# uncomment to train on expert layers\n# lora_target_parameters:\n#   - mlp.experts.gate_up_proj\n#   - mlp.experts.down_proj\n# lora_mlp_kernel: false\n# lora_qkv_kernel: false\n# lora_o_kernel: false\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 1\nmicro_batch_size: 1\nnum_epochs: 1\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: true\ntf32: true\n\ngradient_checkpointing: true\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 1\nsaves_per_epoch: 1\nweight_decay: 0.0\n"
  },
  {
    "path": "examples/mistral4/qlora-vision.yml",
    "content": "base_model: axolotl-ai-co/Mistral-Small-4-119B-2603-BF16\nprocessor_type: AutoProcessor\n\nplugins:\n  - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin\n\nload_in_4bit: true\nquantize_moe_experts: true\n\n# vision chat template requirements\nskip_prepare_dataset: true\nremove_unused_columns: false\nsample_packing: false\n\ndatasets:\n  - path: Nanobit/text-vision-2k-test\n    type: chat_template\n\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.01\noutput_dir: ./outputs/out\n\nadapter: qlora\n\nsequence_len: 2048\n\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_modules: 'model.language_model.layers.[\\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'\n\n# uncomment to train on expert layers\n# lora_target_parameters:\n#   - mlp.experts.gate_up_proj\n#   - mlp.experts.down_proj\n# lora_mlp_kernel: false\n# lora_qkv_kernel: false\n# lora_o_kernel: false\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 1\nmicro_batch_size: 1\nnum_epochs: 1\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: true\ntf32: true\n\ngradient_checkpointing: true\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 1\nsaves_per_epoch: 1\nweight_decay: 0.0\n"
  },
  {
    "path": "examples/nemotron/nemotron-mini-4b-qlora.yaml",
    "content": "base_model: nvidia/Nemotron-Mini-4B-Instruct\n\nload_in_8bit: false\nload_in_4bit: true\n\ndatasets:\n  - path: fozziethebeat/alpaca_messages_2k_test\n    type: chat_template\n\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.1\noutput_dir: ./outputs/nemotron-mini-4b-qlora\n\nadapter: qlora\nlora_model_dir:\n\nsequence_len: 4096\nsample_packing: true\n\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_linear: true\nlora_target_modules:\n  - q_proj\n  - k_proj\n  - v_proj\n  - o_proj\n  - up_proj\n  - down_proj\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 2\nnum_epochs: 1\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 1\nsaves_per_epoch: 1\n\nspecial_tokens:\n"
  },
  {
    "path": "examples/olmo3/README.md",
    "content": "# Finetune Allenai's Olmo 3 with Axolotl\n\n[Olmo 3](https://huggingface.co/collections/allenai/olmo-3) are a family of 7B and 32B models open source models trained by The Allen Institute for Artificial Intelligence.\n\nThis guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.\n\n## Getting started\n\n1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).\n\n2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.\n\n3. Run the finetuning example:\n\n    ```bash\n    axolotl train examples/olmo3/olmo3-7b-qlora.yaml\n    ```\n\nThis uses about 11.3 GiB VRAM. Let us know how it goes. Happy finetuning! 🚀\n\n### TIPS\n\n- The example config can be re-used for Olmo and Olmo 2.\n- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.\n- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).\n- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).\n\n## Optimization Guides\n\nPlease check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).\n\n## Related Resources\n\n- [Olmo 3 Blog](https://allenai.org/blog/olmo3)\n- [Axolotl Docs](https://docs.axolotl.ai)\n- [Axolotl Website](https://axolotl.ai)\n- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)\n- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)\n"
  },
  {
    "path": "examples/olmo3/olmo3-7b-qlora.yaml",
    "content": "base_model: allenai/Olmo-3-7B-Instruct-SFT\n\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nplugins:\n  - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin\n\nload_in_8bit: false\nload_in_4bit: true\n\ndatasets:\n  - path: fozziethebeat/alpaca_messages_2k_test\n    type: chat_template\n\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.1\noutput_dir: ./outputs/lora-out\n\nadapter: qlora\nlora_model_dir:\n\nsequence_len: 2048\nsample_packing: true\n\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_linear: true\nlora_target_modules:\n  - gate_proj\n  - down_proj\n  - up_proj\n  - q_proj\n  - v_proj\n  - k_proj\n  - o_proj\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 2\nmicro_batch_size: 2\nnum_epochs: 1\noptimizer: adamw_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 1\nsaves_per_epoch: 1\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/orpheus/README.md",
    "content": "# Finetuning LLMs to output audio\n\nIn this example, we finetune Orpcanopylabs/orpheus-tts-0.1-pretrained (a LLaMA 3.2 3b model) to output audio.\n\nThe `finetune.yml` withe current settings will run on any Nvidia GPU with 45GB VRAM or more. If you adjust the batch size it can easily run on any GPU under 24GB.\n\n## Dataset pre-processing for pre-training\nIf you are adding another voice in English, please jump ahead to finetuning pre-processing.\n\nFor this to work, we need to preprocess our dataset. Since we are expecting to output audio, we will need to add tokens to the tokenizer.\n\nUsing this code, it will download the SNAC model and add the correct tokens and upload the final dataset.\n\n```python\nimport torch\nfrom snac import SNAC\nfrom datasets import load_dataset\nfrom huggingface_hub import snapshot_download\nfrom datasets import load_dataset\nimport random\nimport torchaudio.transforms as T\nfrom transformers import AutoTokenizer\nimport os\n\nmy_original_dataset_name = \"<huggingface-id-of-dataset-that-we-want-to-preprocess>\"\nname_to_push_dataset_to = \"<huggingface-id-of-where-to-save-dataset>\"\n\ndsn = my_original_dataset_name\n\nsnapshot_download(\n    repo_id=dsn,\n    repo_type=\"dataset\",\n    revision=\"main\",\n    max_workers=64,\n)\n\n\nds = load_dataset(dsn, split=\"train\")\nds_sample_rate = ds[0][\"audio\"][\"sampling_rate\"]\n\nmodel = SNAC.from_pretrained(\"hubertsiuzdak/snac_24khz\")\nmodel = model.to(\"mps\")\n\ndef tokenise_audio(waveform):\n  waveform = torch.from_numpy(waveform).unsqueeze(0)\n  waveform = waveform.to(dtype=torch.float32)\n  resample_transform = T.Resample(orig_freq=ds_sample_rate, new_freq=24000)\n  waveform = resample_transform(waveform)\n\n  waveform = waveform.unsqueeze(0).to(\"cuda\")\n\n  #generate the codes from snac\n  with torch.inference_mode():\n    codes = model.encode(waveform)\n\n  all_codes = []\n  for i in range(codes[0].shape[1]):\n    all_codes.append(codes[0][0][i].item()+128266)\n    all_codes.append(codes[1][0][2*i].item()+128266+4096)\n    all_codes.append(codes[2][0][4*i].item()+128266+(2*4096))\n    all_codes.append(codes[2][0][(4*i)+1].item()+128266+(3*4096))\n    all_codes.append(codes[1][0][(2*i)+1].item()+128266+(4*4096))\n    all_codes.append(codes[2][0][(4*i)+2].item()+128266+(5*4096))\n    all_codes.append(codes[2][0][(4*i)+3].item()+128266+(6*4096))\n\n\n  return all_codes\n\ndef add_codes(example):\n    # Always initialize codes_list to None\n    codes_list = None\n\n    try:\n        answer_audio = example.get(\"audio\")\n        # If there's a valid audio array, tokenise it\n        if answer_audio and \"array\" in answer_audio:\n            audio_array = answer_audio[\"array\"]\n            codes_list = tokenise_audio(audio_array)\n    except Exception as e:\n        print(f\"Skipping row due to error: {e}\")\n        # Keep codes_list as None if we fail\n    example[\"codes_list\"] = codes_list\n\n    return example\n\nds = ds.map(add_codes, remove_columns=[\"audio\"])\n\n#@title Load Tokenizer\ntokeniser_length = 128256\nstart_of_text = 128000\nend_of_text = 128009\n\nstart_of_speech = tokeniser_length + 1\nend_of_speech = tokeniser_length + 2\n\nstart_of_human = tokeniser_length + 3\nend_of_human = tokeniser_length + 4\n\nstart_of_ai = tokeniser_length + 5\nend_of_ai =  tokeniser_length + 6\npad_token = tokeniser_length + 7\n\naudio_tokens_start = tokeniser_length + 10\n\ntokenizer_name = \"canopylabs/orpheus-3b-0.1-pretrained\"\n\n\ntokenizer = AutoTokenizer.from_pretrained(tokenizer_name)\nnum_proc = os.cpu_count() - 2\n\nds = ds.filter(lambda x: x[\"codes_list\"] is not None)\nds = ds.filter(lambda x: len(x[\"codes_list\"]) > 0)\n\n#@title Create Input Ids\ndef remove_duplicate_frames(example):\n    vals = example[\"codes_list\"]\n    if len(vals) % 7 != 0:\n        raise ValueError(\"Input list length must be divisible by 7\")\n\n    result = vals[:7]\n\n    removed_frames = 0\n\n    for i in range(7, len(vals), 7):\n        current_first = vals[i]\n        previous_first = result[-7]\n\n        if current_first != previous_first:\n            result.extend(vals[i:i+7])\n        else:\n            removed_frames += 1\n\n    example[\"codes_list\"] = result\n\n    return example\n\nds = ds.map(remove_duplicate_frames, num_proc=num_proc)\n\n\ndef create_input_ids(example):\n    text_ids = tokenizer.encode({example['text']},  add_special_tokens=True)\n    text_ids.append(end_of_text)\n    example[\"text_tokens\"] = text_ids\n    input_ids = (\n        [start_of_human]\n        + example[\"text_tokens\"]\n        + [end_of_human]\n        + [start_of_ai]\n        + [start_of_speech]\n        + example[\"codes_list\"]\n        + [end_of_speech]\n        + [end_of_ai]\n    )\n    example[\"input_ids\"] = input_ids\n    example[\"labels\"] = input_ids\n    example[\"attention_mask\"] = [1] * len(input_ids)\n\n    return example\n\nds = ds.map(create_input_ids, num_proc=num_proc, remove_columns=[\"text\", \"codes_list\"])\n\n#@title Remove unnecessary columns\ncolumns_to_keep = [\"input_ids\", \"labels\", \"attention_mask\"]\ncolumns_to_remove = [col for col in ds.column_names if col not in columns_to_keep]\n\nds = ds.remove_columns(columns_to_remove)\n\nds.push_to_hub(name_to_push_dataset_to)\n```\n\n\n## Finetune pre-processing\nUse this code to add a new voice.\n\n```python\nimport torch\nfrom snac import SNAC\nfrom datasets import load_dataset\nfrom huggingface_hub import snapshot_download\nfrom datasets import load_dataset\nimport random\nimport torchaudio.transforms as T\nfrom transformers import AutoTokenizer\nimport os\n\nmy_original_dataset_name = \"<huggingface-id-of-dataset-that-we-want-to-preprocess>\"\nname_to_push_dataset_to = \"<huggingface-id-of-where-to-save-dataset>\"\n\ndsn = my_original_dataset_name\n\nsnapshot_download(\n    repo_id=dsn,\n    repo_type=\"dataset\",\n    revision=\"main\",\n    max_workers=64,\n)\n\n\nds = load_dataset(dsn, split=\"train\")\nds_sample_rate = ds[0][\"audio\"][\"sampling_rate\"]\n\nmodel = SNAC.from_pretrained(\"hubertsiuzdak/snac_24khz\")\nmodel = model.to(\"mps\")\n\ndef tokenise_audio(waveform):\n  waveform = torch.from_numpy(waveform).unsqueeze(0)\n  waveform = waveform.to(dtype=torch.float32)\n  resample_transform = T.Resample(orig_freq=ds_sample_rate, new_freq=24000)\n  waveform = resample_transform(waveform)\n\n  waveform = waveform.unsqueeze(0).to(\"cuda\")\n\n  #generate the codes from snac\n  with torch.inference_mode():\n    codes = model.encode(waveform)\n\n  all_codes = []\n  for i in range(codes[0].shape[1]):\n    all_codes.append(codes[0][0][i].item()+128266)\n    all_codes.append(codes[1][0][2*i].item()+128266+4096)\n    all_codes.append(codes[2][0][4*i].item()+128266+(2*4096))\n    all_codes.append(codes[2][0][(4*i)+1].item()+128266+(3*4096))\n    all_codes.append(codes[1][0][(2*i)+1].item()+128266+(4*4096))\n    all_codes.append(codes[2][0][(4*i)+2].item()+128266+(5*4096))\n    all_codes.append(codes[2][0][(4*i)+3].item()+128266+(6*4096))\n\n\n  return all_codes\n\ndef add_codes(example):\n    # Always initialize codes_list to None\n    codes_list = None\n\n    try:\n        answer_audio = example.get(\"audio\")\n        # If there's a valid audio array, tokenise it\n        if answer_audio and \"array\" in answer_audio:\n            audio_array = answer_audio[\"array\"]\n            codes_list = tokenise_audio(audio_array)\n    except Exception as e:\n        print(f\"Skipping row due to error: {e}\")\n        # Keep codes_list as None if we fail\n    example[\"codes_list\"] = codes_list\n\n    return example\n\nds = ds.map(add_codes, remove_columns=[\"audio\"])\n\n#@title Load Tokenizer\ntokeniser_length = 128256\nstart_of_text = 128000\nend_of_text = 128009\n\nstart_of_speech = tokeniser_length + 1\nend_of_speech = tokeniser_length + 2\n\nstart_of_human = tokeniser_length + 3\nend_of_human = tokeniser_length + 4\n\nstart_of_ai = tokeniser_length + 5\nend_of_ai =  tokeniser_length + 6\npad_token = tokeniser_length + 7\n\naudio_tokens_start = tokeniser_length + 10\n\ntokenizer_name = \"canopylabs/orpheus-3b-0.1-pretrained\"\n\n\ntokenizer = AutoTokenizer.from_pretrained(tokenizer_name)\nnum_proc = os.cpu_count() - 2\n\nds = ds.filter(lambda x: x[\"codes_list\"] is not None)\nds = ds.filter(lambda x: len(x[\"codes_list\"]) > 0)\n\n#@title Create Input Ids\ndef remove_duplicate_frames(example):\n    vals = example[\"codes_list\"]\n    if len(vals) % 7 != 0:\n        raise ValueError(\"Input list length must be divisible by 7\")\n\n    result = vals[:7]\n\n    removed_frames = 0\n\n    for i in range(7, len(vals), 7):\n        current_first = vals[i]\n        previous_first = result[-7]\n\n        if current_first != previous_first:\n            result.extend(vals[i:i+7])\n        else:\n            removed_frames += 1\n\n    example[\"codes_list\"] = result\n\n    return example\n\nds = ds.map(remove_duplicate_frames, num_proc=num_proc)\n\ntok_info = '''*** HERE you can modify the text prompt\ni.e. if you wanted a multispeaker model like canopylabs/orpheus-3b-0.1-ft, you can pass:\nf\"{example[\"source\"]}:  {example[\"text\"]}\", as is passed.\n'''\nprint(tok_info)\n\ndef create_input_ids(example):\n    text_ids = tokenizer.encode(f\"{example['speaker_id']}: {example['text']}\",  add_special_tokens=True)\n    text_ids.append(end_of_text)\n    example[\"text_tokens\"] = text_ids\n    input_ids = (\n        [start_of_human]\n        + example[\"text_tokens\"]\n        + [end_of_human]\n        + [start_of_ai]\n        + [start_of_speech]\n        + example[\"codes_list\"]\n        + [end_of_speech]\n        + [end_of_ai]\n    )\n    example[\"input_ids\"] = input_ids\n    example[\"labels\"] = input_ids\n    example[\"attention_mask\"] = [1] * len(input_ids)\n\n    return example\n\nds = ds.map(create_input_ids, num_proc=num_proc, remove_columns=[\"text\", \"codes_list\"])\n\n#@title Remove unnecessary columns\ncolumns_to_keep = [\"input_ids\", \"labels\", \"attention_mask\"]\ncolumns_to_remove = [col for col in ds.column_names if col not in columns_to_keep]\n\nds = ds.remove_columns(columns_to_remove)\n\nds.push_to_hub(name_to_push_dataset_to)\n```\n\n## Training\nAfter preprocessing is done, fill out the blanks in finetune.yml and simply run `axolotl train finetune.yml`\n\n## Inference\nFor inference, please refer to the original [orpheus github](https://github.com/canopyai/Orpheus-TTS/tree/main).\n"
  },
  {
    "path": "examples/orpheus/finetune.yml",
    "content": "base_model: canopylabs/orpheus-3b-0.1-pretrained\n\nhub_model_id: <your-hub-model-id>\n\nplugins:\n  - axolotl.integrations.liger.LigerPlugin\nliger_rope: true\nliger_rms_norm: true\nliger_glu_activation: true\nliger_fused_linear_cross_entropy: true\n\ndatasets:\n  - path: <your-hf-dataset-id>\n    type:  # leave empty to load pre-tokenized\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.01\noutput_dir: ./outputs/out\n\nsequence_len: 8192\nsample_packing: true\n\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 8\nmicro_batch_size: 4\nnum_epochs: 3\noptimizer: adamw_torch_fused\nlr_scheduler: cosine\nlearning_rate: 2e-5\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\ngradient_checkpointing_kwargs:\n  use_reentrant: false\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 5\nsaves_per_epoch: 5\nweight_decay: 0.05\n\nspecial_tokens:\n  pad_token: <custom_token_7>\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/phi/README.md",
    "content": "# Phi\n\nDue to some nuances with the phi code, please use deepspeed when training phi for full finetune.\n\n```shell\naccelerate launch -m axolotl.cli.train examples/phi/phi-ft.yml --deepspeed deepspeed_configs/zero1.json\n\n# OR\n\npython -m axolotl.cli.train examples/phi/phi-qlora.yml\n```\n"
  },
  {
    "path": "examples/phi/lora-3.5.yaml",
    "content": "base_model: microsoft/Phi-3.5-mini-instruct\n# optionally might have model_type or tokenizer_type\nmodel_type: AutoModelForCausalLM\ntokenizer_type: AutoTokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nload_in_8bit: true\nload_in_4bit: false\n\nchat_template: phi_3\ndatasets:\n  - path: fozziethebeat/alpaca_messages_2k_test\n    type: chat_template\n\ndataset_prepared_path:\nval_set_size: 0.05\noutput_dir: ./outputs/lora-out\n\nsequence_len: 4096\nsample_packing: false\n\n\nadapter: lora\nlora_model_dir:\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_linear: true\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 4\nnum_epochs: 2\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbfloat16: true\nbf16: true\nfp16:\ntf32: false\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\n\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 4\nweight_decay: 0.0\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/phi/phi-ft.yml",
    "content": "base_model: microsoft/phi-1_5\n# optionally might have model_type or tokenizer_type\nmodel_type: AutoModelForCausalLM\ntokenizer_type: AutoTokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\ndatasets:\n  - path: garage-bAInd/Open-Platypus\n    type: alpaca\n\ndataset_prepared_path:\nval_set_size: 0.05\noutput_dir: ./outputs/phi-sft-out\n\nsequence_len: 2048\nsample_packing: true\n\n\nadapter:\nlora_model_dir:\nlora_r:\nlora_alpha:\nlora_dropout:\nlora_target_linear:\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 1\nmicro_batch_size: 2\nnum_epochs: 4\noptimizer: adamw_torch_fused\nadam_beta2: 0.95\nadam_epsilon: 0.00001\nmax_grad_norm: 1.0\nlr_scheduler: cosine\nlearning_rate: 0.000003\n\nbf16: auto\ntf32: true\n\ngradient_checkpointing: true\ngradient_checkpointing_kwargs:\n  use_reentrant: True\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\nweight_decay: 0.1\nresize_token_embeddings_to_32x: true\nspecial_tokens:\n  pad_token: \"<|endoftext|>\"\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/phi/phi-qlora.yml",
    "content": "base_model: microsoft/phi-1_5\n# optionally might have model_type or tokenizer_type\nmodel_type: AutoModelForCausalLM\ntokenizer_type: AutoTokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nload_in_8bit: false\nload_in_4bit: true\n\ndatasets:\n  - path: garage-bAInd/Open-Platypus\n    type: alpaca\n\ndataset_prepared_path:\nval_set_size: 0.05\noutput_dir: ./outputs/phi-sft-out\n\nsequence_len: 2048\nsample_packing: true\n\n\nadapter: qlora\nlora_model_dir:\nlora_r: 64\nlora_alpha: 32\nlora_dropout: 0.05\nlora_target_linear: true\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 1\nmicro_batch_size: 2\nnum_epochs: 4\noptimizer: adamw_torch_fused\nadam_beta2: 0.95\nadam_epsilon: 0.00001\nmax_grad_norm: 1.0\nlr_scheduler: cosine\nlearning_rate: 0.000003\n\nbf16: auto\ntf32: true\n\ngradient_checkpointing: true\ngradient_checkpointing_kwargs:\n  use_reentrant: True\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\nweight_decay: 0.1\nresize_token_embeddings_to_32x: true\nspecial_tokens:\n  pad_token: \"<|endoftext|>\"\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/phi/phi2-ft.yml",
    "content": "base_model: microsoft/phi-2\n# optionally might have model_type or tokenizer_type\nmodel_type: AutoModelForCausalLM\ntokenizer_type: AutoTokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\ndatasets:\n  - path: garage-bAInd/Open-Platypus\n    type: alpaca\n\ndataset_prepared_path:\nval_set_size: 0.05\noutput_dir: ./outputs/phi-sft-out\n\nsequence_len: 2048\nsample_packing: true\n\n\nadapter:\nlora_model_dir:\nlora_r:\nlora_alpha:\nlora_dropout:\nlora_target_linear:\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 1\nmicro_batch_size: 2\nnum_epochs: 4\noptimizer: adamw_torch_fused\nadam_beta2: 0.95\nadam_epsilon: 0.00001\nmax_grad_norm: 1.0\nlr_scheduler: cosine\nlearning_rate: 0.000003\n\nbf16: auto\ntf32: true\n\ngradient_checkpointing: true\ngradient_checkpointing_kwargs:\n  use_reentrant: True\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\nweight_decay: 0.1\nresize_token_embeddings_to_32x: true\nspecial_tokens:\n  pad_token: \"<|endoftext|>\"\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/phi/phi3-ft-fsdp.yml",
    "content": "base_model: microsoft/Phi-3-mini-4k-instruct\n# optionally might have model_type or tokenizer_type\nmodel_type: AutoModelForCausalLM\ntokenizer_type: AutoTokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\ndatasets:\n  - path: mhenrichsen/alpaca_2k_test\n    type: alpaca\n\ndataset_prepared_path:\nval_set_size: 0\noutput_dir: ./phi-sft-out\n\nsequence_len: 4096\nsample_packing: true\n\ntrust_remote_code: true\n\nadapter:\nlora_model_dir:\nlora_r:\nlora_alpha:\nlora_dropout:\nlora_target_linear:\n\nwandb_project: phi3\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 2\nmicro_batch_size: 12\nnum_epochs: 2\noptimizer: adamw_torch_fused\nadam_beta2: 0.95\nadam_epsilon: 0.00001\nmax_grad_norm: 1.0\nlr_scheduler: cosine\nlearning_rate: 0.000003\n\nbf16: auto\ntf32: true\n\ngradient_checkpointing: true\ngradient_checkpointing_kwargs:\n  use_reentrant: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\nweight_decay: 0.1\nfsdp:\n  - full_shard\n  - auto_wrap\nfsdp_config:\n  fsdp_limit_all_gathers: true\n  fsdp_sync_module_states: true\n  fsdp_offload_params: true\n  fsdp_use_orig_params: false\n  fsdp_cpu_ram_efficient_loading: true\n  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP\n  fsdp_transformer_layer_cls_to_wrap: Phi3DecoderLayer\n  fsdp_state_dict_type: FULL_STATE_DICT\n  fsdp_sharding_strategy: FULL_SHARD\nresize_token_embeddings_to_32x: true\nspecial_tokens:\n  pad_token: \"<|endoftext|>\"\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/phi/phi3-ft.yml",
    "content": "base_model: microsoft/Phi-3-mini-4k-instruct\n# optionally might have model_type or tokenizer_type\ntrust_remote_code: true\nmodel_type: AutoModelForCausalLM\ntokenizer_type: AutoTokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nchat_template: phi_3\n\ndatasets:\n  - path: garage-bAInd/Open-Platypus\n    type: alpaca:phi\n\ndataset_prepared_path:\nval_set_size: 0.01\noutput_dir: ./out\n\nsequence_len: 4096\nsample_packing: true\n\n\nadapter: lora\nlora_model_dir:\nlora_r: 64\nlora_alpha: 32\nlora_dropout: 0.05\nlora_target_linear: true\n\ngradient_accumulation_steps: 1\nmicro_batch_size: 2\nnum_epochs: 1\noptimizer: adamw_torch_fused\nadam_beta2: 0.95\nadam_epsilon: 0.00001\nmax_grad_norm: 1.0\nlr_scheduler: cosine\nlearning_rate: 5.0e-6\n\nbf16: auto\n\ngradient_checkpointing: true\ngradient_checkpointing_kwargs:\n  use_reentrant: True\nearly_stopping_patience: 3\nlogging_steps: 1\nflash_attention: true\n\neval_steps: 1000\nsave_steps: 5000\neval_batch_size: 2\neval_sample_packing: false\neval_table_size: 2\neval_max_new_tokens: 32\neval_causal_lm_metrics: [\"perplexity\"]\ndo_causal_lm_eval: true\n\nwarmup_ratio: 0.2\ndebug: true\nweight_decay: 0.1\nresize_token_embeddings_to_32x: true\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/pixtral/lora-12b.yml",
    "content": "base_model: mistral-community/pixtral-12b\nprocessor_type: AutoProcessor\n\n# these 3 lines are needed for now to handle vision chat templates w images\nskip_prepare_dataset: true\nremove_unused_columns: false\nsample_packing: false\n\nchat_template: pixtral\ndatasets:\n  - path: HuggingFaceH4/llava-instruct-mix-vsft\n    type: chat_template\n    split: train[:1%]\ndataset_prepared_path:\nval_set_size: 0.0\noutput_dir: ./outputs/out\n\nadapter: lora\nlora_model_dir:\n\nsequence_len: 8192\npad_to_sequence_len: false\n\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_modules: 'model.language_model.layers.[\\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 1\nnum_epochs: 1\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: true\nfp16:\ntf32: true\n\ngradient_checkpointing: true\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 1\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n  pad_token: <pad>\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/plano/README.md",
    "content": "# Finetune Katanemo's Plano-Orchestrator with Axolotl\n\n[Plano-Orchestrator](https://huggingface.co/collections/katanemo/plano-orchestrator) is a family of 4B and 30B-A3B routing and orchestration models designed for multi-agent systems. It analyzes user intent and conversation context to make precise routing decisions, excelling at multi-turn context understanding, multi-intent detection, and context-dependent routing.\n\nThis guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.\n\n## Getting started\n\n1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).\n\n2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.\n\n3. Run the finetuning example:\n\n    ```bash\n    axolotl train examples/plano/plano-4b-qlora.yaml\n    ```\n\nThis config uses about 5.1 GiB VRAM. Let us know how it goes. Happy finetuning! 🚀\n\n### Orchestration Prompt\n\nPlano-Orchestrator uses a specific orchestration prompt format for routing/agent decisions. Please check the [official model card](https://huggingface.co/katanemo/Plano-Orchestrator-4B) for proper prompt formatting and the `ORCHESTRATION_PROMPT` template.\n\n### Tips\n\n- To use the larger [Plano-Orchestrator-30B-A3B](https://huggingface.co/katanemo/Plano-Orchestrator-30B-A3B) MoE model, simply change `base_model: katanemo/Plano-Orchestrator-30B-A3B` in the config and enable multi-GPU training if needed.\n- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.\n- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).\n- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).\n\n## Optimization Guides\n\nPlease check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).\n\n## Related Resources\n\n- [Plano GitHub](https://github.com/katanemo/plano)\n- [Axolotl Docs](https://docs.axolotl.ai)\n- [Axolotl Website](https://axolotl.ai)\n- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)\n- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)\n"
  },
  {
    "path": "examples/plano/plano-4b-qlora.yaml",
    "content": "base_model: katanemo/Plano-Orchestrator-4B\n\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nplugins:\n  - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin\n\nload_in_8bit: false\nload_in_4bit: true\n\nchat_template: qwen3\ndatasets:\n  - path: fozziethebeat/alpaca_messages_2k_test\n    type: chat_template\n\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.1\noutput_dir: ./outputs/lora-out\n\nadapter: qlora\nlora_model_dir:\n\nsequence_len: 2048\nsample_packing: true\n\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_linear: true\nlora_target_modules:\n  - gate_proj\n  - down_proj\n  - up_proj\n  - q_proj\n  - v_proj\n  - k_proj\n  - o_proj\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 2\nnum_epochs: 1\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 1\nsaves_per_epoch: 1\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/qat_nvfp4/Gemma3-12B_baseline.yml",
    "content": "base_model: google/gemma-3-12b-it\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nload_in_8bit: false\nload_in_4bit: false\nstrict: false\n\nplugins:\n  - axolotl.integrations.liger.LigerPlugin\n\nliger_rope: true\nliger_rms_norm: true\nliger_glu_activation: true\nliger_layer_norm: true\nliger_fused_linear_cross_entropy: true\nseed: 42\nchat_template: gemma3\ndatasets:\n  - path: tatsu-lab/alpaca\n    type: alpaca\n\noutput_dir: ./outputs/out_gemma/\n\nsequence_len: 8096\nsample_packing: true\nflash_attention: true\n\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 1\nmicro_batch_size: 16\n\nnum_epochs: 1\noptimizer: adamw_torch_fused\nlr_scheduler: cosine\nlearning_rate: 4e-5\n\nbf16: true\ntf32: true\n\nresume_from_checkpoint:\nlogging_steps: 1\n\n# evals_per_epoch: 1\nsaves_per_epoch: 1\n\nwarmup_ratio: 0.1\nweight_decay: 0.0\nfsdp_version: 2\n\nfsdp_config:\n  offload_params: false\n  cpu_ram_efficient_loading: true\n  auto_wrap_policy: TRANSFORMER_BASED_WRAP\n  transformer_layer_cls_to_wrap: Gemma3DecoderLayer\n  state_dict_type: FULL_STATE_DICT\n  sharding_strategy: FULL_SHARD\n  reshard_after_forward: true\n  activation_checkpointing: true\n\nspecial_tokens:\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/qat_nvfp4/Gemma3-12B_qat.yml",
    "content": "base_model: google/gemma-3-12b-it\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nload_in_8bit: false\nload_in_4bit: false\nstrict: false\n\nplugins:\n  - axolotl.integrations.liger.LigerPlugin\n\nliger_rope: true\nliger_rms_norm: true\nliger_glu_activation: true\nliger_layer_norm: true\nliger_fused_linear_cross_entropy: true\nseed: 42\nchat_template: gemma3\ndatasets:\n  - path: tatsu-lab/alpaca\n    type: alpaca\n\noutput_dir: ./outputs/qat_out_gemma/\n\nsequence_len: 8096\nsample_packing: true\nflash_attention: true\n\nqat:\n  activation_dtype: nvfp4\n  weight_dtype: nvfp4\n  group_size: 16 # only group_size of 16 is supported with nvfp4\n\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 1\nmicro_batch_size: 16\n\nnum_epochs: 1\noptimizer: adamw_torch_fused\nlr_scheduler: cosine\nlearning_rate: 4e-5\n\nbf16: true\ntf32: true\n\nresume_from_checkpoint:\nlogging_steps: 1\n\nevals_per_epoch: 1\nsaves_per_epoch: 1\n\nwarmup_ratio: 0.1\nweight_decay: 0.0\nfsdp_version: 2\n\nfsdp_config:\n  offload_params: false\n  cpu_ram_efficient_loading: true\n  auto_wrap_policy: TRANSFORMER_BASED_WRAP\n  transformer_layer_cls_to_wrap: Gemma3DecoderLayer\n  state_dict_type: FULL_STATE_DICT\n  sharding_strategy: FULL_SHARD\n  reshard_after_forward: true\n  activation_checkpointing: true\n\nspecial_tokens:\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/qat_nvfp4/Math-Gemma3-12B_baseline.yml",
    "content": "base_model: google/gemma-3-12b-it\n# Math finetuning configuration for Gemma3-12B\n# hub_model_id: username/custom_model_name\n\nload_in_8bit: false\nload_in_4bit: false\nstrict: false\n\nplugins:\n  - axolotl.integrations.liger.LigerPlugin\n\nliger_rope: true\nliger_rms_norm: true\nliger_glu_activation: true\nliger_layer_norm: true\nliger_fused_linear_cross_entropy: true\nseed: 42\nchat_template: gemma3\ndatasets:\n  - path: AI-MO/NuminaMath-CoT\n    type: chat_template\n\noutput_dir: ./outputs/out_math_gemma/\n\nsequence_len: 4096\nsample_packing: true\nflash_attention: true\n\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 1\nmicro_batch_size: 8\n\nnum_epochs: 1\noptimizer: adamw_torch_fused\nlr_scheduler: cosine\nlearning_rate: 3e-5\n\nbf16: true\ntf32: true\n\nresume_from_checkpoint:\nlogging_steps: 1\n\n# evals_per_epoch: 1\nsaves_per_epoch: 1\n\nwarmup_ratio: 0.1\nweight_decay: 0.0\nfsdp_version: 2\n\nfsdp_config:\n  offload_params: false\n  cpu_ram_efficient_loading: true\n  auto_wrap_policy: TRANSFORMER_BASED_WRAP\n  transformer_layer_cls_to_wrap: Gemma3DecoderLayer\n  state_dict_type: FULL_STATE_DICT\n  sharding_strategy: FULL_SHARD\n  reshard_after_forward: true\n  activation_checkpointing: true\n\nspecial_tokens:\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/qat_nvfp4/Math-Gemma3-12B_qat.yml",
    "content": "base_model: google/gemma-3-12b-it\n# Math finetuning configuration for Gemma3-12B\n# hub_model_id: username/custom_model_name\n\nload_in_8bit: false\nload_in_4bit: false\nstrict: false\n\nplugins:\n  - axolotl.integrations.liger.LigerPlugin\n\nliger_rope: true\nliger_rms_norm: true\nliger_glu_activation: true\nliger_layer_norm: true\nliger_fused_linear_cross_entropy: true\nseed: 42\nchat_template: gemma3\ndatasets:\n  - path: AI-MO/NuminaMath-CoT\n    type: chat_template\n\noutput_dir: ./outputs/qat_out_math_gemma/\n\nsequence_len: 4096\nsample_packing: true\nflash_attention: true\n\nqat:\n  activation_dtype: nvfp4\n  weight_dtype: nvfp4\n  group_size: 16 # only group_size of 16 is supported with nvfp4\n\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 1\nmicro_batch_size: 8\n\nnum_epochs: 1\noptimizer: adamw_torch_fused\nlr_scheduler: cosine\nlearning_rate: 3e-5\n\nbf16: true\ntf32: true\n\nresume_from_checkpoint:\nlogging_steps: 1\n\n# evals_per_epoch: 1\nsaves_per_epoch: 1\n\nwarmup_ratio: 0.1\nweight_decay: 0.0\nfsdp_version: 2\n\nfsdp_config:\n  offload_params: false\n  cpu_ram_efficient_loading: true\n  auto_wrap_policy: TRANSFORMER_BASED_WRAP\n  transformer_layer_cls_to_wrap: Gemma3DecoderLayer\n  state_dict_type: FULL_STATE_DICT\n  sharding_strategy: FULL_SHARD\n  reshard_after_forward: true\n  activation_checkpointing: true\n\nspecial_tokens:\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/qat_nvfp4/Math-Gemma3-27B_baseline.yml",
    "content": "base_model: google/gemma-3-27b-it\n# Math finetuning configuration for Gemma3-27B\n# hub_model_id: username/custom_model_name\n\nload_in_8bit: false\nload_in_4bit: false\nstrict: false\n\nplugins:\n  - axolotl.integrations.liger.LigerPlugin\n\nliger_rope: true\nliger_rms_norm: true\nliger_glu_activation: true\nliger_layer_norm: true\nliger_fused_linear_cross_entropy: true\nseed: 42\nchat_template: gemma3\ndatasets:\n  - path: AI-MO/NuminaMath-CoT\n    type: chat_template\n\noutput_dir: ./outputs/out_math_gemma27/\n\nsequence_len: 4096\nsample_packing: true\nflash_attention: true\n\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 1\nmicro_batch_size: 16\n\nnum_epochs: 1\noptimizer: adamw_torch_fused\nlr_scheduler: cosine\nlearning_rate: 5e-6\neta_min: 7e-7\n\nbf16: true\ntf32: true\n\nresume_from_checkpoint:\nlogging_steps: 1\n\n# evals_per_epoch: 1\nsaves_per_epoch: 1\n\nwarmup_ratio: 0.1\nweight_decay: 0.0\nfsdp_version: 2\n\nfsdp_config:\n  offload_params: false\n  cpu_ram_efficient_loading: true\n  auto_wrap_policy: TRANSFORMER_BASED_WRAP\n  transformer_layer_cls_to_wrap: Gemma3DecoderLayer\n  state_dict_type: FULL_STATE_DICT\n  sharding_strategy: FULL_SHARD\n  reshard_after_forward: true\n  activation_checkpointing: true\n\nspecial_tokens:\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/qat_nvfp4/Math-Gemma3-27B_qat.yml",
    "content": "base_model: google/gemma-3-27b-it\n# Math finetuning configuration for Gemma3-27B\n# hub_model_id: username/custom_model_name\n\nload_in_8bit: false\nload_in_4bit: false\nstrict: false\n\nplugins:\n  - axolotl.integrations.liger.LigerPlugin\n\nliger_rope: true\nliger_rms_norm: true\nliger_glu_activation: true\nliger_layer_norm: true\nliger_fused_linear_cross_entropy: true\nseed: 42\nchat_template: gemma3\ndatasets:\n  - path: AI-MO/NuminaMath-CoT\n    type: chat_template\n\noutput_dir: ./outputs/qat_out_math_gemma27/\n\nsequence_len: 4096\nsample_packing: true\nflash_attention: true\n\nqat:\n  activation_dtype: nvfp4\n  weight_dtype: nvfp4\n  group_size: 16 # only group_size of 16 is supported with nvfp4\n\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 1\nmicro_batch_size: 16\n\nnum_epochs: 1\noptimizer: adamw_torch_fused\nlr_scheduler: cosine\nlearning_rate: 5e-6\neta_min: 7e-7\n\nbf16: true\ntf32: true\n\nresume_from_checkpoint:\nlogging_steps: 1\n\n# evals_per_epoch: 1\nsaves_per_epoch: 1\n\nwarmup_ratio: 0.1\nweight_decay: 0.0\nfsdp_version: 2\n\nfsdp_config:\n  offload_params: false\n  cpu_ram_efficient_loading: true\n  auto_wrap_policy: TRANSFORMER_BASED_WRAP\n  transformer_layer_cls_to_wrap: Gemma3DecoderLayer\n  state_dict_type: FULL_STATE_DICT\n  sharding_strategy: FULL_SHARD\n  reshard_after_forward: true\n  activation_checkpointing: true\n\nspecial_tokens:\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/qat_nvfp4/Math-Qwen2.5-72B_baseline.yml",
    "content": "base_model: Qwen/Qwen2.5-72B\n# Math finetuning configuration for Qwen2.5-72B (non-instruct)\n# hub_model_id: username/custom_model_name\n\nload_in_8bit: false\nload_in_4bit: false\nstrict: false\n\nplugins:\n  - axolotl.integrations.liger.LigerPlugin\n\nliger_rope: true\nliger_rms_norm: true\nliger_glu_activation: true\nliger_layer_norm: true\nliger_fused_linear_cross_entropy: true\nseed: 42\nchat_template: qwen_25\ndatasets:\n  - path: AI-MO/NuminaMath-CoT\n    type: chat_template\n\noutput_dir: ./outputs/out_math_72b/\n\nsequence_len: 4096\nsample_packing: true\nflash_attention: true\n\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 1\nmicro_batch_size: 8\nnum_epochs: 1\noptimizer: adamw_torch_fused\nlr_scheduler: cosine\nlearning_rate: 5e-6\neta_min: 7e-7\n\nbf16: true\ntf32: true\n\nresume_from_checkpoint:\nlogging_steps: 1\n\n# evals_per_epoch: 1\nsaves_per_epoch: 1\n\nwarmup_ratio: 0.1\nweight_decay: 0.0\nfsdp_version: 2\n\nfsdp_config:\n  offload_params: false\n  cpu_ram_efficient_loading: true\n  auto_wrap_policy: TRANSFORMER_BASED_WRAP\n  transformer_layer_cls_to_wrap: Qwen2DecoderLayer\n  state_dict_type: FULL_STATE_DICT\n  sharding_strategy: FULL_SHARD\n  reshard_after_forward: true\n  activation_checkpointing: true\n\nspecial_tokens:\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/qat_nvfp4/Math-Qwen2.5-72B_qat.yml",
    "content": "base_model: Qwen/Qwen2.5-72B\n# Math finetuning configuration for Qwen2.5-72B (non-instruct)\n# hub_model_id: username/custom_model_name\n\nload_in_8bit: false\nload_in_4bit: false\nstrict: false\n\nplugins:\n  - axolotl.integrations.liger.LigerPlugin\n\nliger_rope: true\nliger_rms_norm: true\nliger_glu_activation: true\nliger_layer_norm: true\nliger_fused_linear_cross_entropy: true\nseed: 42\nchat_template: qwen_25\ndatasets:\n  - path: AI-MO/NuminaMath-CoT\n    type: chat_template\n\noutput_dir: ./outputs/qat_out_math_72b/\n\nsequence_len: 4096\nsample_packing: true\nflash_attention: true\n\nqat:\n  activation_dtype: nvfp4\n  weight_dtype: nvfp4\n  group_size: 16 # only group_size of 16 is supported with nvfp4\n\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 1\nmicro_batch_size: 8\nnum_epochs: 1\noptimizer: adamw_torch_fused\nlr_scheduler: cosine\nlearning_rate: 5e-6\neta_min: 7e-7\n\nbf16: true\ntf32: true\n\nresume_from_checkpoint:\nlogging_steps: 1\n\n# evals_per_epoch: 1\nsaves_per_epoch: 1\n\nwarmup_ratio: 0.1\nweight_decay: 0.0\nfsdp_version: 2\n\nfsdp_config:\n  offload_params: false\n  cpu_ram_efficient_loading: true\n  auto_wrap_policy: TRANSFORMER_BASED_WRAP\n  transformer_layer_cls_to_wrap: Qwen2DecoderLayer\n  state_dict_type: FULL_STATE_DICT\n  sharding_strategy: FULL_SHARD\n  reshard_after_forward: true\n  activation_checkpointing: true\n\nspecial_tokens:\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/qat_nvfp4/Qwen2.5-72B_baseline.yml",
    "content": "base_model: Qwen/Qwen2.5-72B\n# Alpaca finetuning configuration for Qwen2.5-72B\n# hub_model_id: username/custom_model_name\n\nload_in_8bit: false\nload_in_4bit: false\nstrict: false\n\nplugins:\n  - axolotl.integrations.liger.LigerPlugin\n\nliger_rope: true\nliger_rms_norm: true\nliger_glu_activation: true\nliger_layer_norm: true\nliger_fused_linear_cross_entropy: true\nseed: 42\nchat_template: qwen_25\ndatasets:\n  - path: tatsu-lab/alpaca\n    type: alpaca\n\noutput_dir: ./outputs/out_qwen72b/\n\nsequence_len: 8096\nsample_packing: true\nflash_attention: true\n\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 1\nmicro_batch_size: 16\n\nnum_epochs: 1\noptimizer: adamw_torch_fused\nlr_scheduler: cosine\nlearning_rate: 2e-5\n\nbf16: true\ntf32: true\n\nresume_from_checkpoint:\nlogging_steps: 1\n\n# evals_per_epoch: 1\nsaves_per_epoch: 1\n\nwarmup_ratio: 0.1\nweight_decay: 0.0\nfsdp_version: 2\n\nfsdp_config:\n  offload_params: false\n  cpu_ram_efficient_loading: true\n  auto_wrap_policy: TRANSFORMER_BASED_WRAP\n  transformer_layer_cls_to_wrap: Qwen2DecoderLayer\n  state_dict_type: FULL_STATE_DICT\n  sharding_strategy: FULL_SHARD\n  reshard_after_forward: true\n  activation_checkpointing: true\n\nspecial_tokens:\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/qat_nvfp4/Qwen2.5-72B_qat.yml",
    "content": "base_model: Qwen/Qwen2.5-72B\n# Alpaca finetuning configuration for Qwen2.5-72B\n# hub_model_id: username/custom_model_name\n\nload_in_8bit: false\nload_in_4bit: false\nstrict: false\n\nplugins:\n  - axolotl.integrations.liger.LigerPlugin\n\nliger_rope: true\nliger_rms_norm: true\nliger_glu_activation: true\nliger_layer_norm: true\nliger_fused_linear_cross_entropy: true\nseed: 42\nchat_template: qwen_25\ndatasets:\n  - path: tatsu-lab/alpaca\n    type: alpaca\n\noutput_dir: ./outputs/qat_out_qwen72b/\n\nsequence_len: 8096\nsample_packing: true\nflash_attention: true\n\nqat:\n  activation_dtype: nvfp4\n  weight_dtype: nvfp4\n  group_size: 16 # only group_size of 16 is supported with nvfp4\n\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 1\nmicro_batch_size: 16\n\nnum_epochs: 1\noptimizer: adamw_torch_fused\nlr_scheduler: cosine\nlearning_rate: 2e-5\n\nbf16: true\ntf32: true\n\nresume_from_checkpoint:\nlogging_steps: 1\n\n# evals_per_epoch: 1\nsaves_per_epoch: 1\n\nwarmup_ratio: 0.1\nweight_decay: 0.0\nfsdp_version: 2\n\nfsdp_config:\n  offload_params: false\n  cpu_ram_efficient_loading: true\n  auto_wrap_policy: TRANSFORMER_BASED_WRAP\n  transformer_layer_cls_to_wrap: Qwen2DecoderLayer\n  state_dict_type: FULL_STATE_DICT\n  sharding_strategy: FULL_SHARD\n  reshard_after_forward: true\n  activation_checkpointing: true\n\nspecial_tokens:\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/qwen2/adamw-pretrain-fsdp2.yaml",
    "content": "base_model: Qwen/Qwen2.5-0.5B\nmodel_type: AutoModelForCausalLM\ntokenizer_type: AutoTokenizer\n\n# Use random initialization for fair comparison\nreinit_weights: true\n\nload_in_8bit: false\nload_in_4bit: false\nstrict: false\n\n# Pretraining dataset\npretraining_dataset:\n  - path: allenai/c4\n    name: en\n    type: pretrain\n    split: train\n\ndataset_prepared_path:\nval_set_size: 0.0\noutput_dir: ./outputs/compare-adamw-pretrain\n\nsequence_len: 2048\nsample_packing: true\npad_to_sequence_len: true\n\nwandb_project: dist_muon\nwandb_entity:\nwandb_watch:\nwandb_name: adamw\nwandb_log_model:\n\ngradient_accumulation_steps: 1\nmicro_batch_size: 4\nnum_epochs: 1\nmax_steps: 305\n\n# AdamW optimizer settings (standard LR for AdamW)\noptimizer: adamw_torch_fused\nlearning_rate: 0.0002\nweight_decay: 0.01\nlr_scheduler: cosine\n\ntrain_on_inputs: true\ngroup_by_length: false\nbf16: auto\nfp16: false\ntf32: false\n\ngradient_checkpointing: false\nlogging_steps: 1\nflash_attention: true\n\nwarmup_steps: 10\nevals_per_epoch: 0\nsaves_per_epoch: 1\n\n# Reproducibility\nseed: 42\n\nfsdp_config:\n  fsdp_version: 2\n  fsdp_offload_params: false\n  fsdp_state_dict_type: FULL_STATE_DICT\n  fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer\n  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP\n  fsdp_cpu_ram_efficient_loading: false\n  fsdp_reshard_after_forward: true\n\nspecial_tokens:\n"
  },
  {
    "path": "examples/qwen2/dpo.yaml",
    "content": "base_model: Qwen/Qwen2.5-0.5B\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nchat_template: qwen_25\nrl: dpo\ndatasets:\n  - path: fozziethebeat/alpaca_messages_2k_dpo_test\n    type: chat_template.default\n    field_messages: conversation\n    field_chosen: chosen\n    field_rejected: rejected\n    message_property_mappings:\n      role: role\n      content: content\n    roles:\n      system:\n        - system\n      user:\n        - user\n      assistant:\n        - assistant\n\ndataset_prepared_path:\nval_set_size: 0.0\noutput_dir: ./outputs/dpo-out\n\nsequence_len: 2048\nsample_packing: false\n\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 2\nnum_epochs: 4\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\nweight_decay: 0.0\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/qwen2/muon-pretrain-fsdp2.yaml",
    "content": "base_model: Qwen/Qwen2.5-0.5B\nmodel_type: AutoModelForCausalLM\ntokenizer_type: AutoTokenizer\n\n# Use random initialization for fair comparison\nreinit_weights: true\n\nload_in_8bit: false\nload_in_4bit: false\nstrict: false\n\n# Pretraining dataset\npretraining_dataset:\n  - path: allenai/c4\n    name: en\n    type: pretrain\n    split: train\n\ndataset_prepared_path:\nval_set_size: 0.0\noutput_dir: ./outputs/compare-muon-pretrain\n\nsequence_len: 2048\nsample_packing: true\npad_to_sequence_len: true\n\nwandb_project: dist_muon\nwandb_entity:\nwandb_watch:\nwandb_name: muon\nwandb_log_model:\n\ngradient_accumulation_steps: 1\nmicro_batch_size: 4\nnum_epochs: 1\nmax_steps: 305\n\n# Muon optimizer settings\noptimizer: muon\nlearning_rate: 0.02\nweight_decay: 0.01\nlr_scheduler: cosine\n\ntrain_on_inputs: true\ngroup_by_length: false\nbf16: auto\nfp16: false\ntf32: false\n\ngradient_checkpointing: false\nlogging_steps: 1\nflash_attention: true\n\nwarmup_steps: 10\nevals_per_epoch: 0\nsaves_per_epoch: 1\n\n# Reproducibility\nseed: 42\n\nfsdp_config:\n  fsdp_version: 2\n  fsdp_offload_params: false\n  fsdp_state_dict_type: FULL_STATE_DICT\n  fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer\n  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP\n  fsdp_cpu_ram_efficient_loading: false\n  fsdp_reshard_after_forward: true\n\nspecial_tokens:\n"
  },
  {
    "path": "examples/qwen2/prm.yaml",
    "content": "base_model: Qwen/Qwen2.5-3B\n# optionally might have model_type or tokenizer_type\nmodel_type: AutoModelForTokenClassification\nnum_labels: 2\ntokenizer_type: AutoTokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nprocess_reward_model: true\nchat_template:\ndatasets:\n  - path: trl-lib/math_shepherd\n    type: stepwise_supervised\n    step_separator: \"\\n\"\n    max_completion_length:\n    train_on_last_step_only: false\n\nval_set_size: 0.2\noutput_dir: ./outputs/out\nremove_unused_columns: false\n\nsequence_len: 2048\nsample_packing: false\neval_sample_packing: false\n\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\n\ngradient_accumulation_steps: 1\nmicro_batch_size: 8\neval_batch_size: 8\nnum_epochs: 1\noptimizer: adamw_torch\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: true\nfp16:\ntf32:\ngradient_checkpointing: true\ngradient_checkpointing_kwargs:\n  use_reentrant: false\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch:\neval_steps: 100\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/qwen2/qlora-fsdp.yaml",
    "content": "base_model: Qwen/Qwen2-7B\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\ntrust_remote_code: true\n\nload_in_8bit: false\nload_in_4bit: true\n\ndatasets:\n  - path: tatsu-lab/alpaca\n    type: alpaca\ndataset_prepared_path:\nval_set_size: 0.05\noutput_dir: ./outputs/out\n\nsequence_len: 2048\nsample_packing: true\neval_sample_packing: true\n\n\nadapter: qlora\nlora_model_dir:\nlora_r: 32\nlora_alpha: 64\nlora_dropout: 0.05\nlora_target_linear: true\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 1\nnum_epochs: 4\noptimizer: adamw_torch_fused\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: true\n\ngradient_checkpointing: true\ngradient_checkpointing_kwargs:\n  use_reentrant: false\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\nweight_decay: 0.0\nfsdp:\n  - full_shard\n  - auto_wrap\nfsdp_config:\n  fsdp_limit_all_gathers: true\n  fsdp_sync_module_states: true\n  fsdp_offload_params: true\n  fsdp_use_orig_params: false\n  fsdp_cpu_ram_efficient_loading: true\n  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP\n  fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer\n  fsdp_state_dict_type: FULL_STATE_DICT\n  fsdp_sharding_strategy: FULL_SHARD\nspecial_tokens:\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/qwen2/reward-model.yaml",
    "content": "base_model:  Qwen/Qwen2.5-0.5B\n# optionally might have model_type or tokenizer_type\nmodel_type: AutoModelForSequenceClassification\nnum_labels: 1\ntokenizer_type: AutoTokenizer\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nreward_model: true\nchat_template: qwen_25\ndatasets:\n  - path: argilla/distilabel-intel-orca-dpo-pairs\n    type: bradley_terry.chat_template\nval_set_size: 0.0\noutput_dir: ./outputs/out\nremove_unused_columns: false\n\nsequence_len: 2048\nsample_packing: false\neval_sample_packing: false\n\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 2\nnum_epochs: 4\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: true\nfp16:\ntf32: true\n\ngradient_checkpointing: true\ngradient_checkpointing_kwargs:\n  use_reentrant: false\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch:\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/qwen2-vl/lora-7b.yaml",
    "content": "base_model: Qwen/Qwen2-VL-7B-Instruct\nprocessor_type: AutoProcessor\n\n# these 3 lines are needed for now to handle vision chat templates w images\nskip_prepare_dataset: true\nremove_unused_columns: false\nsample_packing: false\n\nchat_template: qwen2_vl\ndatasets:\n  - path: HuggingFaceH4/llava-instruct-mix-vsft\n    type: chat_template\n    split: train[:1%]\n\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.0\noutput_dir: ./outputs/out\n\nadapter: lora\nlora_model_dir:\n\nsequence_len: 8192\npad_to_sequence_len: false\n\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_modules: 'model.language_model.layers.[\\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 1\nnum_epochs: 1\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: true\nfp16:\ntf32: true\n\ngradient_checkpointing: true\nlogging_steps: 1\nflash_attention: true\neager_attention:\n\nwarmup_ratio: 0.1\nevals_per_epoch: 1\nsaves_per_epoch: 1\nweight_decay: 0.0\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/qwen2_5-vl/lora-7b.yaml",
    "content": "base_model: Qwen/Qwen2.5-VL-7B-Instruct\nprocessor_type: AutoProcessor\n\n# these 3 lines are needed for now to handle vision chat templates w images\nskip_prepare_dataset: true\nremove_unused_columns: false\nsample_packing: false\n\nchat_template: qwen2_vl\ndatasets:\n  - path: HuggingFaceH4/llava-instruct-mix-vsft\n    type: chat_template\n    split: train[:1%]\n\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.0\noutput_dir: ./outputs/out\n\nadapter: lora\nlora_model_dir:\n\nsequence_len: 8192\npad_to_sequence_len: false\n\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_modules: 'model.language_model.layers.[\\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 1\nnum_epochs: 1\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: true\nfp16:\ntf32: true\n\ngradient_checkpointing: true\nlogging_steps: 1\nflash_attention: true\neager_attention:\n\nwarmup_ratio: 0.1\nevals_per_epoch: 1\nsaves_per_epoch: 1\nweight_decay: 0.0\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/qwen3/32b-qlora.yaml",
    "content": "base_model: Qwen/Qwen3-32B\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nplugins:\n  - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin\nstrict: false\n\nchat_template: qwen3\ndatasets:\n  - path: mlabonne/FineTome-100k\n    type: chat_template\n    split: train[:20%]\n    field_messages: conversations\n    message_property_mappings:\n      role: from\n      content: value\nval_set_size: 0.0\noutput_dir: ./outputs/out\ndataset_prepared_path: last_run_prepared\n\nsequence_len: 2048\nsample_packing: true\neval_sample_packing: true\n\n\nload_in_4bit: true\nadapter: qlora\nlora_r: 16\nlora_alpha: 32\nlora_target_modules:\n  - q_proj\n  - k_proj\n  - v_proj\n  - o_proj\n  - down_proj\n  - up_proj\nlora_mlp_kernel: true\nlora_qkv_kernel: true\nlora_o_kernel: true\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 2\nmicro_batch_size: 1\nnum_epochs: 1\noptimizer: adamw_torch_4bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: true\n\ngradient_checkpointing: offload\ngradient_checkpointing_kwargs:\n  use_reentrant: false\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/qwen3/8b-qat-fsdp2.yml",
    "content": "base_model: Qwen/Qwen3-8B\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nload_in_8bit: false\nload_in_4bit: false\nstrict: false\n\nplugins:\n  - axolotl.integrations.liger.LigerPlugin\n\nliger_rope: true\nliger_rms_norm: true\nliger_glu_activation: true\nliger_layer_norm: true\nliger_fused_linear_cross_entropy: true\n\ndatasets:\n  - path: tatsu-lab/alpaca\n    type: alpaca\n\noutput_dir: ./outputs/qat_out/\n\nsequence_len: 2048\nsample_packing: true\nflex_attention: true\n\n\nflex_attn_compile_kwargs:\n  dynamic: false\n  mode: max-autotune-no-cudagraphs\n\nqat:\n  activation_dtype: int8\n  weight_dtype: int4\n  group_size: 256\n  fake_quant_after_n_steps: 1000\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 1\nmicro_batch_size: 2\nmax_steps: 2000\noptimizer: adamw_torch_fused\nlr_scheduler: cosine\nlearning_rate: 2e-5\n\nbf16: true\ntf32: true\n\nresume_from_checkpoint:\nlogging_steps: 1\n\nevals_per_epoch: 1\nsaves_per_epoch: 1\n\nwarmup_ratio: 0.1\nweight_decay: 0.0\nfsdp:\n  - full_shard\n  - auto_wrap\n\nfsdp_config:\n  fsdp_version: 2\n  fsdp_offload_params: false\n  fsdp_cpu_ram_efficient_loading: true\n  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP\n  fsdp_transformer_layer_cls_to_wrap: Qwen3DecoderLayer\n  fsdp_state_dict_type: FULL_STATE_DICT\n  fsdp_sharding_strategy: FULL_SHARD\n  fsdp_reshard_after_forward: true\n  fsdp_activation_checkpointing: true\n\nspecial_tokens:\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/qwen3/README.md",
    "content": "# Finetune Qwen3 with Axolotl\n\n[Qwen3](https://huggingface.co/collections/Qwen/qwen3) are a family of open source models trained by Alibaba.\n\nThis guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.\n\n## Getting started\n\n1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).\n\n2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.\n\n3. Run the finetuning example:\n\n    ```bash\n    axolotl train examples/qwen3/32b-qlora.yaml\n    ```\n\nLet us know how it goes. Happy finetuning! 🚀\n\n### Chat template masking a few tokens off\n\nIf you notice that the `chat_template` masking for assistant prompts are off by a few tokens, please ensure that you are adding the below to the yaml.\n\n```yaml\nchat_template: qwen3\n```\n\n### TIPS\n\n- For inference, please check the official model card as it depends on your reasoning mode.\n- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.\n- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).\n- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).\n\n## Optimization Guides\n\nPlease check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).\n\n## Related Resources\n\n- [Qwen3 Blog](https://qwenlm.github.io/blog/qwen3/)\n- [Axolotl Docs](https://docs.axolotl.ai)\n- [Axolotl Website](https://axolotl.ai)\n- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)\n- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)\n"
  },
  {
    "path": "examples/qwen3/qlora-fsdp.yaml",
    "content": "base_model: Qwen/Qwen3-8B\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nload_in_8bit: false\nload_in_4bit: true\nstrict: false\n\ndatasets:\n  - path: tatsu-lab/alpaca\n    type: alpaca\ndataset_prepared_path:\nval_set_size: 0.05\noutput_dir: ./outputs/out\n\nsequence_len: 2048\nsample_packing: true\neval_sample_packing: true\n\n\nadapter: qlora\nlora_model_dir:\nlora_r: 32\nlora_alpha: 64\nlora_dropout: 0.05\nlora_target_linear: true\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 1\nnum_epochs: 1\noptimizer: adamw_torch_fused\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: true\n\ngradient_checkpointing: true\ngradient_checkpointing_kwargs:\n  use_reentrant: false\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\nweight_decay: 0.0\nfsdp:\n  - full_shard\n  - auto_wrap\nfsdp_config:\n  fsdp_limit_all_gathers: true\n  fsdp_sync_module_states: true\n  fsdp_offload_params: true\n  fsdp_use_orig_params: false\n  fsdp_cpu_ram_efficient_loading: true\n  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP\n  fsdp_transformer_layer_cls_to_wrap: Qwen3DecoderLayer\n  fsdp_state_dict_type: FULL_STATE_DICT\n  fsdp_sharding_strategy: FULL_SHARD\nspecial_tokens:\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/qwen3/reward-model.yaml",
    "content": "base_model: Skywork/Skywork-Reward-V2-Qwen3-8B\nmodel_type: AutoModelForSequenceClassification\nnum_labels: 1\n\nreward_model: true\ncenter_rewards_coefficient: 0.01  # Incentivize mean-zero rewards for improved stability\nchat_template: qwen3\ndatasets:\n  - path: argilla/distilabel-intel-orca-dpo-pairs\n    type: bradley_terry.chat_template\n\nval_set_size: 0.0\noutput_dir: ./outputs/out\n\nsequence_len: 8192\nsample_packing: false\neval_sample_packing: false\npad_to_sequence_len: true\n\ndeepspeed: deepspeed_configs/zero1.json\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 1\neval_batch_size: 1\nnum_epochs: 3\noptimizer: adamw_bnb_8bit\nlr_scheduler: linear\nlearning_rate: 0.00002\n\nbf16: true\ntf32: true\n\ngradient_checkpointing: true\ngradient_checkpointing_kwargs:\n  use_reentrant: false\nwarmup_ratio: 0.1\nlogging_steps: 1\nweight_decay: 0.01\n"
  },
  {
    "path": "examples/qwen3-next/README.md",
    "content": "# Finetune Qwen3-Next with Axolotl\n\n[Qwen3-Next](https://huggingface.co/collections/Qwen/qwen3-next-68c25fd6838e585db8eeea9d) represents the next-generation foundation models optimized for extreme context length and large-scale parameter efficiency. The series introduces architectural innovations including Hybrid Attention (Gated DeltaNet + Gated Attention), High-Sparsity MoE with 1:50 activation ratio, and Multi-Token Prediction for enhanced performance and inference acceleration.\n\nThis guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.\n\n## Getting started\n\n1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).\n\n2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.\n\n3. Install FLA for improved performance\n```bash\npip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.4.1\n```\n\n4. Run the finetuning example:\n\n```bash\naxolotl train examples/qwen3-next/qwen3-next-80b-a3b-qlora.yaml\n```\n\nThis config uses about ~47 GiB (no target experts) and ~71GiB (target experts) VRAM.\n\nLet us know how it goes. Happy finetuning! 🚀\n\n### TIPS\n\n- For inference, you can experiment with `temperature: 0.7`, `top_p: 0.8`, `top_k: 20`, and `min_p: 0`.\n- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config. See [Multi-GPU](#optimization-guides) section below.\n- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).\n- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).\n\n## Optimization Guides\n\n- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)\n- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)\n- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)\n\n## Related Resources\n\n- [Qwen3-Next Blog](https://qwenlm.github.io/blog/qwen3_next/)\n- [Axolotl Docs](https://docs.axolotl.ai)\n- [Axolotl Website](https://axolotl.ai)\n- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)\n- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)\n"
  },
  {
    "path": "examples/qwen3-next/qwen3-next-80b-a3b-qlora.yaml",
    "content": "base_model: Qwen/Qwen3-Next-80B-A3B-Instruct\n\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nplugins:\n  - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin\n\nload_in_8bit: false\nload_in_4bit: true\n\nquantize_moe_experts: true\n\ndatasets:\n  - path: fozziethebeat/alpaca_messages_2k_test\n    type: chat_template\n\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.1\noutput_dir: ./outputs/lora-out\n\nadapter: qlora\nlora_model_dir:\n\nsequence_len: 2048\nsample_packing: true\n\nlora_r: 16\nlora_alpha: 8\nlora_dropout: 0\nlora_target_modules:\n  - linear_attn.in_proj_ba\n  - linear_attn.in_proj_qkvz\n  - linear_attn.out_proj\n  - shared_expert.up_proj\n  - shared_expert.down_proj\n  - shared_expert.gate_proj\n  - shared_expert_gate\n  - q_proj\n  - v_proj\n  - k_proj\n  - o_proj\n\n# lora_target_parameters:\n#   - mlp.experts.gate_up_proj\n#   - mlp.experts.down_proj\n\nlora_mlp_kernel: false\nlora_qkv_kernel: false\nlora_o_kernel: false\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 2\nmicro_batch_size: 2\nnum_epochs: 1\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 1\nsaves_per_epoch: 1\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/qwen3.5/122b-a10b-moe-qlora-fsdp.yaml",
    "content": "base_model: Qwen/Qwen3.5-122B-A10B\n\nplugins:\n  - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin\nstrict: false\n\nchat_template: qwen3_5\ndatasets:\n  - path: mlabonne/FineTome-100k\n    type: chat_template\n    split: train[:20%]\n    field_messages: conversations\n    message_property_mappings:\n      role: from\n      content: value\nval_set_size: 0.0\noutput_dir: ./outputs/out\ndataset_prepared_path: last_run_prepared\n\nsequence_len: 2048\nsample_packing: true\n\nload_in_4bit: true\nquantize_moe_experts: true\nadapter: qlora\nlora_r: 16\nlora_alpha: 32\nlora_dropout: 0\nlora_target_modules:\n  - q_proj\n  - k_proj\n  - v_proj\n  - o_proj\n# Regex matching to target shared experts too\n# lora_target_modules: 'model\\.(language_model\\.)?layers\\.[\\d]+\\.(mlp|self_attn)\\.(shared_expert\\.)?(up|down|gate|gate_up|q|k|v|o)_proj'\n\n# Target experts\n# lora_target_parameters:\n#   - mlp.experts.gate_up_proj\n#   - mlp.experts.down_proj\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 2\nmicro_batch_size: 1\nnum_epochs: 1\noptimizer: adamw_torch_4bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: true\n\nlora_mlp_kernel: false\nlora_qkv_kernel: false\nlora_o_kernel: false\n\ngradient_checkpointing: true\ngradient_checkpointing_kwargs:\n  use_reentrant: false\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n\nfsdp_config:\n  fsdp_version: 2\n  offload_params: true\n  cpu_ram_efficient_loading: false\n  auto_wrap_policy: TRANSFORMER_BASED_WRAP\n  transformer_layer_cls_to_wrap: Qwen3_5MoeDecoderLayer\n  state_dict_type: FULL_STATE_DICT\n  sharding_strategy: FULL_SHARD\n  reshard_after_forward: true\n  activation_checkpointing: true\n"
  },
  {
    "path": "examples/qwen3.5/122b-a10b-moe-qlora.yaml",
    "content": "base_model: Qwen/Qwen3.5-122B-A10B\n\nplugins:\n  - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin\nstrict: false\n\nchat_template: qwen3_5\ndatasets:\n  - path: mlabonne/FineTome-100k\n    type: chat_template\n    split: train[:20%]\n    field_messages: conversations\n    message_property_mappings:\n      role: from\n      content: value\nval_set_size: 0.0\noutput_dir: ./outputs/out\ndataset_prepared_path: last_run_prepared\n\nsequence_len: 2048\nsample_packing: true\n\nload_in_4bit: true\nquantize_moe_experts: true\nadapter: qlora\nlora_r: 16\nlora_alpha: 32\nlora_dropout: 0\nlora_target_modules:\n  - q_proj\n  - k_proj\n  - v_proj\n  - o_proj\n\n# Regex matching to target shared experts too\n# lora_target_modules: 'model\\.(language_model\\.)?layers\\.[\\d]+\\.(mlp|self_attn)\\.(shared_expert\\.)?(up|down|gate|gate_up|q|k|v|o)_proj'\n\n# Target experts\n# lora_target_parameters:\n#   - mlp.experts.gate_up_proj\n#   - mlp.experts.down_proj\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 2\nmicro_batch_size: 1\nnum_epochs: 1\noptimizer: adamw_torch_4bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: true\n\nlora_mlp_kernel: false\nlora_qkv_kernel: false\nlora_o_kernel: false\n\ngradient_checkpointing: true\ngradient_checkpointing_kwargs:\n  use_reentrant: false\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n"
  },
  {
    "path": "examples/qwen3.5/27b-fft.yaml",
    "content": "base_model: Qwen/Qwen3.5-27B\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\n# Full fine-tune (FFT) of the text-only path of Qwen3.5-27B.\n\nplugins:\n  - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin\nstrict: false\n\nchat_template: qwen3_5\ndatasets:\n  - path: mlabonne/FineTome-100k\n    type: chat_template\n    split: train[:20%]\n    field_messages: conversations\n    message_property_mappings:\n      role: from\n      content: value\nval_set_size: 0.0\noutput_dir: ./outputs/out\ndataset_prepared_path: last_run_prepared\n\nsequence_len: 2048\nsample_packing: true\n\n# Freeze vision encoder\nunfrozen_parameters:\n  - model\\.language_model\\..*\n  - lm_head\\..*\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 2\nmicro_batch_size: 1\nnum_epochs: 1\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: true\n\ngradient_checkpointing: true\ngradient_checkpointing_kwargs:\n  use_reentrant: false\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n"
  },
  {
    "path": "examples/qwen3.5/27b-qlora-fsdp.yaml",
    "content": "base_model: Qwen/Qwen3.5-27B\n\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nplugins:\n  - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin\nstrict: false\n\nchat_template: qwen3_5\ndatasets:\n  - path: mlabonne/FineTome-100k\n    type: chat_template\n    split: train[:20%]\n    field_messages: conversations\n    message_property_mappings:\n      role: from\n      content: value\nval_set_size: 0.0\noutput_dir: ./outputs/out\ndataset_prepared_path: last_run_prepared\n\nsequence_len: 2048\nsample_packing: true\n\nload_in_4bit: true\nadapter: qlora\nlora_r: 16\nlora_alpha: 32\nlora_target_modules:\n  - q_proj\n  - k_proj\n  - v_proj\n  - o_proj\n  - down_proj\n  - up_proj\n  # Uncomment below to also target the linear attention projections.\n  # These use separate in_proj_qkv / in_proj_z / out_proj (Qwen3.5-specific).\n  # - linear_attn.in_proj_qkv\n  # - linear_attn.in_proj_z\n  # - linear_attn.out_proj\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 2\nmicro_batch_size: 1\nnum_epochs: 1\noptimizer: adamw_torch_4bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: true\n\ngradient_checkpointing: true\ngradient_checkpointing_kwargs:\n  use_reentrant: false\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n\nfsdp_config:\n  fsdp_version: 2\n  offload_params: false\n  cpu_ram_efficient_loading: false\n  auto_wrap_policy: TRANSFORMER_BASED_WRAP\n  transformer_layer_cls_to_wrap: Qwen3_5DecoderLayer\n  state_dict_type: FULL_STATE_DICT\n  sharding_strategy: FULL_SHARD\n  reshard_after_forward: true\n  activation_checkpointing: true\n"
  },
  {
    "path": "examples/qwen3.5/27b-qlora.yaml",
    "content": "base_model: Qwen/Qwen3.5-27B\n\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nplugins:\n  - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin\nstrict: false\n\nchat_template: qwen3_5\ndatasets:\n  - path: mlabonne/FineTome-100k\n    type: chat_template\n    split: train[:20%]\n    field_messages: conversations\n    message_property_mappings:\n      role: from\n      content: value\nval_set_size: 0.0\noutput_dir: ./outputs/out\ndataset_prepared_path: last_run_prepared\n\nsequence_len: 2048\nsample_packing: true\n\nload_in_4bit: true\nadapter: qlora\nlora_r: 16\nlora_alpha: 32\nlora_target_modules:\n  - q_proj\n  - k_proj\n  - v_proj\n  - o_proj\n  - down_proj\n  - up_proj\n  # Uncomment below to also target the linear attention projections.\n  # These use separate in_proj_qkv / in_proj_z / out_proj (Qwen3.5-specific).\n  # - linear_attn.in_proj_qkv\n  # - linear_attn.in_proj_z\n  # - linear_attn.out_proj\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 2\nmicro_batch_size: 1\nnum_epochs: 1\noptimizer: adamw_torch_4bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: true\n\ngradient_checkpointing: true\ngradient_checkpointing_kwargs:\n  use_reentrant: false\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n"
  },
  {
    "path": "examples/qwen3.5/35b-a3b-moe-qlora-fsdp.yaml",
    "content": "base_model: Qwen/Qwen3.5-35B-A3B\n\nplugins:\n  - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin\nstrict: false\n\nchat_template: qwen3_5\ndatasets:\n  - path: mlabonne/FineTome-100k\n    type: chat_template\n    split: train[:20%]\n    field_messages: conversations\n    message_property_mappings:\n      role: from\n      content: value\nval_set_size: 0.0\noutput_dir: ./outputs/out\ndataset_prepared_path: last_run_prepared\n\nsequence_len: 2048\nsample_packing: true\n\nload_in_4bit: true\nquantize_moe_experts: true\nadapter: qlora\nlora_r: 16\nlora_alpha: 32\nlora_dropout: 0\nlora_target_modules:\n  - q_proj\n  - k_proj\n  - v_proj\n  - o_proj\n\n# Regex matching to target shared experts too\n# lora_target_modules: 'model\\.(language_model\\.)?layers\\.[\\d]+\\.(mlp|self_attn)\\.(shared_expert\\.)?(up|down|gate|gate_up|q|k|v|o)_proj'\n\n# Target experts\n# lora_target_parameters:\n#   - mlp.experts.gate_up_proj\n#   - mlp.experts.down_proj\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 2\nmicro_batch_size: 1\nnum_epochs: 1\noptimizer: adamw_torch_4bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: true\n\nlora_mlp_kernel: false\nlora_qkv_kernel: false\nlora_o_kernel: false\n\ngradient_checkpointing: true\ngradient_checkpointing_kwargs:\n  use_reentrant: false\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n\nfsdp_config:\n  fsdp_version: 2\n  offload_params: true\n  cpu_ram_efficient_loading: false\n  auto_wrap_policy: TRANSFORMER_BASED_WRAP\n  transformer_layer_cls_to_wrap: Qwen3_5MoeDecoderLayer\n  state_dict_type: FULL_STATE_DICT\n  sharding_strategy: FULL_SHARD\n  reshard_after_forward: true\n  activation_checkpointing: true\n"
  },
  {
    "path": "examples/qwen3.5/35b-a3b-moe-qlora.yaml",
    "content": "base_model: Qwen/Qwen3.5-35B-A3B\n\nplugins:\n  - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin\nstrict: false\n\nchat_template: qwen3_5\ndatasets:\n  - path: mlabonne/FineTome-100k\n    type: chat_template\n    split: train[:20%]\n    field_messages: conversations\n    message_property_mappings:\n      role: from\n      content: value\nval_set_size: 0.0\noutput_dir: ./outputs/out\ndataset_prepared_path: last_run_prepared\n\nsequence_len: 2048\nsample_packing: true\n\nload_in_4bit: true\nquantize_moe_experts: true\nadapter: qlora\nlora_r: 16\nlora_alpha: 32\nlora_dropout: 0\nlora_target_modules:\n  - q_proj\n  - k_proj\n  - v_proj\n  - o_proj\n\n# Regex matching to target shared experts too\n# lora_target_modules: 'model\\.(language_model\\.)?layers\\.[\\d]+\\.(mlp|self_attn)\\.(shared_expert\\.)?(up|down|gate|gate_up|q|k|v|o)_proj'\n\n# Target experts\n# lora_target_parameters:\n#   - mlp.experts.gate_up_proj\n#   - mlp.experts.down_proj\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 2\nmicro_batch_size: 1\nnum_epochs: 1\noptimizer: adamw_torch_4bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: true\n\nlora_mlp_kernel: false\nlora_qkv_kernel: false\nlora_o_kernel: false\n\ngradient_checkpointing: true\ngradient_checkpointing_kwargs:\n  use_reentrant: false\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 4\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n"
  },
  {
    "path": "examples/qwen3.5/9b-fft-vision.yaml",
    "content": "base_model: Qwen/Qwen3.5-9B\nprocessor_type: AutoProcessor\n\n# Required for multimodal training\nskip_prepare_dataset: true\nremove_unused_columns: false\nsample_packing: false\n\nchat_template: qwen3_5\ndatasets:\n  - path: HuggingFaceH4/llava-instruct-mix-vsft\n    type: chat_template\n    split: train[:1%]\n\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.0\noutput_dir: ./outputs/out\n\nsequence_len: 4096\npad_to_sequence_len: false\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 1\nnum_epochs: 1\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: true\n\ngradient_checkpointing: true\ngradient_checkpointing_kwargs:\n  use_reentrant: false\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 1\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n"
  },
  {
    "path": "examples/qwen3.5/9b-lora-vision.yaml",
    "content": "base_model: Qwen/Qwen3.5-9B\nprocessor_type: AutoProcessor\n\n# These 3 lines are required for vision/multimodal training\nskip_prepare_dataset: true\nremove_unused_columns: false\nsample_packing: false\n\nchat_template: qwen3_5\ndatasets:\n  - path: HuggingFaceH4/llava-instruct-mix-vsft\n    type: chat_template\n    split: train[:1%]\n\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.0\noutput_dir: ./outputs/out\n\nadapter: lora\nlora_model_dir:\n\nsequence_len: 8192\npad_to_sequence_len: false\n\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\n# Targets the language model attention and MLP layers.\nlora_target_modules:\n  - q_proj\n  - k_proj\n  - v_proj\n  - o_proj\n  - down_proj\n  - up_proj\n  # Uncomment to also target the linear attention (GatedDeltaNet) projections:\n  # - linear_attn.in_proj_qkv\n  # - linear_attn.in_proj_z\n  # - linear_attn.out_proj\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 1\nnum_epochs: 1\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: true\ntf32: true\n\ngradient_checkpointing: true\ngradient_checkpointing_kwargs:\n  use_reentrant: false\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 1\nsaves_per_epoch: 1\nweight_decay: 0.0\n"
  },
  {
    "path": "examples/qwen3.5/README.md",
    "content": "# Finetune Qwen3.5 with Axolotl\n\n[Qwen3.5](https://huggingface.co/collections/Qwen/qwen35) is a hybrid architecture model series combining Gated DeltaNet linear attention with standard Transformer attention. All Qwen3.5 models are early-fusion vision-language models: dense variants use `Qwen3_5ForConditionalGeneration` and MoE variants use `Qwen3_5MoeForConditionalGeneration`.\n\n## Getting started\n\n1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).\n\n2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.\n\n3. Install FLA for sample packing support with the Gated DeltaNet linear attention layers:\n  ```bash\n  pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.4.1\n  ```\n  > FLA is required when `sample_packing: true`. Without it, training raises a `RuntimeError` on packed sequences. Vision configs use `sample_packing: false` so FLA is optional there.\n\n4. Pick any config from the table below and run:\n\n    ```bash\n    axolotl train examples/qwen3.5/<config>.yaml\n    ```\n\nAvailable configs:\n\n| Config | Model | Type | Peak VRAM |\n|---|---|---|---|\n| `9b-lora-vision.yaml` | Qwen3.5-9B | Vision+text LoRA, single GPU | — |\n| `9b-fft-vision.yaml` | Qwen3.5-9B | Vision+text FFT, single GPU | ~61 GiB |\n| `27b-qlora.yaml` | Qwen3.5-27B | Dense, text-only QLoRA | ~47 GiB |\n| `27b-fft.yaml` | Qwen3.5-27B | Dense, text-only FFT (vision frozen) | ~53 GiB |\n| `27b-qlora-fsdp.yaml` | Qwen3.5-27B | Dense, text-only QLoRA + FSDP2 | — |\n| `35b-a3b-moe-qlora.yaml` | Qwen3.5-35B-A3B | MoE, text-only QLoRA | — |\n| `35b-a3b-moe-qlora-fsdp.yaml` | Qwen3.5-35B-A3B | MoE, text-only QLoRA + FSDP2 | — |\n| `122b-a10b-moe-qlora.yaml` | Qwen3.5-122B-A10B | MoE, text-only QLoRA | — |\n| `122b-a10b-moe-qlora-fsdp.yaml` | Qwen3.5-122B-A10B | MoE, text-only QLoRA + FSDP2 | — |\n\n### Gated DeltaNet Linear Attention\n\nQwen3.5 interleaves standard attention with Gated DeltaNet linear attention layers. To apply LoRA to them, add to `lora_target_modules`:\n\n```yaml\nlora_target_modules:\n  # ... standard projections ...\n  - linear_attn.in_proj_qkv\n  - linear_attn.in_proj_z\n  - linear_attn.out_proj\n```\n\n### Routed Experts (MoE)\n\nTo apply LoRA to routed expert parameters, add `lora_target_parameters`:\n\n```yaml\nlora_target_parameters:\n  - mlp.experts.gate_up_proj\n  - mlp.experts.down_proj\n#  - mlp.gate.weight  # router\n```\n\n### Shared Experts (MoE)\n\nRouted experts and shared experts both have `gate_up_proj`/`down_proj`, so a plain module name in `lora_target_modules` would match both. Use a regex to target only attention and shared expert projections, while `lora_target_parameters` above handles routed experts separately:\n\n```yaml\nlora_target_modules: 'model\\.(language_model\\.)?layers\\.[\\d]+\\.(mlp|self_attn)\\.(shared_expert\\.)?(up|down|gate|gate_up|q|k|v|o)_proj'\n```\n\n### TIPS\n\n- For inference hyp, please see the respective model card details.\n- You can run a full finetuning of smaller configs by removing `adapter: qlora` and `load_in_4bit: true`. See [Multi-GPU](#optimization-guides) below.\n- Read more on loading your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).\n- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).\n- For **multimodal** finetuning, set `processor_type: AutoProcessor`, `skip_prepare_dataset: true`, and `remove_unused_columns: false` as shown in `9b-lora-vision.yaml`.\n\n## Optimization Guides\n\n- [Optimizations Guide](https://docs.axolotl.ai/docs/optimizations.html)\n\n## Related Resources\n\n- [Qwen3.5 Blog](https://qwenlm.github.io/blog/qwen3.5/)\n- [Axolotl Docs](https://docs.axolotl.ai)\n- [Axolotl Website](https://axolotl.ai)\n- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)\n- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)\n"
  },
  {
    "path": "examples/seed-oss/README.md",
    "content": "# Finetune ByteDance's Seed-OSS with Axolotl\n\n[Seed-OSS](https://huggingface.co/collections/ByteDance-Seed/seed-oss-68a609f4201e788db05b5dcd) are a series of 36B parameter open source models trained by ByteDance's Seed Team.\n\nThis guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.\n\n## Getting started\n\n1.  Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).\n\n    Here is an example of how to install from pip:\n    ```bash\n    # Ensure you have a compatible version of Pytorch installed\n    pip3 install packaging setuptools wheel ninja\n    pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'\n\n    # Install Cut Cross Entropy\n    python scripts/cutcrossentropy_install.py | sh\n    ```\n\n2. Run the finetuning example:\n\n```bash\naxolotl train examples/seed-oss/seed-oss-36b-qlora.yaml\n```\n\nThis config uses about 27.7 GiB VRAM.\n\nLet us know how it goes. Happy finetuning! 🚀\n\n### TIPS\n\n- For inference, the official Seed Team recommends `top_p=0.95` and `temperature=1.1`.\n- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.\n- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).\n- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).\n\n## Optimization Guides\n\nPlease check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).\n\n## Related Resources\n\n- [ByteDance Seed Website](https://seed.bytedance.com/)\n- [Axolotl Docs](https://docs.axolotl.ai)\n- [Axolotl Website](https://axolotl.ai)\n- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)\n- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)\n"
  },
  {
    "path": "examples/seed-oss/seed-oss-36b-qlora.yaml",
    "content": "base_model: ByteDance-Seed/Seed-OSS-36B-Instruct\n\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\nplugins:\n  - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin\n\nload_in_8bit: false\nload_in_4bit: true\n\ndatasets:\n  - path: fozziethebeat/alpaca_messages_2k_test\n    type: chat_template\n\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.1\noutput_dir: ./outputs/lora-out\n\nadapter: qlora\nlora_model_dir:\n\nsequence_len: 2048\nsample_packing: true\n\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_linear: true\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 2\nnum_epochs: 1\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 1\nsaves_per_epoch: 1\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/slurm/README.md",
    "content": "# SLURM Multi-Node Training\n\nThis directory contains an example SLURM script for running Axolotl training jobs across multiple nodes in a SLURM cluster.\n\n## Prerequisites\n\n- Access to a SLURM cluster with GPU nodes\n- Axolotl installed on all nodes (see [installation docs](https://docs.axolotl.ai/docs/installation.html))\n\n## Usage\n\n### Standard SLURM Clusters\n\n1. Copy [`axolotl.slurm`](./axolotl.slurm) to your working directory.\n2. Place your Axolotl config file (`train.yaml`) in the same directory.\n3. Set the appropriate environment variables for the job:\n    ```bash\n    export HF_TOKEN=\"your-huggingface-token\"\n\n    # metric tracking\n    # export WANDB_API_KEY=\"your-wandb-api-key\"\n    # ...\n    ```\n4. Submit the job:\n   ```bash\n   sbatch --export=ALL,NUM_NODES=2,NUM_TRAINERS=8,PRIMARY_ADDR=<master-node>,PRIMARY_PORT=29400 axolotl.slurm\n   ```\n\n   Where:\n   - `NUM_NODES`: Number of nodes to use\n   - `NUM_TRAINERS`: GPUs per node (typically 8)\n   - `PRIMARY_ADDR`: Hostname/IP of the master node\n   - `PRIMARY_PORT`: Port for distributed training (default: 29400)\n\n5. (Optional) Run other slurm commands:\n    ```bash\n    # check job info\n    scontrol show job axolotl-cli\n\n    # check job queue\n    squeue\n\n    # check cluster status\n    sinfo\n    ```\n\n### RunPod Instant Clusters\n\nAxolotl works with RunPod Instant Clusters. This feature provides managed SLURM clusters with zero configuration.\n\n1. **Deploy a SLURM Cluster**:\n   - Go to [RunPod Instant Clusters](https://console.runpod.io/cluster)\n   - Click \"Create a Cluster\"\n   - Choose your GPU type, node count, and region\n   - Choose an [Axolotl cloud docker image](https://docs.axolotl.ai/docs/docker.html#cloud)\n   - Deploy the cluster\n\n2. **Connect to the Controller Node**: Find the controller node in the RunPod console and connect via SSH\n\n3. **Follow the instructions in [Standard SLURM Clusters](#standard-slurm-clusters)**\n\n## Additional Resources\n\n- [Axolotl Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)\n- [SLURM Documentation](https://slurm.schedmd.com/documentation.html)\n- [RunPod SLURM Clusters Guide](https://docs.runpod.io/instant-clusters/slurm-clusters)\n"
  },
  {
    "path": "examples/slurm/axolotl.slurm",
    "content": "#!/bin/bash\n# Prior to running this script, export your HF_TOKEN and WANDB_API_KEY to your environment; i.e.\n# export HF_TOKEN=\"...\"\n# export WANDB_API_KEY=\"...\"\n#\n\n# ---------- SBATCH commands ---------- #\n#SBATCH --job-name=axolotl-slurm-multinode\n#SBATCH --ntasks-per-node=1\n#SBATCH --nodes=$NUM_NODES\n#SBATCH --gpus-per-task=8\n#SBATCH --cpus-per-task=128\n\nexport TORCH_DIST_INIT_BARRIER=0\n\nsrun axolotl preprocess train.yaml\n\nsrun axolotl train train.yaml --launcher torchrun -- \\\n    --nproc_per_node=$NUM_TRAINERS --nnodes=$NUM_NODES \\\n    --rdzv_id axolotl-cli --rdzv_backend c10d --rdzv_endpoint \"${PRIMARY_ADDR}:${PRIMARY_PORT}\" --rdzv-conf=\"join_timeout=1800\"\n"
  },
  {
    "path": "examples/smolvlm2/README.md",
    "content": "# Finetune SmolVLM2 with Axolotl\n\n[SmolVLM2](https://huggingface.co/collections/HuggingFaceTB/smolvlm2-smallest-video-lm-ever-67ab6b5e84bf8aaa60cb17c7) are a family of lightweight, open-source multimodal models from HuggingFace designed to analyze and understand video, image, and text content.\n\nThese models are built for efficiency, making them well-suited for on-device applications where computational resources are limited. Models are available in multiple sizes, including 2.2B, 500M, and 256M.\n\nThis guide shows how to fine-tune SmolVLM2 models with Axolotl.\n\n## Getting Started\n\n1.  Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).\n\n    Here is an example of how to install from pip:\n    ```bash\n    # Ensure you have a compatible version of Pytorch installed\n    pip3 install packaging setuptools wheel ninja\n    pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'\n    ```\n\n2. Install an extra dependency:\n\n    ```bash\n    pip3 install num2words==0.5.14\n    ```\n\n3.  Run the finetuning example:\n\n    ```bash\n    # LoRA SFT (1x48GB @ 6.8GiB)\n    axolotl train examples/smolvlm2/smolvlm2-2B-lora.yaml\n    ```\n\n## TIPS\n\n- **Dataset Format**: For video finetuning, your dataset must be compatible with the multi-content Messages format. For more details, see our documentation on [Multimodal Formats](https://docs.axolotl.ai/docs/multimodal.html#dataset-format).\n- **Dataset Loading**: Read more on how to prepare and load your own datasets in our [documentation](https://docs.axolotl.ai/docs/dataset_loading.html).\n\n## Optimization Guides\n\nPlease check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).\n\n## Related Resources\n\n- [SmolVLM2 Blog](https://huggingface.co/blog/smolvlm2)\n- [Axolotl Docs](https://docs.axolotl.ai)\n- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)\n- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)\n"
  },
  {
    "path": "examples/smolvlm2/smolvlm2-2B-lora.yaml",
    "content": "base_model: HuggingFaceTB/SmolVLM2-2.2B-Instruct\ntrust_remote_code: true\nprocessor_type: AutoProcessor\n\n# these 3 lines are needed for now to handle vision chat templates w images\nskip_prepare_dataset: true\nremove_unused_columns: false\nsample_packing: false\n\ndatasets:\n  - path: HuggingFaceH4/llava-instruct-mix-vsft\n    type: chat_template\n    split: train[:1%]\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.0\noutput_dir: ./outputs/out\n\nadapter: lora\nlora_model_dir:\n\nsequence_len: 8192\npad_to_sequence_len: false\n\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_modules: 'model.text_model.layers.[\\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 1\nnum_epochs: 1\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: true\nfp16:\ntf32: true\n\ngradient_checkpointing: true\nlogging_steps: 1\nflash_attention: true\neager_attention:\n\nwarmup_ratio: 0.1\nevals_per_epoch: 1\nsaves_per_epoch: 1\nweight_decay: 0.0\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/streaming/README.md",
    "content": "# Streaming Dataset Examples\n\nThis directory contains example configurations for using Axolotl's streaming dataset\nfunctionality, which enables memory-efficient training with large datasets.\n\n## Examples\n\nRun the following examples with e.g. `axolotl train examples/streaming/sft.yaml`; no\n`axolotl preprocess` required!\n\n### Pretraining (`pretrain.yaml`)\n\nDemonstrates streaming configuration for pretraining tasks using the fineweb-edu dataset\nwith SmolLM2-135M.\n\n- Uses `pretraining_dataset` configuration for automatic streaming\n- Multipack attention control to prevent cross-attention between packed sequences\n- Buffer size configuration for memory management\n\n### SFT (`sft.yaml`)\n\nShows how to use streaming for supervised fine-tuning with the Alpaca dataset.\n\n- Explicit `streaming: true` flag for SFT datasets\n- Memory-efficient training on instruction datasets\n- Evaluation datasets are currently not streamed\n\n## Key Configuration Options\n\n### `streaming`\n- Enables streaming mode for standard datasets\n- Automatically enabled for `pretraining_dataset`\n\n### `streaming_multipack_buffer_size`\n- Controls buffer size for sample packing (default: 10,000)\n- Larger values improve packing efficiency but use more memory\n- Adjust based on available memory\n\n### `shuffle_merged_datasets`\n- Enables shuffling of streaming datasets\n- Requires additional memory for shuffle buffer\n\n### `sample_packing`\n- Packs multiple samples into single sequences\n- Minimize per-step padding tokens\n\n## Performance Tips\n\n- Download small / frequently-used datasets locally for better performance\n- Larger buffer sizes improve packing efficiency\n"
  },
  {
    "path": "examples/streaming/pretrain.yaml",
    "content": "base_model: HuggingFaceTB/SmolLM2-135M\n\n# Streaming pretraining configuration\npretraining_dataset:\n  - path: HuggingFaceFW/fineweb-edu\n    name: sample-10BT\n    type: pretrain\n    text_column: text\n    split: train\n\n# Streaming-specific settings\nstreaming_multipack_buffer_size: 10000\nshuffle_merged_datasets: true\n\n# Training configuration\nmax_steps: 1000\noutput_dir: ./outputs/smollm2-135m-pretrain-streaming\n\n# Sequence and packing settings\nsequence_len: 1024\nsample_packing: true\npretrain_multipack_attn: true  # Prevent cross-attention between packed sequences\nflash_attention: true\n\n# Batch size settings\ngradient_accumulation_steps: 8\nmicro_batch_size: 1\n\n# Optimizer and scheduler\noptimizer: adamw_torch\nlr_scheduler: cosine\nlearning_rate: 5e-4\nwarmup_ratio: 0.1\nweight_decay: 0.01\n\n# Precision and performance\nbf16: auto\ntf32: true\n\n# Logging and checkpointing\nlogging_steps: 10\nsave_strategy: steps\nsave_steps: 250\nsave_total_limit: 3\n\n# Weights & Biases (optional)\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\n# Special tokens\nspecial_tokens:\n  pad_token: \"<|endoftext|>\"\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/streaming/sft.yaml",
    "content": "base_model: HuggingFaceTB/SmolLM2-135M\n\n# Dataset configuration\ndatasets:\n  - path: tatsu-lab/alpaca\n    type: alpaca\n    split: train\n\n# Streaming-specific settings\nstreaming: true\nstreaming_multipack_buffer_size: 10000\nshuffle_merged_datasets: true\n\n# Training configuration\nmax_steps: 1000\noutput_dir: ./outputs/smollm2-135m-sft-streaming\n\n# Sequence and packing settings\nsequence_len: 1024\nsample_packing: true\nflash_attention: true\n\n# Batch size settings\ngradient_accumulation_steps: 4\nmicro_batch_size: 1\n\n# Optimizer and scheduler\noptimizer: adamw_torch\nlr_scheduler: cosine\nlearning_rate: 2e-4\nwarmup_ratio: 0.1\nweight_decay: 0.0\n\n# Precision and performance\nbf16: auto\ntf32: true\n\n# Logging and checkpointing\nlogging_steps: 10\nsave_strategy: steps\nsave_steps: 100\nsave_total_limit: 3\n\n# Weights & Biases (optional)\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\n# Special tokens\nspecial_tokens:\n  pad_token: \"<|endoftext|>\"\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/swanlab/README.md",
    "content": "# SwanLab Integration Examples\n\nThis directory contains example configurations demonstrating SwanLab integration with Axolotl.\n\n## Examples Overview\n\n### 1. DPO with Completion Logging\n**File**: `dpo-swanlab-completions.yml`\n\nDemonstrates DPO (Direct Preference Optimization) training with RLHF completion table logging.\n\n**Features**:\n- Basic SwanLab experiment tracking\n- Completion table logging (prompts, chosen/rejected responses, rewards)\n- Memory-bounded buffer for long training runs\n- Cloud sync configuration\n\n**Best for**: RLHF practitioners who want to analyze model outputs qualitatively\n\n**Quick start**:\n```bash\nexport SWANLAB_API_KEY=your-api-key\naccelerate launch -m axolotl.cli.train examples/swanlab/dpo-swanlab-completions.yml\n```\n\n---\n\n### 2. LoRA with Performance Profiling\n**File**: `lora-swanlab-profiling.yml`\n\nDemonstrates standard LoRA fine-tuning with performance profiling enabled.\n\n**Features**:\n- SwanLab experiment tracking\n- Automatic profiling of trainer methods\n- Profiling metrics visualization\n- Performance optimization guidance\n\n**Best for**: Engineers optimizing training performance and comparing different configurations\n\n**Quick start**:\n```bash\nexport SWANLAB_API_KEY=your-api-key\naccelerate launch -m axolotl.cli.train examples/swanlab/lora-swanlab-profiling.yml\n```\n\n---\n\n### 3. Full-Featured DPO Production Setup\n**File**: `dpo-swanlab-full-featured.yml`\n\nComprehensive production-ready configuration with ALL SwanLab features enabled.\n\n**Features**:\n- Experiment tracking with team workspace\n- RLHF completion logging\n- Performance profiling\n- Lark (Feishu) team notifications\n- Private deployment support\n- Production checklist and troubleshooting\n\n**Best for**: Production RLHF training with team collaboration\n\n**Quick start**:\n```bash\nexport SWANLAB_API_KEY=your-api-key\nexport SWANLAB_LARK_WEBHOOK_URL=https://open.feishu.cn/...\nexport SWANLAB_LARK_SECRET=your-webhook-secret\naccelerate launch -m axolotl.cli.train examples/swanlab/dpo-swanlab-full-featured.yml\n```\n\n---\n\n### 4. Custom Trainer Profiling (Python)\n**File**: `custom_trainer_profiling.py`\n\nPython code examples showing how to add SwanLab profiling to custom trainers.\n\n**Features**:\n- `@swanlab_profile` decorator examples\n- Context manager profiling for fine-grained timing\n- `ProfilingConfig` for advanced filtering and throttling\n- Multiple profiling patterns and best practices\n\n**Best for**: Advanced users creating custom trainers\n\n**Usage**:\n```python\nfrom custom_trainer_profiling import CustomTrainerWithProfiling\n# See file for detailed examples and patterns\n```\n\n---\n\n## Feature Matrix\n\n| Example | Tracking | Completion Logging | Profiling | Lark Notifications | Team Workspace |\n|---------|----------|-------------------|-----------|-------------------|----------------|\n| dpo-swanlab-completions.yml | ✅ | ✅ | ✅ (auto) | ➖ (commented) | ➖ (commented) |\n| lora-swanlab-profiling.yml | ✅ | ➖ (disabled) | ✅ (auto) | ➖ (commented) | ➖ (commented) |\n| dpo-swanlab-full-featured.yml | ✅ | ✅ | ✅ (auto) | ✅ | ✅ |\n| custom_trainer_profiling.py | N/A | N/A | ✅ (manual) | N/A | N/A |\n\n---\n\n## Configuration Quick Reference\n\n### Basic SwanLab Setup\n```yaml\nplugins:\n  - axolotl.integrations.swanlab.SwanLabPlugin\n\nuse_swanlab: true\nswanlab_project: my-project\nswanlab_experiment_name: my-experiment\nswanlab_mode: cloud  # cloud, local, offline, disabled\n```\n\n### RLHF Completion Logging\n```yaml\nswanlab_log_completions: true\nswanlab_completion_log_interval: 100  # Log every 100 steps\nswanlab_completion_max_buffer: 128    # Memory-bounded buffer\n```\n\n### Lark Team Notifications\n```yaml\nswanlab_lark_webhook_url: https://open.feishu.cn/...\nswanlab_lark_secret: your-webhook-secret  # Required for production\n```\n\n### Team Workspace\n```yaml\nswanlab_workspace: my-research-team\n```\n\n### Private Deployment\n```yaml\nswanlab_web_host: https://swanlab.yourcompany.com\nswanlab_api_host: https://api.swanlab.yourcompany.com\n```\n\n---\n\n## Authentication\n\n### Recommended: Environment Variable\n```bash\nexport SWANLAB_API_KEY=your-api-key\nexport SWANLAB_LARK_WEBHOOK_URL=https://open.feishu.cn/...\nexport SWANLAB_LARK_SECRET=your-webhook-secret\n```\n\n### Alternative: Config File (less secure)\n```yaml\nswanlab_api_key: your-api-key\nswanlab_lark_webhook_url: https://open.feishu.cn/...\nswanlab_lark_secret: your-webhook-secret\n```\n\n---\n\n## Common Use Cases\n\n### Use Case 1: Migrate from WandB to SwanLab\nStart with `lora-swanlab-profiling.yml`, add your model/dataset config, disable WandB:\n```yaml\nuse_swanlab: true\nuse_wandb: false\n```\n\n### Use Case 2: Analyze DPO Model Outputs\nUse `dpo-swanlab-completions.yml`, adjust completion logging interval based on your training length:\n```yaml\nswanlab_completion_log_interval: 50   # More frequent for short training\nswanlab_completion_log_interval: 200  # Less frequent for long training\n```\n\n### Use Case 3: Optimize Training Performance\nUse `lora-swanlab-profiling.yml`, run multiple experiments with different optimizations:\n- Baseline: `flash_attention: false, gradient_checkpointing: false`\n- Flash Attention: `flash_attention: true`\n- Gradient Checkpointing: `gradient_checkpointing: true`\n- Both: `flash_attention: true, gradient_checkpointing: true`\n\nCompare profiling metrics in SwanLab dashboard.\n\n### Use Case 4: Production RLHF with Team Collaboration\nUse `dpo-swanlab-full-featured.yml`, set up team workspace and Lark notifications:\n```yaml\nswanlab_workspace: ml-team\nswanlab_lark_webhook_url: ...\nswanlab_lark_secret: ...\n```\n\n---\n\n## Viewing Your Experiments\n\n### Cloud Mode\nVisit [https://swanlab.cn](https://swanlab.cn) and navigate to your project.\n\n**Dashboard sections**:\n- **Metrics**: Training loss, learning rate, profiling metrics\n- **Tables**: RLHF completions (for DPO/KTO/ORPO/GRPO)\n- **Config**: Hyperparameters and configuration\n- **System**: Resource usage (GPU, memory, CPU)\n- **Files**: Logged artifacts\n\n### Local Mode\n```bash\nswanlab watch ./swanlog\n# Open browser to http://localhost:5092\n```\n\n---\n\n## Troubleshooting\n\n### SwanLab not initializing\n```bash\n# Check API key\necho $SWANLAB_API_KEY\n\n# Verify SwanLab is installed\npip show swanlab\n\n# Check config\ngrep -A 5 \"use_swanlab\" your-config.yml\n```\n\n### Completions not appearing\n- Verify you're using an RLHF trainer (DPO/KTO/ORPO/GRPO)\n- Check `swanlab_log_completions: true`\n- Wait for `swanlab_completion_log_interval` steps\n- Look for \"Registered SwanLab RLHF completion logging\" in logs\n\n### Lark notifications not working\n- Test webhook manually: `curl -X POST \"$SWANLAB_LARK_WEBHOOK_URL\" ...`\n- Verify `SWANLAB_LARK_SECRET` is set correctly\n- Check bot is added to Lark group chat\n- Look for \"Registered Lark notification callback\" in logs\n\n### Profiling metrics not appearing\n- Verify `use_swanlab: true`\n- Check SwanLab is initialized (look for init log message)\n- Profiling metrics are under \"profiling/\" namespace\n- Profiling auto-enabled when SwanLab is enabled\n\n---\n\n## Performance Notes\n\n### Overhead Comparison\n\n| Feature | Overhead per Step | Memory Usage |\n|---------|------------------|--------------|\n| Basic tracking | < 0.1% | ~10 MB |\n| Completion logging | < 0.5% | ~64 KB (buffer=128) |\n| Profiling | < 0.1% | ~1 KB |\n| **Total** | **< 0.7%** | **~10 MB** |\n\n### Best Practices\n1. Use ONE logging tool in production (disable WandB/MLflow when using SwanLab)\n2. Adjust completion log interval based on training length (100-200 steps)\n3. Keep completion buffer size reasonable (128-512)\n4. Profile critical path methods first (training_step, compute_loss)\n5. Use ProfilingConfig to throttle high-frequency operations\n\n---\n\n## Further Reading\n\n- **Full Documentation**: [src/axolotl/integrations/swanlab/README.md](../../src/axolotl/integrations/swanlab/README.md)\n- **SwanLab Docs**: [https://docs.swanlab.cn](https://docs.swanlab.cn)\n- **Axolotl Docs**: [https://axolotl-ai-cloud.github.io/axolotl/](https://axolotl-ai-cloud.github.io/axolotl/)\n- **DPO Paper**: [Direct Preference Optimization](https://arxiv.org/abs/2305.18290)\n\n---\n\n## Contributing\n\nFound an issue or have an improvement? Please submit a PR or open an issue:\n- [Axolotl Issues](https://github.com/axolotl-ai-cloud/axolotl/issues)\n- [SwanLab Issues](https://github.com/SwanHubX/SwanLab/issues)\n"
  },
  {
    "path": "examples/swanlab/custom_trainer_profiling.py",
    "content": "\"\"\"Example: Custom Trainer with SwanLab Profiling\n\nThis example demonstrates how to add SwanLab profiling to your custom trainer.\n\nFeatures:\n- @swanlab_profile decorator for automatic profiling\n- swanlab_profiling_context for fine-grained profiling\n- ProfilingConfig for advanced filtering and throttling\n\nUsage:\n    1. Create your custom trainer extending AxolotlTrainer\n    2. Add @swanlab_profile decorators to methods you want to profile\n    3. Use swanlab_profiling_context for fine-grained profiling within methods\n    4. Enable SwanLab in your config (use_swanlab: true)\n\nSee also:\n    - examples/swanlab/lora-swanlab-profiling.yml for config\n    - src/axolotl/integrations/swanlab/profiling.py for implementation\n\"\"\"\n\nfrom axolotl.core.trainers.base import AxolotlTrainer\nfrom axolotl.integrations.swanlab.profiling import (\n    ProfilingConfig,\n    swanlab_profile,\n    swanlab_profiling_context,\n    swanlab_profiling_context_advanced,\n)\n\n\nclass CustomTrainerWithProfiling(AxolotlTrainer):\n    \"\"\"Custom trainer with SwanLab profiling enabled.\n\n    This trainer demonstrates three profiling patterns:\n    1. Decorator-based profiling (@swanlab_profile)\n    2. Context manager profiling (swanlab_profiling_context)\n    3. Advanced profiling with filtering (ProfilingConfig)\n    \"\"\"\n\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n\n        # Create custom profiling config for high-frequency operations\n        self.fast_op_config = ProfilingConfig(\n            enabled=True,\n            min_duration_ms=0.5,  # Only log if duration > 0.5ms\n            log_interval=50,  # Log every 50th call\n        )\n\n    # ========================================================================\n    # Pattern 1: Decorator-based Profiling\n    # ========================================================================\n    # Best for: Methods you always want to profile\n    # Overhead: ~2-5 microseconds per call (negligible)\n\n    @swanlab_profile\n    def training_step(self, model, inputs):\n        \"\"\"Main training step - always profile.\n\n        Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.training_step\n        \"\"\"\n        return super().training_step(model, inputs)\n\n    @swanlab_profile\n    def compute_loss(self, model, inputs, return_outputs=False):\n        \"\"\"Loss computation - always profile.\n\n        Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.compute_loss\n        \"\"\"\n        return super().compute_loss(model, inputs, return_outputs)\n\n    @swanlab_profile\n    def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):\n        \"\"\"Prediction step - always profile.\n\n        Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.prediction_step\n        \"\"\"\n        return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys)\n\n    # ========================================================================\n    # Pattern 2: Fine-grained Context Manager Profiling\n    # ========================================================================\n    # Best for: Profiling specific code blocks within a method\n    # Use case: When you want to profile forward vs backward separately\n\n    def complex_training_step(self, model, inputs):\n        \"\"\"Training step with fine-grained profiling.\n\n        Profiling metrics:\n        - profiling/Time taken: CustomTrainerWithProfiling.forward_pass\n        - profiling/Time taken: CustomTrainerWithProfiling.backward_pass\n        - profiling/Time taken: CustomTrainerWithProfiling.optimizer_step\n        \"\"\"\n        # Profile just the forward pass\n        with swanlab_profiling_context(self, \"forward_pass\"):\n            outputs = model(**inputs)\n            loss = outputs.loss\n\n        # Profile just the backward pass\n        with swanlab_profiling_context(self, \"backward_pass\"):\n            loss.backward()\n\n        # Profile optimizer step\n        with swanlab_profiling_context(self, \"optimizer_step\"):\n            self.optimizer.step()\n            self.optimizer.zero_grad()\n\n        return outputs\n\n    # ========================================================================\n    # Pattern 3: Advanced Profiling with Filtering\n    # ========================================================================\n    # Best for: High-frequency operations where you want to throttle logging\n    # Use case: Methods called 100+ times per step\n\n    def _prepare_inputs(self, inputs):\n        \"\"\"Prepare inputs - throttled profiling.\n\n        This method is called frequently (once per batch), so we throttle\n        profiling to reduce overhead:\n        - Only log if duration > 0.5ms (skip very fast operations)\n        - Only log every 50th call (reduce logging frequency)\n\n        Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.prepare_inputs\n        \"\"\"\n        with swanlab_profiling_context_advanced(\n            self, \"prepare_inputs\", config=self.fast_op_config\n        ):\n            return super()._prepare_inputs(inputs)\n\n    def _prepare_input_for_model(self, input_ids):\n        \"\"\"Another high-frequency operation - throttled profiling.\n\n        Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.prepare_input_for_model\n        \"\"\"\n        with swanlab_profiling_context_advanced(\n            self, \"prepare_input_for_model\", config=self.fast_op_config\n        ):\n            # Your custom input preparation logic\n            return input_ids\n\n    # ========================================================================\n    # Pattern 4: Exception-safe Profiling\n    # ========================================================================\n    # Profiling is exception-safe: duration is logged even if method raises\n\n    @swanlab_profile\n    def potentially_failing_method(self):\n        \"\"\"This method may raise an exception.\n\n        SwanLab profiling will still log the duration before re-raising.\n        Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.potentially_failing_method\n        \"\"\"\n        # Do some work\n        result = self._do_risky_computation()\n\n        # If this raises, profiling duration is still logged\n        if result < 0:\n            raise ValueError(\"Invalid result\")\n\n        return result\n\n    def _do_risky_computation(self):\n        \"\"\"Placeholder for risky computation.\"\"\"\n        return 42\n\n\n# ============================================================================\n# Advanced Example: Custom ProfilingConfig Per Method\n# ============================================================================\n\n\nclass AdvancedProfilingTrainer(AxolotlTrainer):\n    \"\"\"Trainer with method-specific profiling configurations.\"\"\"\n\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n\n        # Different profiling configs for different method types\n        self.critical_path_config = ProfilingConfig(\n            enabled=True,\n            min_duration_ms=0.0,  # Log everything on critical path\n            log_interval=1,  # Log every call\n        )\n\n        self.fast_path_config = ProfilingConfig(\n            enabled=True,\n            min_duration_ms=1.0,  # Only log if > 1ms\n            log_interval=100,  # Log every 100th call\n        )\n\n        self.debug_config = ProfilingConfig(\n            enabled=True,\n            min_duration_ms=0.0,  # Log everything\n            log_interval=1,  # Log every call\n        )\n\n    def training_step(self, model, inputs):\n        \"\"\"Critical path - log everything.\"\"\"\n        with swanlab_profiling_context_advanced(\n            self, \"training_step\", config=self.critical_path_config\n        ):\n            return super().training_step(model, inputs)\n\n    def _prepare_inputs(self, inputs):\n        \"\"\"Fast path - throttle logging.\"\"\"\n        with swanlab_profiling_context_advanced(\n            self, \"prepare_inputs\", config=self.fast_path_config\n        ):\n            return super()._prepare_inputs(inputs)\n\n    def _debug_method(self, data):\n        \"\"\"Debug-only method - verbose logging.\"\"\"\n        with swanlab_profiling_context_advanced(\n            self, \"debug_method\", config=self.debug_config\n        ):\n            # Your debug logic\n            pass\n\n\n# ============================================================================\n# How to Use This Custom Trainer\n# ============================================================================\n\n\"\"\"\nTo use this custom trainer:\n\n1. Save this file to your project (e.g., my_custom_trainer.py)\n\n2. Create a config file that uses your custom trainer:\n\n    # config.yml\n    base_model: NousResearch/Llama-3.2-1B\n\n    # ... other config ...\n\n    plugins:\n      - axolotl.integrations.swanlab.SwanLabPlugin\n\n    use_swanlab: true\n    swanlab_project: my-profiling-experiment\n\n    # Optional: Specify custom trainer\n    # (Or modify axolotl to use your custom trainer class)\n\n3. Run training:\n\n    export SWANLAB_API_KEY=your-api-key\n    accelerate launch -m axolotl.cli.train config.yml\n\n4. View profiling metrics in SwanLab dashboard:\n   - profiling/Time taken: CustomTrainerWithProfiling.training_step\n   - profiling/Time taken: CustomTrainerWithProfiling.forward_pass\n   - profiling/Time taken: CustomTrainerWithProfiling.backward_pass\n   - etc.\n\n5. Compare profiling metrics across runs:\n   - Run baseline without optimizations\n   - Run with flash_attention enabled\n   - Run with gradient_checkpointing enabled\n   - Compare profiling metrics to see performance impact\n\"\"\"\n\n# ============================================================================\n# Tips for Effective Profiling\n# ============================================================================\n\n\"\"\"\n1. Profile the critical path first:\n   - training_step, compute_loss, prediction_step\n   - These methods are called most frequently and have biggest impact\n\n2. Use throttling for high-frequency operations:\n   - Methods called 100+ times per step\n   - Use log_interval=50 or log_interval=100\n   - Reduces profiling overhead and dashboard clutter\n\n3. Filter noise with min_duration_ms:\n   - Set min_duration_ms=1.0 to skip very fast operations\n   - Focus on operations that actually take time\n\n4. Compare across runs:\n   - Run same config multiple times to check consistency\n   - Compare different optimization strategies\n   - Track profiling trends over time\n\n5. Monitor distributed training:\n   - Check for per-rank timing differences\n   - Look for stragglers (slower ranks)\n   - Identify synchronization bottlenecks\n\n6. Disable profiling in production:\n   - from axolotl.integrations.swanlab.profiling import DEFAULT_PROFILING_CONFIG\n   - DEFAULT_PROFILING_CONFIG.enabled = False\n\n7. Exception handling:\n   - Profiling is exception-safe\n   - Duration logged even if method raises\n   - Useful for debugging methods that fail intermittently\n\"\"\"\n"
  },
  {
    "path": "examples/swanlab/dpo-swanlab-completions.yml",
    "content": "# SwanLab DPO Training Example with Completion Logging\n#\n# This example demonstrates DPO (Direct Preference Optimization) training\n# with SwanLab integration for experiment tracking and completion table logging.\n#\n# Features enabled:\n# - SwanLab experiment tracking\n# - RLHF completion table logging (prompts, chosen/rejected responses, rewards)\n# - Lark (Feishu) team notifications (optional)\n#\n# To run:\n#   export SWANLAB_API_KEY=your-api-key\n#   accelerate launch -m axolotl.cli.train examples/swanlab/dpo-swanlab-completions.yml\n\n# Model Configuration\nbase_model: meta-llama/Meta-Llama-3-8B-Instruct\nmodel_type: LlamaForCausalLM\ntokenizer_type: AutoTokenizer\n\nspecial_tokens:\n  pad_token: <|finetune_right_pad_id|>\n  eos_token: <|eot_id|>\n\n# Quantization\nload_in_8bit: true\nload_in_4bit: false\n\n# LoRA Configuration\nadapter: lora\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_linear: true\n\n# DPO Configuration\nchat_template: llama3\nrl: dpo\n\ndatasets:\n  - path: fozziethebeat/alpaca_messages_2k_dpo_test\n    type: chat_template.default\n    field_messages: conversation\n    field_chosen: chosen\n    field_rejected: rejected\n    message_property_mappings:\n      role: role\n      content: content\n    roles:\n      system:\n        - system\n      user:\n        - user\n      assistant:\n        - assistant\n\n# Dataset and Output\ndataset_prepared_path:\nval_set_size: 0.05\noutput_dir: ./outputs/dpo-swanlab-out\n\n# Training Configuration\nsequence_len: 4096\nsample_packing: false\nmicro_batch_size: 2\ngradient_accumulation_steps: 4\nnum_epochs: 4\n\n# Optimization\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\nwarmup_ratio: 0.1\nweight_decay: 0.0\n\n# Precision\nbf16: auto\ntf32: false\n\n# Performance\ngradient_checkpointing: true\nflash_attention: true\n\n# Checkpointing and Logging\nlogging_steps: 1\nevals_per_epoch: 4\nsaves_per_epoch: 1\n\n# ============================================================================\n# SwanLab Integration\n# ============================================================================\n\nplugins:\n  - axolotl.integrations.swanlab.SwanLabPlugin\n\n# Basic SwanLab Configuration\nuse_swanlab: true\nswanlab_project: dpo-training\nswanlab_experiment_name: llama-3-dpo-completions-demo\nswanlab_description: \"DPO training with completion table logging\"\nswanlab_mode: cloud  # Options: cloud, local, offline, disabled\n\n# SwanLab Authentication\n# Recommended: Set via environment variable\n#   export SWANLAB_API_KEY=your-api-key\n# Or set in config (less secure):\n# swanlab_api_key: your-api-key\n\n# Optional: Team workspace\n# swanlab_workspace: my-research-team\n\n# ============================================================================\n# RLHF Completion Table Logging\n# ============================================================================\n#\n# Automatically logs model completions to SwanLab for qualitative analysis:\n# - Prompts from your DPO dataset\n# - Chosen responses (preferred)\n# - Rejected responses (non-preferred)\n# - Reward differences\n#\n# View the table in SwanLab dashboard under \"rlhf_completions\"\n\nswanlab_log_completions: true\nswanlab_completion_log_interval: 100  # Log every 100 training steps\nswanlab_completion_max_buffer: 128    # Keep last 128 completions in memory\n\n# Memory Usage Notes:\n# - Buffer size 128: ~64 KB (default, recommended)\n# - Buffer size 512: ~256 KB (for more historical completions)\n# - Buffer size 1024: ~512 KB (maximum for very long training runs)\n\n# Performance Notes:\n# - Completion logging overhead: < 0.5% per training step\n# - Only logs every N steps to minimize impact\n# - Memory-bounded buffer prevents memory leaks\n\n# ============================================================================\n# Optional: Lark (Feishu) Team Notifications\n# ============================================================================\n#\n# Get real-time training notifications in your team chat\n# Uncomment to enable:\n\n# swanlab_lark_webhook_url: https://open.feishu.cn/open-apis/bot/v2/hook/xxxxxxxxxx\n# swanlab_lark_secret: your-webhook-secret  # Recommended for production\n\n# Notifications sent for:\n# - Training start\n# - Training completion\n# - Training errors\n# - Metric milestones (if configured)\n\n# ============================================================================\n# Optional: Private SwanLab Deployment\n# ============================================================================\n#\n# For enterprise users with private SwanLab deployment:\n\n# swanlab_web_host: https://swanlab.yourcompany.com\n# swanlab_api_host: https://api.swanlab.yourcompany.com\n\n# ============================================================================\n# Disable WandB if you're migrating from it\n# ============================================================================\n\n# wandb_project:\n# wandb_entity:\n# use_wandb: false\n"
  },
  {
    "path": "examples/swanlab/dpo-swanlab-full-featured.yml",
    "content": "# SwanLab Full-Featured DPO Training Example\n#\n# This example demonstrates ALL SwanLab integration features:\n# - Experiment tracking with cloud sync\n# - RLHF completion table logging\n# - Performance profiling\n# - Lark (Feishu) team notifications\n# - Team workspace collaboration\n#\n# Use this as a reference for production RLHF training setups.\n#\n# To run:\n#   export SWANLAB_API_KEY=your-api-key\n#   export SWANLAB_LARK_WEBHOOK_URL=https://open.feishu.cn/...\n#   export SWANLAB_LARK_SECRET=your-webhook-secret\n#   accelerate launch -m axolotl.cli.train examples/swanlab/dpo-swanlab-full-featured.yml\n\n# ============================================================================\n# Model Configuration\n# ============================================================================\n\nbase_model: meta-llama/Meta-Llama-3-8B-Instruct\nmodel_type: LlamaForCausalLM\ntokenizer_type: AutoTokenizer\n\nspecial_tokens:\n  pad_token: <|finetune_right_pad_id|>\n  eos_token: <|eot_id|>\n\n# Quantization for efficient training\nload_in_8bit: true\nload_in_4bit: false\n\n# ============================================================================\n# LoRA Configuration\n# ============================================================================\n\nadapter: lora\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_linear: true  # Target all linear layers\n\n# ============================================================================\n# DPO (Direct Preference Optimization) Configuration\n# ============================================================================\n\nchat_template: llama3\nrl: dpo  # Enable DPO trainer\n\ndatasets:\n  - path: fozziethebeat/alpaca_messages_2k_dpo_test\n    type: chat_template.default\n    field_messages: conversation\n    field_chosen: chosen\n    field_rejected: rejected\n    message_property_mappings:\n      role: role\n      content: content\n    roles:\n      system:\n        - system\n      user:\n        - user\n      assistant:\n        - assistant\n\n# ============================================================================\n# Dataset and Output Configuration\n# ============================================================================\n\ndataset_prepared_path:\nval_set_size: 0.05\noutput_dir: ./outputs/dpo-swanlab-full-featured-out\n\n# ============================================================================\n# Training Configuration\n# ============================================================================\n\nsequence_len: 4096\nsample_packing: false\n\nmicro_batch_size: 2\ngradient_accumulation_steps: 4\nnum_epochs: 4\n\n# ============================================================================\n# Optimization\n# ============================================================================\n\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\nwarmup_ratio: 0.1\nweight_decay: 0.0\n\n# ============================================================================\n# Precision and Performance\n# ============================================================================\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\nflash_attention: true\n\n# ============================================================================\n# Checkpointing and Logging\n# ============================================================================\n\nlogging_steps: 1\nevals_per_epoch: 4\nsaves_per_epoch: 1\n\n# ============================================================================\n# SwanLab Integration - Full Configuration\n# ============================================================================\n\nplugins:\n  - axolotl.integrations.swanlab.SwanLabPlugin\n\n# ------------------------------------------------------------------------------\n# Basic SwanLab Configuration\n# ------------------------------------------------------------------------------\n\nuse_swanlab: true\nswanlab_project: dpo-production\nswanlab_experiment_name: llama-3-dpo-full-featured-v1\nswanlab_description: |\n  Production DPO training with all SwanLab features enabled:\n  - Completion table logging for qualitative analysis\n  - Performance profiling for optimization\n  - Lark notifications for team collaboration\n\nswanlab_mode: cloud  # Options: cloud, local, offline, disabled\n\n# ------------------------------------------------------------------------------\n# Team Collaboration\n# ------------------------------------------------------------------------------\n\n# Workspace for team collaboration (shared experiments)\nswanlab_workspace: ml-research-team\n\n# Authentication (recommended: use environment variable)\n#   export SWANLAB_API_KEY=your-api-key\n# Or set in config (less secure):\n# swanlab_api_key: your-api-key\n\n# ------------------------------------------------------------------------------\n# RLHF Completion Table Logging\n# ------------------------------------------------------------------------------\n# Automatically logs model completions for qualitative analysis:\n# - Prompts from your DPO dataset\n# - Chosen responses (preferred)\n# - Rejected responses (non-preferred)\n# - Reward differences\n#\n# View in SwanLab dashboard under \"rlhf_completions\" table\n\nswanlab_log_completions: true\nswanlab_completion_log_interval: 100  # Log every 100 steps\nswanlab_completion_max_buffer: 256    # Larger buffer for long training runs\n\n# Buffer size recommendations:\n# - 128: Default, ~64 KB memory (recommended for most cases)\n# - 256: ~128 KB memory (this config, good for longer training)\n# - 512: ~256 KB memory (maximum for very long runs)\n\n# ------------------------------------------------------------------------------\n# Lark (Feishu) Team Notifications\n# ------------------------------------------------------------------------------\n# Get real-time training notifications in your team chat\n#\n# Notifications sent for:\n# - Training start\n# - Training completion\n# - Training errors\n# - Metric milestones (if configured)\n\n# Recommended: Set via environment variables\n#   export SWANLAB_LARK_WEBHOOK_URL=https://open.feishu.cn/...\n#   export SWANLAB_LARK_SECRET=your-webhook-secret\n\n# Or set in config (less secure):\n# swanlab_lark_webhook_url: https://open.feishu.cn/open-apis/bot/v2/hook/xxxxxxxxxx\n# swanlab_lark_secret: your-webhook-secret  # REQUIRED for production\n\n# Security note: ALWAYS use swanlab_lark_secret in production to prevent\n# unauthorized parties from sending fake notifications to your team chat.\n\n# ------------------------------------------------------------------------------\n# Performance Profiling\n# ------------------------------------------------------------------------------\n# Profiling is automatically enabled when SwanLab is enabled.\n# Metrics logged to SwanLab under \"profiling/\" namespace:\n#   profiling/Time taken: AxolotlTrainer.training_step\n#   profiling/Time taken: AxolotlTrainer.compute_loss\n#   profiling/Time taken: AxolotlTrainer.prediction_step\n#\n# Use these metrics to:\n# - Identify bottlenecks in training loop\n# - Compare performance across different configurations\n# - Monitor performance regressions over time\n# - Debug unexpected slowdowns\n\n# For custom profiling in your own trainer, see:\n#   examples/swanlab/custom_trainer_profiling.py\n\n# ------------------------------------------------------------------------------\n# Optional: Private SwanLab Deployment\n# ------------------------------------------------------------------------------\n# For enterprise users with private SwanLab deployment:\n\n# swanlab_web_host: https://swanlab.yourcompany.com\n# swanlab_api_host: https://api.swanlab.yourcompany.com\n\n# ------------------------------------------------------------------------------\n# Optional: Model Checkpointing to SwanLab\n# ------------------------------------------------------------------------------\n# Log model checkpoints to SwanLab (coming soon)\n\nswanlab_log_model: false\n\n# ============================================================================\n# Disable Other Logging Tools (Recommended)\n# ============================================================================\n# Using multiple logging tools simultaneously can impact performance:\n# - Expected overhead: ~1-2% per logger\n# - Potential config/callback conflicts\n#\n# For production training, use ONLY SwanLab:\n\n# wandb_project:\n# use_wandb: false\n#\n# use_mlflow: false\n#\n# use_comet: false\n\n# ============================================================================\n# Expected Training Behavior\n# ============================================================================\n\n# With this configuration, you should see:\n#\n# 1. SwanLab Initialization (rank 0 only):\n#    INFO: SwanLab initialized for project: dpo-production\n#    INFO: SwanLab experiment: llama-3-dpo-full-featured-v1\n#    INFO: SwanLab mode: cloud\n#    INFO: SwanLab workspace: ml-research-team\n#\n# 2. Completion Logging (rank 0 only):\n#    INFO: Registered SwanLab RLHF completion logging callback for DPOTrainer\n#          (log_interval=100, max_buffer=256)\n#\n# 3. Lark Notifications (rank 0 only):\n#    INFO: Registered Lark notification callback with HMAC authentication\n#\n# 4. Distributed Training Detection (if multi-GPU):\n#    INFO: Distributed training detected (world_size=N)\n#    INFO: Only rank 0 will initialize SwanLab\n#    INFO: Other ranks will skip SwanLab to avoid conflicts\n#\n# 5. Training Start Notification (Lark):\n#    Your team chat receives: \"Training started: llama-3-dpo-full-featured-v1\"\n#\n# 6. Periodic Completion Logging:\n#    Every 100 steps, completion table is updated in SwanLab dashboard\n#\n# 7. Training Complete Notification (Lark):\n#    Your team chat receives: \"Training completed: llama-3-dpo-full-featured-v1\"\n#    With link to SwanLab dashboard and final metrics\n#\n# 8. SwanLab Dashboard Shows:\n#    - Training metrics (loss, learning rate, etc.)\n#    - Completion table (rlhf_completions)\n#    - Profiling metrics (profiling/Time taken: ...)\n#    - Hyperparameters and configuration\n#    - System resource usage\n\n# ============================================================================\n# Production Checklist\n# ============================================================================\n\n# Before deploying to production, verify:\n# ✅ SwanLab API key is set via environment variable (not in config)\n# ✅ Lark webhook secret is set (required for HMAC authentication)\n# ✅ Workspace is set to your team's workspace\n# ✅ Experiment name is descriptive and unique\n# ✅ Only SwanLab is enabled (other loggers disabled)\n# ✅ Completion logging buffer size is appropriate for your training duration\n# ✅ Private deployment hosts are set (if using enterprise SwanLab)\n# ✅ Test run completes successfully and shows up in SwanLab dashboard\n# ✅ Lark notifications are received in team chat\n# ✅ Profiling metrics are logged correctly\n\n# ============================================================================\n# Troubleshooting\n# ============================================================================\n\n# If SwanLab initialization fails:\n# 1. Check SWANLAB_API_KEY environment variable is set\n# 2. Verify swanlab_project is set in config\n# 3. Check swanlab_mode is valid (cloud/local/offline/disabled)\n# 4. Verify internet connectivity (for cloud mode)\n\n# If Lark notifications not received:\n# 1. Check SWANLAB_LARK_WEBHOOK_URL is set correctly\n# 2. Verify SWANLAB_LARK_SECRET matches your Lark bot settings\n# 3. Test webhook manually: curl -X POST \"$SWANLAB_LARK_WEBHOOK_URL\" ...\n# 4. Check training logs for \"Registered Lark notification callback\"\n# 5. Verify bot is added to the target Lark group chat\n\n# If completions not appearing in SwanLab:\n# 1. Verify you're using an RLHF trainer (DPO/KTO/ORPO/GRPO)\n# 2. Check swanlab_log_completions is true\n# 3. Wait for log_interval steps (default: 100)\n# 4. Check training logs for \"Registered SwanLab RLHF completion logging\"\n\n# If profiling metrics not appearing:\n# 1. Verify use_swanlab is true\n# 2. Check SwanLab is initialized (check logs)\n# 3. Look under \"profiling/\" namespace in dashboard\n# 4. Profiling may be disabled if DEFAULT_PROFILING_CONFIG.enabled = False\n\n# For more help:\n# - SwanLab docs: https://docs.swanlab.cn\n# - Axolotl SwanLab integration: src/axolotl/integrations/swanlab/README.md\n# - GitHub issues: https://github.com/axolotl-ai-cloud/axolotl/issues\n"
  },
  {
    "path": "examples/swanlab/lora-swanlab-profiling.yml",
    "content": "# SwanLab LoRA Training Example with Performance Profiling\n#\n# This example demonstrates standard LoRA fine-tuning with SwanLab integration\n# for performance profiling and optimization.\n#\n# Features enabled:\n# - SwanLab experiment tracking\n# - Performance profiling (training step, forward/backward pass timing)\n# - Real-time metrics visualization\n#\n# To run:\n#   export SWANLAB_API_KEY=your-api-key\n#   accelerate launch -m axolotl.cli.train examples/swanlab/lora-swanlab-profiling.yml\n\n# Model Configuration\nbase_model: NousResearch/Llama-3.2-1B\n\n# Dataset Configuration\ndatasets:\n  - path: teknium/GPT4-LLM-Cleaned\n    type: alpaca\n\nval_set_size: 0.1\noutput_dir: ./outputs/lora-swanlab-profiling-out\n\n# LoRA Configuration\nadapter: lora\nlora_r: 16\nlora_alpha: 32\nlora_dropout: 0.05\nlora_target_modules:\n  - gate_proj\n  - down_proj\n  - up_proj\n  - q_proj\n  - v_proj\n  - k_proj\n  - o_proj\n\n# Training Configuration\nsequence_len: 2048\nsample_packing: true\neval_sample_packing: true\n\nmicro_batch_size: 2\ngradient_accumulation_steps: 2\nnum_epochs: 1\n\n# Optimization\noptimizer: adamw_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\nwarmup_ratio: 0.1\nweight_decay: 0.0\n\n# Precision\nbf16: auto\ntf32: false\n\n# Performance\ngradient_checkpointing: true\nflash_attention: true\n\n# Checkpointing and Logging\nlogging_steps: 1\nevals_per_epoch: 4\nsaves_per_epoch: 1\n\n# Loss Monitoring\nloss_watchdog_threshold: 5.0\nloss_watchdog_patience: 3\n\nspecial_tokens:\n  pad_token: \"<|end_of_text|>\"\n\n# ============================================================================\n# SwanLab Integration\n# ============================================================================\n\nplugins:\n  - axolotl.integrations.swanlab.SwanLabPlugin\n\n# Basic SwanLab Configuration\nuse_swanlab: true\nswanlab_project: lora-profiling\nswanlab_experiment_name: llama-3.2-1b-profiling-demo\nswanlab_description: \"LoRA fine-tuning with performance profiling\"\nswanlab_mode: cloud  # Options: cloud, local, offline, disabled\n\n# SwanLab Authentication\n# Recommended: Set via environment variable\n#   export SWANLAB_API_KEY=your-api-key\n# Or set in config (less secure):\n# swanlab_api_key: your-api-key\n\n# Optional: Team workspace\n# swanlab_workspace: my-ml-team\n\n# ============================================================================\n# Performance Profiling\n# ============================================================================\n#\n# SwanLab automatically profiles trainer methods when enabled.\n# Profiling metrics appear in SwanLab dashboard under \"profiling/\" namespace.\n#\n# Built-in profiling:\n# - Minimal overhead (< 0.1% per step)\n# - High-precision timing (microsecond accuracy)\n# - Exception-safe (logs duration even if method fails)\n#\n# View profiling metrics in SwanLab dashboard:\n#   profiling/Time taken: AxolotlTrainer.training_step\n#   profiling/Time taken: AxolotlTrainer.compute_loss\n#   profiling/Time taken: AxolotlTrainer.prediction_step\n#\n# For custom profiling in your own trainer, see:\n#   examples/swanlab/custom_trainer_profiling.py\n\n# Completion logging is disabled for non-RLHF trainers\nswanlab_log_completions: false  # Only works with DPO/KTO/ORPO/GRPO\n\n# ============================================================================\n# Optional: Compare with Multiple Runs\n# ============================================================================\n#\n# To compare profiling metrics across different configurations:\n#\n# 1. Run baseline without flash attention:\n#    swanlab_experiment_name: llama-3.2-1b-no-flash-attn\n#    flash_attention: false\n#\n# 2. Run with gradient checkpointing:\n#    swanlab_experiment_name: llama-3.2-1b-grad-checkpoint\n#    gradient_checkpointing: true\n#\n# 3. Run with both:\n#    swanlab_experiment_name: llama-3.2-1b-optimized\n#    flash_attention: true\n#    gradient_checkpointing: true\n#\n# Then compare profiling metrics in SwanLab dashboard to see performance impact\n\n# ============================================================================\n# Optional: Lark (Feishu) Team Notifications\n# ============================================================================\n#\n# Get notified when profiling experiments complete:\n\n# swanlab_lark_webhook_url: https://open.feishu.cn/open-apis/bot/v2/hook/xxxxxxxxxx\n# swanlab_lark_secret: your-webhook-secret\n\n# ============================================================================\n# Profiling Best Practices\n# ============================================================================\n#\n# 1. Run multiple epochs to see profiling trends over time\n# 2. Ignore first ~10 steps (warmup period, slower)\n# 3. Look for outliers (steps that take significantly longer)\n# 4. Compare profiling metrics before/after optimization changes\n# 5. Monitor per-rank profiling in distributed training\n#\n# Common bottlenecks to profile:\n# - training_step: Overall step time (should be consistent)\n# - compute_loss: Loss computation (scales with sequence length)\n# - prediction_step: Evaluation time (can be slow for large val sets)\n#\n# If you see inconsistent timing:\n# - Check for data loading bottlenecks\n# - Monitor GPU utilization (may be CPU-bound)\n# - Check for gradient accumulation effects\n# - Verify CUDA kernel synchronization\n\n# ============================================================================\n# Disable WandB if you're migrating from it\n# ============================================================================\n\n# wandb_project:\n# use_wandb: false\n"
  },
  {
    "path": "examples/trinity/README.md",
    "content": "# Finetune ArceeAI's Trinity with Axolotl\n\n[Trinity](https://huggingface.co/collections/arcee-ai/trinity) is a family of open weight MoE models trained by Arcee.ai.\n\nThis guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.\n\n## Getting started\n\n1. Install Axolotl following the main from the [installation guide](https://docs.axolotl.ai/docs/installation.html#sec-edge-build).\n\n2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.\n\n3. Run the finetuning example:\n\n    ```bash\n    axolotl train examples/trinity/trinity-nano-preview-qlora.yaml\n    ```\n\nThis config uses about 24.9 GiB VRAM (w/o CCE).\n\nLet us know how it goes. Happy finetuning! 🚀\n\n### TIPS\n\n- For inference, the official Arcee.ai team recommends `top_p: 0.75`, `temperature: 0.15`, `top_k: 50`, and `min_p: 0.06`.\n- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.\n- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).\n- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).\n\n## Optimization Guides\n\nPlease check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).\n\n## Related Resources\n\n- [Trinity Blog](https://www.arcee.ai/blog/the-trinity-manifesto)\n- [Axolotl Docs](https://docs.axolotl.ai)\n- [Axolotl Website](https://axolotl.ai)\n- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)\n- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)\n"
  },
  {
    "path": "examples/trinity/trinity-nano-preview-qlora.yaml",
    "content": "base_model: arcee-ai/Trinity-Nano-Preview\nrevision_of_model: 2ee94b0\n\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\n# CCE - N/A as of now\n# plugins:\n#   - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin\n\nload_in_8bit: false\nload_in_4bit: true\n\ndatasets:\n  - path: fozziethebeat/alpaca_messages_2k_test\n    type: chat_template\n\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.1\noutput_dir: ./outputs/lora-out\n\nadapter: qlora\nlora_model_dir:\n\nsequence_len: 2048\nsample_packing: true\n\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_linear: true\nlora_target_modules:\n  - gate_proj\n  - down_proj\n  - up_proj\n  - q_proj\n  - v_proj\n  - k_proj\n  - o_proj\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 2\nnum_epochs: 1\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: false\n\ngradient_checkpointing: true\nresume_from_checkpoint:\nlogging_steps: 1\n# flash_attention: true  # Not supported\nsdp_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 1\nsaves_per_epoch: 1\n\n# save_first_step: true  # uncomment this to validate checkpoint saving works with your config\n"
  },
  {
    "path": "examples/voxtral/README.md",
    "content": "# Finetune Voxtral with Axolotl\n\nVoxtral is a [3B](https://huggingface.co/mistralai/Voxtral-Mini-3B-2507)/[24B](https://huggingface.co/mistralai/Voxtral-Small-24B-2507) parameter opensource model from MistralAI found on HuggingFace. This guide shows how to fine-tune it with Axolotl.\n\nThanks to the team at MistralAI for giving us early access to prepare for this release.\n\n## Getting started\n\n1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).\n\n    Here is an example of how to install from pip:\n\n```bash\n# Ensure you have Pytorch installed (Pytorch 2.6.0 min)\npip3 install packaging==26.0 setuptools==75.8.0 wheel ninja\npip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'\n```\n\n2. Please install the below.\n\n```bash\n# audio\npip3 install librosa==0.11.0\npip3 install 'mistral_common[audio]==1.8.3'\n\n# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy\npython scripts/cutcrossentropy_install.py | sh\n```\n\n3. Download sample dataset files\n\n```bash\n# for text + audio only\nwget https://huggingface.co/datasets/Nanobit/text-audio-2k-test/resolve/main/En-us-African_elephant.oga\n```\n\n4. Run the finetuning example:\n\n```bash\n# text only\naxolotl train examples/voxtral/voxtral-mini-qlora.yml\n\n# text + audio\naxolotl train examples/voxtral/voxtral-mini-audio-qlora.yml\n```\n\nThese configs use about 4.8 GB VRAM.\n\nLet us know how it goes. Happy finetuning! 🚀\n\n### TIPS\n\n- For inference, the official MistralAI team recommends `temperature: 0.2` and `top_p: 0.95` for audio understanding and `temperature: 0.0` for transcription.\n- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.\n- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).\n- The text dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).\n- The multimodal dataset format follows the OpenAI multi-content Messages format as seen [here](https://docs.axolotl.ai/docs/multimodal.html#dataset-format).\n\n\n## Optimization Guides\n\n- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)\n- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)\n- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)\n\n## Limitations\n\nWe only support the `mistral-common` tokenizer for Supervised Fine-tuning at the moment and for `type: chat_template` only.\n\nIn addition, we do not support overriding tokens yet.\n\n## Related Resources\n\n- [MistralAI Magistral Blog](https://mistral.ai/news/magistral/)\n- [Axolotl Docs](https://docs.axolotl.ai)\n- [Axolotl Website](https://axolotl.ai)\n- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)\n- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)\n\n## Future Work\n\n- Add parity to Preference Tuning, RL, etc.\n- Add parity to other tokenizer configs like overriding tokens.\n"
  },
  {
    "path": "examples/voxtral/voxtral-mini-audio-qlora.yml",
    "content": "base_model: mistralai/Voxtral-Mini-3B-2507\nprocessor_type: VoxtralProcessor\n\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\n# Enable to use mistral-common tokenizer\ntokenizer_use_mistral_common: true\n\nplugins:\n  - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin\n\n# for use with fft to only train on language model layers\n# unfrozen_parameters:\n  # - language_model.model.*\n  # - lm_head\n  # - embed_tokens\n\nload_in_4bit: true\n\n# these 3 lines are needed for now to handle vision chat templates w images\nskip_prepare_dataset: true\nremove_unused_columns: false\nsample_packing: false\n\n# gemma3 doesn't seem to play nice with ddp\nddp_find_unused_parameters: true\n\neot_tokens:\n  - <end_of_turn>\n\n# sample dataset below requires downloading audio/image in advance\n# wget https://huggingface.co/datasets/Nanobit/text-audio-2k-test/resolve/main/En-us-African_elephant.oga\ndatasets:\n  - path: NanoBit/text-audio-2k-test\n    type: chat_template\ndataset_prepared_path:\nval_set_size: 0.01\noutput_dir: ./outputs/out\n\nadapter: qlora\nlora_model_dir:\n\nsequence_len: 2048\npad_to_sequence_len: false\n\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_modules: 'language_model.model.layers.[\\d]+.(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 2\nnum_epochs: 1\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: true\nfp16:\ntf32: true\n\ngradient_checkpointing: true\ngradient_checkpointing_kwargs:\n  use_reentrant: false\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch: 1\nsaves_per_epoch: 1\nweight_decay: 0.0\n"
  },
  {
    "path": "examples/voxtral/voxtral-mini-qlora.yml",
    "content": "base_model: mistralai/Voxtral-Mini-3B-2507\n\n# Automatically upload checkpoint and final model to HF\n# hub_model_id: username/custom_model_name\n\n# Enable to use mistral-common tokenizer\ntokenizer_use_mistral_common: true\n\nplugins:\n  - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin\n\nload_in_8bit: false\nload_in_4bit: true\n\n# for use with fft to only train on language model layers\n# unfrozen_parameters:\n  # - language_model.model.*\n  # - lm_head\n  # - embed_tokens\n\neot_tokens:\n  - <end_of_turn>\ndatasets:\n  - path: cgato/SlimOrcaDedupCleaned\n    type: chat_template\n    split: train[:1%]\n    field_messages: conversations\n    message_property_mappings:\n      role: from\n      content: value\n\nval_set_size: 0.0\noutput_dir: ./outputs/out\n\nadapter: qlora\nlora_r: 32\nlora_alpha: 16\nlora_dropout: 0.05\nlora_target_modules: 'language_model.model.layers.[\\d]+.(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'\n\nsequence_len: 2048\nsample_packing: true\neval_sample_packing: true\npad_to_sequence_len: true\n\nwandb_project:\nwandb_entity:\nwandb_watch:\nwandb_name:\nwandb_log_model:\n\ngradient_accumulation_steps: 1\nmicro_batch_size: 1\nnum_epochs: 4\noptimizer: adamw_bnb_8bit\nlr_scheduler: cosine\nlearning_rate: 0.0002\n\nbf16: auto\ntf32: true\n\ngradient_checkpointing: true\ngradient_checkpointing_kwargs:\n  use_reentrant: false\nresume_from_checkpoint:\nlogging_steps: 1\nflash_attention: true\n\nwarmup_ratio: 0.1\nevals_per_epoch:\nsaves_per_epoch: 1\nweight_decay: 0.0\nspecial_tokens:\n"
  },
  {
    "path": "index.qmd",
    "content": "---\n# toc-location: right-body\n# toc-title: Table Of Contents\n# toc-expand: 2\n---\n\n```{python}\n#|output: asis\n#|echo: false\n\n# This cell steals the README as the home page for now, but excludes the table of contents (quarto adds its own)\nimport re\npattern = re.compile(\n    r\"<table>\\s*<tr>\\s*<td>\\s*## Table of Contents.*?</td>\\s*</tr>\\s*</table>\",\n    re.DOTALL | re.IGNORECASE\n)\n\nwith open('README.md', 'r') as f:\n    txt = f.read()\n\ncleaned = pattern.sub(\"\", txt)\nprint(cleaned)\n```\n"
  },
  {
    "path": "pyproject.toml",
    "content": "[build-system]\nrequires = [\"setuptools>=64\", \"wheel\", \"setuptools_scm>=8\", \"packaging==26.0\"]\nbuild-backend = \"setuptools.build_meta\"\n\n[project]\nname = \"axolotl\"\ndynamic = [\"version\", \"dependencies\", \"optional-dependencies\"]\ndescription = \"LLM Trainer\"\nreadme = \"README.md\"\nrequires-python = \">=3.10\"\n# license = \"Apache-2.0\"\n\n[project.scripts]\naxolotl = \"axolotl.cli.main:main\"\n\n[project.urls]\nHomepage = \"https://axolotl.ai/\"\nDocumentation = \"https://docs.axolotl.ai/\"\nRepository = \"https://github.com/axolotl-ai-cloud/axolotl.git\"\n\n[tool.setuptools_scm]\n\n[tool.setuptools]\npy-modules = [\"setuptools_axolotl_dynamic_dependencies\"]\ninclude-package-data = true\n\n[tool.setuptools.dynamic]\nversion = { file = \"VERSION\" }\n\n[tool.setuptools.cmdclass]\nbuild_py = \"setuptools_axolotl_dynamic_dependencies.BuildPyCommand\"\n\n[tool.ruff]\nline-length = 88\ntarget-version = \"py310\"\n\n[tool.ruff.lint]\nselect = [\"E\", \"F\", \"W\", \"C90\", \"B\", \"I\"]\nignore = [\n    \"E203\",  # Whitespace before ':'\n    \"E501\",  # Line too long\n    \"C901\",  # Too complex\n    \"B019\",  # Use of functools.cache on methods\n    \"E722\",  # Bare except\n    \"F821\",  # Undefined name (for dynamic exec)\n]\n\n[tool.ruff.lint.isort]\nknown-third-party = [\"wandb\", \"comet_ml\"]\nknown-local-folder = [\"src\", \"tests\"]\n# Black-compatible isort settings\nforce-single-line = false\ncombine-as-imports = true\nsplit-on-trailing-comma = true\n\n[tool.ruff.format]\n# Use black's formatting style exactly\nquote-style = \"double\"\nindent-style = \"space\"\nskip-magic-trailing-comma = false\nline-ending = \"auto\"\ndocstring-code-format = false\n\n[tool.uv.extra-build-dependencies]\naxolotl = [\"huggingface_hub\"]\n"
  },
  {
    "path": "requirements-dev.txt",
    "content": "black\nmypy\npre-commit\ntypes-requests\nquartodoc\njupyter\nblobfile\ntiktoken\n"
  },
  {
    "path": "requirements-tests.txt",
    "content": "codecov\ncodecov-cli\npytest\npytest-cov\npytest-retry\npytest-sugar\npytest-xdist\ntbparse\n"
  },
  {
    "path": "requirements.txt",
    "content": "--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/\n\n# START section of dependencies that don't install on Darwin/MacOS\nbitsandbytes==0.49.1\ntriton>=3.4.0\nmamba-ssm==1.2.0.post1\nxformers>=0.0.23.post1\nliger-kernel==0.7.0\n# END section\n\npackaging==26.0\nhuggingface_hub>=1.1.7\npeft>=0.18.1\ntokenizers>=0.22.1\ntransformers==5.3.0\naccelerate==1.13.0\ndatasets==4.5.0\ndeepspeed>=0.18.6,<0.19.0\ntrl==0.29.0\nhf_xet==1.3.2\nkernels==0.12.2\n\nfla-core==0.4.1\nflash-linear-attention==0.4.1\n\ntrackio>=0.16.1\ntyping-extensions>=4.15.0\n\noptimum==1.16.2\nhf_transfer\nsentencepiece\ngradio>=6.2.0,<7.0\n\nmodal==1.3.0.post1\npydantic>=2.10.6\naddict\nfire\nPyYAML>=6.0\nrequests\nwandb\neinops\ncolorama\nnumba>=0.61.2\nnumpy>=2.2.6\n\n# qlora things\nevaluate==0.4.1\nscipy\nnvidia-ml-py==12.560.30\nart\ntensorboard\npython-dotenv==1.0.1\n\n# remote filesystems\ns3fs>=2024.5.0\ngcsfs>=2025.3.0\nadlfs>=2024.5.0\nocifs==1.3.2\n\nzstandard==0.22.0\nfastcore\n\n# lm eval harness\nlm_eval==0.4.7\nlangdetect==1.0.9\nimmutabledict==4.2.0\nantlr4-python3-runtime==4.13.2\n\ntorchao==0.16.0\nopenenv-core==0.1.0\nschedulefree==1.4.1\n\naxolotl-contribs-lgpl==0.0.7\naxolotl-contribs-mit==0.0.6\n# telemetry\nposthog==6.7.11\n\nmistral-common==1.10.0\n"
  },
  {
    "path": "scripts/chat_datasets.py",
    "content": "\"\"\"\nhelper script to parse chat datasets into a usable yaml\n\"\"\"\n\nimport click\nimport yaml\nfrom datasets import load_dataset\n\n\n@click.command()\n@click.argument(\"dataset\", type=str)\n@click.option(\"--split\", type=str, default=\"train\")\ndef parse_dataset(dataset=None, split=\"train\"):\n    ds_cfg = {}\n    ds_cfg[\"path\"] = dataset\n    ds_cfg[\"split\"] = split\n    ds_cfg[\"type\"] = \"chat_template\"\n    ds_cfg[\"chat_template\"] = \"<<<Replace based on your model>>>\"\n\n    dataset = load_dataset(dataset, split=split)\n    features = dataset.features\n    feature_keys = features.keys()\n    field_messages = None\n    for key in [\"conversation\", \"conversations\", \"messages\"]:\n        if key in feature_keys:\n            field_messages = key\n            break\n    if not field_messages:\n        raise ValueError(\n            f\"No conversation field found in dataset: {', '.join(feature_keys)}\"\n        )\n    ds_cfg[\"field_messages\"] = field_messages\n\n    message_fields = features[field_messages][0].keys()\n\n    message_property_mappings = {\"role\": None, \"content\": None}\n    for key in [\"from\", \"role\"]:\n        if key in message_fields:\n            message_property_mappings[\"role\"] = key\n            break\n    if not message_property_mappings[\"role\"]:\n        raise ValueError(\n            f\"No role field found in messages: {', '.join(message_fields)}\"\n        )\n\n    for key in [\"content\", \"text\", \"value\"]:\n        if key in message_fields:\n            message_property_mappings[\"content\"] = key\n            break\n    if not message_property_mappings[\"content\"]:\n        raise ValueError(\n            f\"No content field found in messages: {', '.join(message_fields)}\"\n        )\n    ds_cfg[\"message_property_mappings\"] = message_property_mappings\n\n    print(yaml.dump({\"datasets\": [ds_cfg]}))\n\n\nif __name__ == \"__main__\":\n    parse_dataset()\n"
  },
  {
    "path": "scripts/cloud-entrypoint-term.sh",
    "content": "#!/bin/bash\n\n# Export specific ENV variables to /etc/rp_environment\necho \"Exporting environment variables...\"\nprintenv | grep -E '^RUNPOD_|^PATH=|^_=' | sed 's/^\\(.*\\)=\\(.*\\)$/export \\1=\"\\2\"/' >> /etc/rp_environment\nconda init\n# this needs to come after conda init\necho 'source /etc/rp_environment' >> ~/.bashrc\n\nadd_keys_to_authorized() {\n    local key_value=$1\n\n    # Create the ~/.ssh directory and set permissions\n    mkdir -p ~/.ssh\n    chmod 700 ~/.ssh\n\n    # Create the authorized_keys file if it doesn't exist\n    touch ~/.ssh/authorized_keys\n\n    # Initialize an empty key variable\n    local key=\"\"\n\n    # Read the key variable word by word\n    for word in $key_value; do\n        # Check if the word looks like the start of a key\n        if [[ $word == ssh-* ]]; then\n            # If there's a key being built, add it to the authorized_keys file\n            if [[ -n $key ]]; then\n                echo $key >> ~/.ssh/authorized_keys\n            fi\n            # Start a new key\n            key=$word\n        else\n            # Append the word to the current key\n            key=\"$key $word\"\n        fi\n    done\n\n    # Add the last key to the authorized_keys file\n    if [[ -n $key ]]; then\n        echo $key >> ~/.ssh/authorized_keys\n    fi\n\n    # Set the correct permissions\n    chmod 600 ~/.ssh/authorized_keys\n    chmod 700 -R ~/.ssh\n}\n\nif [[ $PUBLIC_KEY ]]; then\n    # runpod\n    add_keys_to_authorized \"$PUBLIC_KEY\"\n    # Start the SSH service in the background\n    service ssh start\nelif [[ $SSH_KEY ]]; then\n    # latitude.sh\n    add_keys_to_authorized \"$SSH_KEY\"\n    # Start the SSH service in the background\n    service ssh start\nelse\n    echo \"No PUBLIC_KEY or SSH_KEY environment variable provided, not starting openSSH daemon\"\nfi\n\n# Check if JUPYTER_PASSWORD is set and not empty\nif [ -n \"$JUPYTER_PASSWORD\" ]; then\n    # Set JUPYTER_TOKEN to the value of JUPYTER_PASSWORD\n    export JUPYTER_TOKEN=\"$JUPYTER_PASSWORD\"\nfi\n\nif [ \"$JUPYTER_DISABLE\" != \"1\" ]; then\n    # Run Jupyter Lab in the background\n    jupyter lab --port=8888 --ip=* --allow-root --ServerApp.allow_origin=* &\nfi\n\nif [ ! -d \"/workspace/data/axolotl-artifacts\" ]; then\n    mkdir -p /workspace/data/axolotl-artifacts\nfi\nif [ ! -L \"/workspace/axolotl/outputs\" ]; then\n    ln -sf /workspace/data/axolotl-artifacts /workspace/axolotl/outputs\nfi\n\n# Execute the passed arguments (CMD)\nexec \"$@\"\n"
  },
  {
    "path": "scripts/cloud-entrypoint.sh",
    "content": "#!/bin/bash\n\n# Export specific ENV variables to /etc/rp_environment\necho \"Exporting environment variables...\"\nprintenv | grep -E '^HF_|^BNB_|^CUDA_|^NCCL_|^NV|^RUNPOD_|^PATH=|^_=' | sed 's/^\\([^=]*\\)=\\(.*\\)$/export \\1=\"\\2\"/' | grep -v 'printenv' >> /etc/rp_environment\necho 'source /etc/rp_environment' >> ~/.bashrc\n\nadd_keys_to_authorized() {\n    local key_value=$1\n\n    # Create the ~/.ssh directory and set permissions\n    mkdir -p ~/.ssh\n    chmod 700 ~/.ssh\n\n    # Create the authorized_keys file if it doesn't exist\n    touch ~/.ssh/authorized_keys\n\n    # Initialize an empty key variable\n    local key=\"\"\n\n    # Read the key variable word by word\n    for word in $key_value; do\n        # Check if the word looks like the start of a key\n        if [[ $word == ssh-* ]]; then\n            # If there's a key being built, add it to the authorized_keys file\n            if [[ -n $key ]]; then\n                echo $key >> ~/.ssh/authorized_keys\n            fi\n            # Start a new key\n            key=$word\n        else\n            # Append the word to the current key\n            key=\"$key $word\"\n        fi\n    done\n\n    # Add the last key to the authorized_keys file\n    if [[ -n $key ]]; then\n        echo $key >> ~/.ssh/authorized_keys\n    fi\n\n    # Set the correct permissions\n    chmod 600 ~/.ssh/authorized_keys\n    chmod 700 -R ~/.ssh\n}\n\n# Set SSH port\nif [ ! -z \"$SSH_PORT\" ]; then\n    sed -i \"s/#Port 22/Port $SSH_PORT/\" /etc/ssh/sshd_config\nfi\n\nif [[ $PUBLIC_KEY ]]; then\n    # runpod, prime intellect\n    add_keys_to_authorized \"$PUBLIC_KEY\"\n    # Start the SSH service in the background\n    service ssh start\nelif [[ $SSH_KEY ]]; then\n    # latitude.sh\n    add_keys_to_authorized \"$SSH_KEY\"\n    # Start the SSH service in the background\n    service ssh start\nelse\n    echo \"No PUBLIC_KEY or SSH_KEY environment variable provided, not starting openSSH daemon\"\nfi\n\n# Check if JUPYTER_PASSWORD is set and not empty\nif [ -n \"$JUPYTER_PASSWORD\" ]; then\n    # Set JUPYTER_TOKEN to the value of JUPYTER_PASSWORD\n    export JUPYTER_TOKEN=\"$JUPYTER_PASSWORD\"\nfi\n\nif [ \"$JUPYTER_DISABLE\" != \"1\" ]; then\n    # Run Jupyter Lab in the background\n    jupyter lab --port=8888 --ip=* --allow-root --ServerApp.allow_origin=* &\nfi\n\nif [ ! -d \"/workspace/data/axolotl-artifacts\" ]; then\n    mkdir -p /workspace/data/axolotl-artifacts\nfi\nif [ ! -L \"/workspace/axolotl/outputs\" ]; then\n    ln -sf /workspace/data/axolotl-artifacts /workspace/axolotl/outputs\nfi\n\n# start the runpod slurm init\nSLURM_INIT=\"${SLURM_INIT:-/slurm-init.sh}\"\n\nif [[ -f \"$SLURM_INIT\" ]]; then\n  echo \"[entrypoint] running $SLURM_INIT...\"\n  bash \"$SLURM_INIT\"\nfi\n\n# Execute the passed arguments (CMD)\nexec \"$@\"\n"
  },
  {
    "path": "scripts/cutcrossentropy_install.py",
    "content": "\"\"\"Script to output the correct installation command for cut-cross-entropy.\"\"\"\n\nimport importlib.util\nimport sys\n\ntry:\n    import torch\nexcept ImportError as exc:\n    raise ImportError(\"Install torch via `pip install torch`\") from exc\nfrom packaging.version import Version as V\n\nUSE_UV = \"--uv\" in sys.argv[1:]\n\nv = V(torch.__version__)\n\n# no cut-cross-entropy support for torch < 2.4.0\nif v < V(\"2.4.0\"):\n    print(\"\")\n    sys.exit(0)\n\ncce_spec = importlib.util.find_spec(\"cut_cross_entropy\")\n\nUNINSTALL_PREFIX = \"\"\nif cce_spec:\n    if not importlib.util.find_spec(\"cut_cross_entropy.transformers\"):\n        UNINSTALL_PREFIX = \"pip uninstall -y cut-cross-entropy && \"\n\nUV_PREFIX = \"uv \" if USE_UV else \"\"\n\nprint(\n    UNINSTALL_PREFIX\n    + f'{UV_PREFIX}pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6\"'\n)\n"
  },
  {
    "path": "scripts/motd",
    "content": "\n     #@@ #@@      @@# @@#\n    @@  @@          @@  @@           =@@#                               @@                 #@    =@@#.\n    @@    #@@@@@@@@@    @@           #@#@=                              @@                 #@     .=@@\n      #@@@@@@@@@@@@@@@@@            =@# @#     ##=     ##    =####=+    @@      =#####+  =#@@###.   @@\n    @@@@@@@@@@/  +@@/  +@@          #@  =@=     #@=   @@   =@#+  +#@#   @@    =@#+  +#@#   #@.      @@\n    @@@@@@@@@@  ##@@  ##@@         =@#   @#      =@# @#    @@      @@   @@    @@      #@   #@       @@\n     @@@@@@@@@@@@@@@@@@@@          #@=+++#@=      =@@#     @@      @@   @@    @@      #@   #@       @@\n                                  =@#=====@@     =@# @#    @@      @@   @@    @@      #@   #@       @@\n    @@@@@@@@@@@@@@@@  @@@@        #@      #@=   #@=  +@@   #@#    =@#   @@.   =@#    =@#   #@.      @@\n                                 =@#       @#  #@=     #@   =#@@@@#=    +#@@=  +#@@@@#=    .##@@+   @@\n    @@@@  @@@@@@@@@@@@@@@@\n\nWelcome to the axolotl cloud image! If the you've mounted a disk to /workspace and the axolotl directory is empty, run the following commands:\n\nNeed help with your post-training workloads? Reach out us at contact@axolotl.ai for assistance.\n\n```\ncd /workspace\nrm -rf /workspace/axolotl\ngit clone https://github.com/axolotl-ai-cloud/axolotl.git\ncd axolotl\npip install --no-build-isolation --no-deps -e .\n```\n"
  },
  {
    "path": "scripts/unsloth_install.py",
    "content": "# noqa\nimport sys\n\ntry:\n    import torch\nexcept ImportError as error:\n    raise ImportError(\"Install torch via `pip install torch`\") from error\nfrom packaging.version import Version as V\n\nuse_uv = \"--uv\" in sys.argv[1:]\n\nv = V(torch.__version__)\ncuda = str(torch.version.cuda)\ntry:\n    is_ampere = torch.cuda.get_device_capability()[0] >= 8\nexcept RuntimeError:\n    is_ampere = False\nif cuda != \"12.1\" and cuda != \"11.8\" and cuda != \"12.4\":\n    raise RuntimeError(f\"CUDA = {cuda} not supported!\")\nif v <= V(\"2.1.0\"):\n    raise RuntimeError(f\"Torch = {v} too old!\")\nelif v <= V(\"2.1.1\"):\n    x = \"cu{}{}-torch211\"\nelif v <= V(\"2.1.2\"):\n    x = \"cu{}{}-torch212\"\nelif v < V(\"2.3.0\"):\n    x = \"cu{}{}-torch220\"\nelif v < V(\"2.4.0\"):\n    x = \"cu{}{}-torch230\"\nelif v < V(\"2.5.0\"):\n    x = \"cu{}{}-torch240\"\nelif v < V(\"2.6.0\"):\n    x = \"cu{}{}-torch250\"\nelse:\n    raise RuntimeError(f\"Torch = {v} too new!\")\nx = x.format(cuda.replace(\".\", \"\"), \"-ampere\" if is_ampere else \"\")\nuv_prefix = \"uv \" if use_uv else \"\"\nprint(\n    f'{uv_prefix}pip install unsloth-zoo==2024.12.1 && {uv_prefix}pip install --no-deps \"unsloth[{x}]==2024.12.4\"'\n)\n"
  },
  {
    "path": "setup.py",
    "content": "\"\"\"setup.py for axolotl\"\"\"\n\nimport os\nimport platform\nimport re\nfrom importlib.metadata import PackageNotFoundError, version\nfrom pathlib import Path\n\nfrom setuptools import find_packages, setup\n\n\ndef parse_requirements(extras_require_map):\n    _install_requires = []\n    _dependency_links = []\n    with open(\"./requirements.txt\", encoding=\"utf-8\") as requirements_file:\n        lines = [r.strip() for r in requirements_file.readlines()]\n        for line in lines:\n            is_extras = \"deepspeed\" in line or \"mamba-ssm\" in line\n            if line.startswith(\"--extra-index-url\"):\n                # Handle custom index URLs\n                _, url = line.split()\n                _dependency_links.append(url)\n            elif not is_extras and line and line[0] != \"#\":\n                # Handle standard packages\n                _install_requires.append(line)\n    try:\n        xformers_version = [req for req in _install_requires if \"xformers\" in req][0]\n        install_xformers = platform.machine() != \"aarch64\"\n        if platform.machine() == \"aarch64\":\n            # skip on ARM64\n            skip_packages = [\n                \"torchao\",\n                \"fla-core\",\n                \"flash-linear-attention\",\n            ]\n            _install_requires = [\n                req\n                for req in _install_requires\n                if re.split(r\"[>=<]\", req)[0].strip() not in skip_packages\n            ]\n        if \"Darwin\" in platform.system():\n            # skip packages not compatible with OSX\n            skip_packages = [\n                \"bitsandbytes\",\n                \"triton\",\n                \"mamba-ssm\",\n                \"xformers\",\n                \"liger-kernel\",\n            ]\n            _install_requires = [\n                req\n                for req in _install_requires\n                if re.split(r\"[>=<]\", req)[0].strip() not in skip_packages\n            ]\n            print(\n                _install_requires, [req in skip_packages for req in _install_requires]\n            )\n        else:\n            # detect the version of torch already installed\n            # and set it so dependencies don't clobber the torch version\n            try:\n                torch_version = version(\"torch\")\n            except PackageNotFoundError:\n                torch_version = \"2.8.0\"  # default to torch 2.8.0\n            _install_requires.append(f\"torch=={torch_version}\")\n\n            version_match = re.match(r\"^(\\d+)\\.(\\d+)(?:\\.(\\d+))?\", torch_version)\n            if version_match:\n                major, minor, patch = version_match.groups()\n                major, minor = int(major), int(minor)\n                patch = (\n                    int(patch) if patch is not None else 0\n                )  # Default patch to 0 if not present\n            else:\n                raise ValueError(\"Invalid version format\")\n\n            torch_parts = torch_version.split(\"+\")\n            if len(torch_parts) == 2:\n                torch_cuda_version = torch_parts[1]\n                _dependency_links.append(\n                    f\"https://download.pytorch.org/whl/{torch_cuda_version}\"\n                )\n\n            if (major, minor) >= (2, 9):\n                extras_require_map.pop(\"fbgemm-gpu\")\n                extras_require_map[\"fbgemm-gpu\"] = [\n                    \"fbgemm-gpu==1.4.0\",\n                    \"fbgemm-gpu-genai==1.4.2\",\n                ]\n                extras_require_map[\"vllm\"] = [\"vllm==0.11.1\"]\n                if not install_xformers:\n                    _install_requires.pop(_install_requires.index(xformers_version))\n                extras_require_map[\"vllm\"] = [\"vllm==0.13.0\"]\n                if patch == 0:\n                    extras_require_map[\"vllm\"] = [\"vllm==0.13.0\"]\n                else:\n                    extras_require_map[\"vllm\"] = [\"vllm==0.14.0\"]\n            elif (major, minor) >= (2, 8):\n                extras_require_map.pop(\"fbgemm-gpu\")\n                extras_require_map[\"fbgemm-gpu\"] = [\"fbgemm-gpu-genai==1.3.0\"]\n                extras_require_map[\"vllm\"] = [\"vllm==0.11.0\"]\n                if not install_xformers:\n                    _install_requires.pop(_install_requires.index(xformers_version))\n            elif (major, minor) >= (2, 7):\n                _install_requires.pop(_install_requires.index(xformers_version))\n                if patch == 0:\n                    if install_xformers:\n                        _install_requires.append(\"xformers==0.0.30\")\n                    # vllm 0.9.x is incompatible with latest transformers\n                    extras_require_map.pop(\"vllm\")\n                else:\n                    if install_xformers:\n                        _install_requires.append(\"xformers==0.0.31\")\n                    extras_require_map[\"vllm\"] = [\"vllm==0.10.1\"]\n            elif (major, minor) >= (2, 6):\n                _install_requires.pop(_install_requires.index(xformers_version))\n                if install_xformers:\n                    _install_requires.append(\"xformers==0.0.29.post3\")\n                # since we only support 2.6.0+cu126\n                _dependency_links.append(\"https://download.pytorch.org/whl/cu126\")\n                extras_require_map.pop(\"vllm\")\n            elif (major, minor) >= (2, 5):\n                _install_requires.pop(_install_requires.index(xformers_version))\n                if install_xformers:\n                    if patch == 0:\n                        _install_requires.append(\"xformers==0.0.28.post2\")\n                    else:\n                        _install_requires.append(\"xformers>=0.0.28.post3\")\n                extras_require_map.pop(\"vllm\")\n            elif (major, minor) >= (2, 4):\n                extras_require_map.pop(\"vllm\")\n                if install_xformers:\n                    if patch == 0:\n                        _install_requires.pop(_install_requires.index(xformers_version))\n                        _install_requires.append(\"xformers>=0.0.27\")\n                    else:\n                        _install_requires.pop(_install_requires.index(xformers_version))\n                        _install_requires.append(\"xformers==0.0.28.post1\")\n            else:\n                raise ValueError(\"axolotl requires torch>=2.4\")\n\n    except PackageNotFoundError:\n        pass\n    return _install_requires, _dependency_links, extras_require_map\n\n\ndef get_package_version():\n    with open(\n        Path(os.path.dirname(os.path.abspath(__file__))) / \"VERSION\",\n        \"r\",\n        encoding=\"utf-8\",\n    ) as fin:\n        version_ = fin.read().strip()\n    return version_\n\n\nextras_require = {\n    \"flash-attn\": [\"flash-attn==2.8.3\"],\n    \"ring-flash-attn\": [\n        \"flash-attn==2.8.3\",\n        \"ring-flash-attn>=0.1.7\",\n    ],\n    \"deepspeed\": [\n        \"deepspeed==0.18.2\",\n        \"deepspeed-kernels\",\n    ],\n    \"mamba-ssm\": [\n        \"mamba-ssm==1.2.0.post1\",\n        \"causal_conv1d\",\n    ],\n    \"auto-gptq\": [\n        \"auto-gptq==0.5.1\",\n    ],\n    \"mlflow\": [\n        \"mlflow\",\n    ],\n    \"galore\": [\n        \"galore_torch\",\n    ],\n    \"apollo\": [\n        \"apollo-torch\",\n    ],\n    \"optimizers\": [\n        \"galore_torch\",\n        \"apollo-torch\",\n        \"lomo-optim==0.1.1\",\n        \"torch-optimi==0.2.1\",\n        \"came_pytorch==0.1.3\",\n    ],\n    \"ray\": [\n        \"ray[train]>=2.52.1\",\n    ],\n    \"vllm\": [\n        \"vllm==0.10.0\",\n    ],\n    \"llmcompressor\": [\n        \"llmcompressor==0.5.1\",\n    ],\n    \"fbgemm-gpu\": [\"fbgemm-gpu-genai==1.3.0\"],\n    \"opentelemetry\": [\n        \"opentelemetry-api\",\n        \"opentelemetry-sdk\",\n        \"opentelemetry-exporter-prometheus\",\n        \"prometheus-client\",\n    ],\n}\ninstall_requires, dependency_links, extras_require_build = parse_requirements(\n    extras_require\n)\n\nsetup(\n    version=get_package_version(),\n    package_dir={\"\": \"src\"},\n    packages=find_packages(\"src\"),\n    install_requires=install_requires,\n    dependency_links=dependency_links,\n    entry_points={\n        \"console_scripts\": [\n            \"axolotl=axolotl.cli.main:main\",\n        ],\n    },\n    extras_require=extras_require_build,\n)\n"
  },
  {
    "path": "src/axolotl/__init__.py",
    "content": "\"\"\"Axolotl - Train and fine-tune large language models\"\"\"\n\nimport pkgutil\nfrom importlib.metadata import PackageNotFoundError, version\n\n__path__ = pkgutil.extend_path(__path__, __name__)  # Make this a namespace package\n\ntry:\n    __version__ = version(\"axolotl\")\nexcept PackageNotFoundError:\n    __version__ = \"unknown\"\n"
  },
  {
    "path": "src/axolotl/cli/__init__.py",
    "content": "\"\"\"Axolotl CLI module initialization.\"\"\"\n\nimport os\n\nfrom axolotl.logging_config import configure_logging\n\nos.environ.setdefault(\"TOKENIZERS_PARALLELISM\", \"false\")\nos.environ.setdefault(\"HF_XET_HIGH_PERFORMANCE\", \"1\")\nos.environ.setdefault(\"TRL_EXPERIMENTAL_SILENCE\", \"1\")\n\nconfigure_logging()\n"
  },
  {
    "path": "src/axolotl/cli/args.py",
    "content": "\"\"\"Module for axolotl CLI command arguments.\"\"\"\n\nfrom dataclasses import dataclass, field\nfrom typing import Optional\n\n\n@dataclass\nclass PreprocessCliArgs:\n    \"\"\"Dataclass with CLI arguments for `axolotl preprocess` command.\"\"\"\n\n    debug: bool = field(default=False)\n    debug_text_only: bool = field(default=False)\n    debug_num_examples: int = field(default=1)\n    prompter: Optional[str] = field(default=None)\n    download: Optional[bool] = field(default=True)\n    iterable: Optional[bool] = field(\n        default=False,\n        metadata={\n            \"help\": (\n                \"Deprecated in v0.13.0, will be removed in v0.14.0. For streaming \"\n                \"datasets, use 'axolotl train' and set 'streaming: true' in your YAML \"\n                \"config, or pass --streaming instead in the CLI.\"\n            )\n        },\n    )\n\n\n@dataclass\nclass TrainerCliArgs:\n    \"\"\"Dataclass with CLI arguments for `axolotl train` command.\"\"\"\n\n    debug: bool = field(default=False)\n    debug_text_only: bool = field(default=False)\n    debug_num_examples: int = field(default=0)\n    prompter: Optional[str] = field(default=None)\n    shard: bool = field(default=False)\n\n\n@dataclass\nclass VllmServeCliArgs:\n    \"\"\"Dataclass with CLI arguments for `axolotl vllm-serve` command.\"\"\"\n\n    tensor_parallel_size: Optional[int] = field(\n        default=None,\n        metadata={\"help\": \"Number of tensor parallel workers to use.\"},\n    )\n    data_parallel_size: Optional[int] = field(\n        default=None,\n        metadata={\n            \"help\": \"Number of data parallel workers to use for vLLM serving. This controls how many model replicas are used for parallel inference.\"\n        },\n    )\n    host: Optional[str] = field(\n        default=None,  # nosec B104\n        metadata={\"help\": \"Host address to run the server on.\"},\n    )\n    port: Optional[int] = field(\n        default=None,\n        metadata={\"help\": \"Port to run the server on.\"},\n    )\n    gpu_memory_utilization: Optional[float] = field(\n        default=None,\n        metadata={\n            \"help\": \"Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV \"\n            \"cache on the device dedicated to generation powered by vLLM. Higher values will increase the KV cache \"\n            \"size and thus improve the model's throughput. However, if the value is too high, it may cause \"\n            \"out-of-memory (OOM) errors during initialization.\"\n        },\n    )\n    dtype: Optional[str] = field(\n        default=None,\n        metadata={\n            \"help\": \"Data type to use for vLLM generation. If set to 'auto', the data type will be automatically \"\n            \"determined based on the model configuration. Find the supported values in the vLLM documentation.\"\n        },\n    )\n    max_model_len: Optional[int] = field(\n        default=None,\n        metadata={\n            \"help\": \"If set, the `max_model_len` to use for vLLM. This can be useful when running with reduced \"\n            \"`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model \"\n            \"context size, which might be much larger than the KV cache, leading to inefficiencies.\"\n        },\n    )\n    enable_prefix_caching: Optional[bool] = field(\n        default=None,\n        metadata={\n            \"help\": \"Whether to enable prefix caching in vLLM. If set to `True`, ensure that the model and the \"\n            \"hardware support this feature.\"\n        },\n    )\n    serve_module: Optional[str] = field(\n        default=None,\n        metadata={\n            \"help\": \"Module to serve. If not set, the default module will be used.\"\n        },\n    )\n\n    enable_reasoning: Optional[bool] = field(\n        default=None,\n    )\n\n    reasoning_parser: Optional[str] = field(\n        default=None,\n    )\n\n\n@dataclass\nclass QuantizeCliArgs:\n    \"\"\"Dataclass with CLI arguments for `axolotl quantize` command.\"\"\"\n\n    base_model: Optional[str] = field(default=None)\n    weight_dtype: Optional[str] = field(default=None)\n    activation_dtype: Optional[str] = field(default=None)\n    quantize_embedding: Optional[bool] = field(default=None)\n    group_size: Optional[int] = field(default=None)\n    output_dir: Optional[str] = field(default=None)\n    hub_model_id: Optional[str] = field(default=None)\n\n\n@dataclass\nclass EvaluateCliArgs:\n    \"\"\"Dataclass with CLI arguments for `axolotl evaluate` command.\"\"\"\n\n    debug: bool = field(default=False)\n    debug_text_only: bool = field(default=False)\n    debug_num_examples: int = field(default=0)\n\n\n@dataclass\nclass InferenceCliArgs:\n    \"\"\"Dataclass with CLI arguments for `axolotl inference` command.\"\"\"\n\n    prompter: Optional[str] = field(default=None)\n"
  },
  {
    "path": "src/axolotl/cli/art.py",
    "content": "\"\"\"Axolotl ASCII logo utils.\"\"\"\n\nfrom axolotl.utils.distributed import is_main_process\n\nAXOLOTL_LOGO = \"\"\"\n     #@@ #@@      @@# @@#\n    @@  @@          @@  @@           =@@#                               @@                 #@    =@@#.\n    @@    #@@@@@@@@@    @@           #@#@=                              @@                 #@     .=@@\n      #@@@@@@@@@@@@@@@@@            =@# @#     ##=     ##    =####=+    @@      =#####+  =#@@###.   @@\n    @@@@@@@@@@/  +@@/  +@@          #@  =@=     #@=   @@   =@#+  +#@#   @@    =@#+  +#@#   #@.      @@\n    @@@@@@@@@@  ##@@  ##@@         =@#   @#      =@# @#    @@      @@   @@    @@      #@   #@       @@\n     @@@@@@@@@@@@@@@@@@@@          #@=+++#@=      =@@#     @@      @@   @@    @@      #@   #@       @@\n                                  =@#=====@@     =@# @#    @@      @@   @@    @@      #@   #@       @@\n    @@@@@@@@@@@@@@@@  @@@@        #@      #@=   #@=  +@@   #@#    =@#   @@.   =@#    =@#   #@.      @@\n                                 =@#       @#  #@=     #@   =#@@@@#=    +#@@=  +#@@@@#=    .##@@+   @@\n    @@@@  @@@@@@@@@@@@@@@@\n\"\"\"\n\nHAS_PRINTED_LOGO = False\n\n\ndef print_axolotl_text_art():\n    \"\"\"Prints axolotl ASCII art.\"\"\"\n\n    global HAS_PRINTED_LOGO\n    if HAS_PRINTED_LOGO:\n        return\n    if is_main_process():\n        HAS_PRINTED_LOGO = True\n        print(AXOLOTL_LOGO)\n"
  },
  {
    "path": "src/axolotl/cli/checks.py",
    "content": "\"\"\"Various checks for Axolotl CLI.\"\"\"\n\nimport os\nfrom pathlib import Path\n\nfrom accelerate.commands.config import config_args\nfrom huggingface_hub import HfApi\nfrom huggingface_hub.utils import LocalTokenNotFoundError\nfrom requests import HTTPError\n\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\n\ndef check_accelerate_default_config() -> None:\n    \"\"\"Logs at warning level if no accelerate config file is found.\"\"\"\n    if Path(config_args.default_yaml_config_file).exists():\n        LOG.warning(\n            f\"accelerate config file found at {config_args.default_yaml_config_file}. This can lead to unexpected errors\"\n        )\n\n\ndef check_user_token() -> bool:\n    \"\"\"Checks for HF user info. Check is skipped if HF_HUB_OFFLINE=1.\n\n    Returns:\n        Boolean indicating successful check (i.e., HF_HUB_OFFLINE=1 or HF user info is retrieved).\n\n    Raises:\n        LocalTokenNotFoundError: If HF user info can't be retrieved.\n    \"\"\"\n    # Skip check if HF_HUB_OFFLINE is set to True\n    if os.getenv(\"HF_HUB_OFFLINE\") == \"1\":\n        LOG.info(\n            \"Skipping HuggingFace token verification because HF_HUB_OFFLINE is set to True. Only local files will be used.\"\n        )\n        return True\n\n    # Verify if token is valid\n    api = HfApi()\n    try:\n        user_info = api.whoami()\n        return bool(user_info)\n    except LocalTokenNotFoundError:\n        LOG.warning(\n            \"Error verifying HuggingFace token. Remember to log in using `hf auth login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets.\"\n        )\n        return False\n    except HTTPError:\n        LOG.warning(\n            \"Error accessing HuggingFace. This may be due to a network issue or rate limiting.\"\n        )\n        return False\n"
  },
  {
    "path": "src/axolotl/cli/cloud/__init__.py",
    "content": "\"\"\"\nlaunch axolotl in supported cloud platforms\n\"\"\"\n\nfrom pathlib import Path\nfrom typing import Literal\n\nimport yaml\n\nfrom axolotl.cli.cloud.base import Cloud\nfrom axolotl.cli.cloud.baseten import BasetenCloud\nfrom axolotl.cli.cloud.modal_ import ModalCloud\nfrom axolotl.utils.dict import DictDefault\n\n\ndef load_cloud_cfg(cloud_config: Path | str) -> DictDefault:\n    \"\"\"Load and validate cloud configuration.\"\"\"\n    # Load cloud configuration.\n    with open(cloud_config, encoding=\"utf-8\") as file:\n        cloud_cfg: DictDefault = DictDefault(yaml.safe_load(file))\n    return cloud_cfg\n\n\ndef do_cli_preprocess(\n    cloud_config: Path | str,\n    config: Path | str,\n) -> None:\n    cloud_cfg = load_cloud_cfg(cloud_config)\n    cloud = ModalCloud(cloud_cfg)\n    with open(config, \"r\", encoding=\"utf-8\") as file:\n        config_yaml = file.read()\n    cloud.preprocess(config_yaml)\n\n\ndef do_cli_train(\n    cloud_config: Path | str,\n    config: Path | str,\n    launcher: Literal[\"accelerate\", \"torchrun\", \"python\"] = \"accelerate\",\n    launcher_args: list[str] | None = None,\n    cwd=None,\n    **kwargs,\n) -> None:\n    cloud_cfg: DictDefault = load_cloud_cfg(cloud_config)\n    provider = cloud_cfg.provider or \"modal\"\n    cloud: Cloud | None\n    if provider == \"modal\":\n        cloud = ModalCloud(cloud_cfg)\n    elif provider == \"baseten\":\n        cloud = BasetenCloud(cloud_cfg.to_dict())\n    else:\n        raise ValueError(f\"Unsupported cloud provider: {provider}\")\n    with open(config, \"r\", encoding=\"utf-8\") as file:\n        config_yaml = file.read()\n    local_dirs = {}\n    if cwd and not Path(cwd).joinpath(\"src\", \"axolotl\").exists():\n        local_dirs = {\"/workspace/mounts\": cwd}\n    cloud.train(\n        config_yaml,\n        launcher=launcher,\n        launcher_args=launcher_args,\n        local_dirs=local_dirs,\n        **kwargs,\n    )\n\n\ndef do_cli_lm_eval(\n    cloud_config: Path | str,\n    config: Path | str,\n) -> None:\n    cloud_cfg = load_cloud_cfg(cloud_config)\n    cloud = ModalCloud(cloud_cfg)\n    with open(config, \"r\", encoding=\"utf-8\") as file:\n        config_yaml = file.read()\n    cloud.lm_eval(config_yaml)\n"
  },
  {
    "path": "src/axolotl/cli/cloud/base.py",
    "content": "\"\"\"\nbase class for cloud platforms from cli\n\"\"\"\n\nfrom abc import ABC, abstractmethod\nfrom typing import Literal\n\n\nclass Cloud(ABC):\n    \"\"\"\n    Abstract base class for cloud platforms.\n    \"\"\"\n\n    @abstractmethod\n    def preprocess(self, config_yaml: str, *args, **kwargs) -> None:\n        pass\n\n    @abstractmethod\n    def train(\n        self,\n        config_yaml: str,\n        launcher: Literal[\"accelerate\", \"torchrun\", \"python\"] = \"accelerate\",\n        launcher_args: list[str] | None = None,\n        local_dirs: dict[str, str] | None = None,\n        **kwargs,\n    ):\n        pass\n"
  },
  {
    "path": "src/axolotl/cli/cloud/baseten/__init__.py",
    "content": "\"\"\"Baseten Cloud CLI\"\"\"\n\nimport shutil\nimport subprocess  # nosec B404\nimport tempfile\nfrom os.path import dirname\nfrom typing import Literal\n\nimport yaml\n\nfrom axolotl.cli.cloud.base import Cloud\n\n\nclass BasetenCloud(Cloud):\n    \"\"\"Baseten Cloud Axolotl CLI\"\"\"\n\n    def __init__(self, config: dict):\n        self.config = config\n\n    def preprocess(self, config_yaml: str, *args, **kwargs) -> None:\n        raise NotImplementedError(\n            \"Separate preprocess function for Baseten is not \"\n            \"implemented and will happen during hte train step.\"\n        )\n\n    def train(\n        self,\n        config_yaml: str,\n        launcher: Literal[\"accelerate\", \"torchrun\", \"python\"] = \"accelerate\",\n        launcher_args: list[str] | None = None,\n        local_dirs: dict[str, str] | None = None,  # pylint: disable=unused-argument\n        **kwargs,\n    ):\n        with tempfile.TemporaryDirectory() as tmp_dir:\n            config = self.config.copy()\n            config[\"launcher\"] = launcher\n            config[\"launcher_args\"] = launcher_args\n            with open(tmp_dir + \"/cloud.yaml\", \"w\", encoding=\"utf-8\") as cloud_fout:\n                yaml.dump(config, cloud_fout)\n            with open(tmp_dir + \"/train.yaml\", \"w\", encoding=\"utf-8\") as config_fout:\n                config_fout.write(config_yaml)\n            shutil.copyfile(dirname(__file__) + \"/template/run.sh\", tmp_dir + \"/run.sh\")\n            shutil.copyfile(\n                dirname(__file__) + \"/template/train_sft.py\", tmp_dir + \"/train_sft.py\"\n            )\n            subprocess.run(  # nosec B603 B607\n                [\"truss\", \"train\", \"push\", \"train_sft.py\"], cwd=tmp_dir, check=False\n            )\n"
  },
  {
    "path": "src/axolotl/cli/cloud/baseten/template/run.sh",
    "content": "#!/bin/bash\nset -eux\n\nexport NCCL_SOCKET_IFNAME=\"^docker0,lo\"\nexport NCCL_IB_DISABLE=0\nexport NCCL_TIMEOUT=1800000\n\naxolotl preprocess train.yaml\naxolotl train train.yaml --launcher ${AXOLOTL_LAUNCHER} ${AXOLOTL_LAUNCHER_ARGS}\n"
  },
  {
    "path": "src/axolotl/cli/cloud/baseten/template/train_sft.py",
    "content": "\"\"\"\nBaseten Training Script for Axolotl\n\"\"\"\n\n# pylint: skip-file\nimport yaml\nfrom truss.base import truss_config\n\n# Import necessary classes from the Baseten Training SDK\nfrom truss_train import definitions\n\ncloud_config = yaml.safe_load(open(\"cloud.yaml\", \"r\"))\ngpu = cloud_config.get(\"gpu\", \"h100\")\ngpu_count = int(cloud_config.get(\"gpu_count\", 1))\nnode_count = int(cloud_config.get(\"node_count\", 1))\nproject_name = cloud_config.get(\"project_name\", \"axolotl-project\") or \"axolotl-project\"\nsecrets = cloud_config.get(\"secrets\", [])\nlauncher = cloud_config.get(\"launcher\", \"accelerate\")\nlauncher_args = cloud_config.get(\"launcher_args\", [])\nscript_name = \"run.sh\"\n\nlauncher_args_str = \"\"\nif launcher_args:\n    launcher_args_str = \"-- \" + \" \".join(launcher_args)\n\n# 1. Define a base image for your training job\nBASE_IMAGE = \"axolotlai/axolotl:main-py3.11-cu128-2.9.1\"\n\n# 2. Define the Runtime Environment for the Training Job\n# This includes start commands and environment variables.a\n# Secrets from the baseten workspace like API keys are referenced using\n# `SecretReference`.\n\nenv_vars = {\n    \"AXOLOTL_LAUNCHER\": launcher,\n    \"AXOLOTL_LAUNCHER_ARGS\": launcher_args_str,\n}\nfor secret_name in secrets:\n    env_vars[secret_name] = definitions.SecretReference(name=secret_name)\n\ntraining_runtime = definitions.Runtime(\n    start_commands=[  # Example: list of commands to run your training script\n        f\"/bin/sh -c 'chmod +x ./{script_name} && ./{script_name}'\"\n    ],\n    environment_variables=env_vars,\n)\n\n# 3. Define the Compute Resources for the Training Job\ntraining_compute = definitions.Compute(\n    node_count=node_count,\n    accelerator=truss_config.AcceleratorSpec(\n        accelerator=truss_config.Accelerator.H100,\n        count=gpu_count,\n    ),\n)\n\n# 4. Define the Training Job\n# This brings together the image, compute, and runtime configurations.\nmy_training_job = definitions.TrainingJob(\n    image=definitions.Image(base_image=BASE_IMAGE),\n    compute=training_compute,\n    runtime=training_runtime,\n)\n\n\n# This config will be pushed using the Truss CLI.\n# The association of the job to the project happens at the time of push.\nfirst_project_with_job = definitions.TrainingProject(\n    name=project_name, job=my_training_job\n)\n"
  },
  {
    "path": "src/axolotl/cli/cloud/modal_.py",
    "content": "\"\"\"\nModal Cloud support from CLI\n\"\"\"\n\nimport copy\nimport json\nimport os\nimport subprocess  # nosec B404\nfrom pathlib import Path\nfrom random import randint\nfrom typing import Literal\n\nimport modal\n\nfrom axolotl.cli.cloud.base import Cloud\n\n\ndef run_cmd(cmd: str, run_folder: str, volumes=None):\n    \"\"\"Run a command inside a folder, with Modal Volume reloading before and commit on success.\"\"\"\n    # Ensure volumes contain latest files.\n    if volumes:\n        for _, vol in volumes.items():\n            vol.reload()\n\n    # modal workaround so it doesn't use the automounted axolotl\n    new_env = copy.deepcopy(os.environ)\n\n    if \"PYTHONPATH\" in new_env:\n        paths = [\"/workspace/mounts\"]\n        for sub_python_path_str in new_env[\"PYTHONPATH\"].split(\":\"):\n            sub_python_path = Path(sub_python_path_str)\n            if not sub_python_path.joinpath(\"src\", \"axolotl\").exists():\n                # we don't want to use the automounted axolotl or unexpected behavior happens\n                paths.append(str(sub_python_path))\n        if paths:\n            new_env[\"PYTHONPATH\"] = \":\".join(paths)\n        else:\n            del new_env[\"PYTHONPATH\"]\n\n    # Propagate errors from subprocess.\n    if exit_code := subprocess.call(  # nosec B603\n        cmd.split(), cwd=run_folder, env=new_env\n    ):\n        exit(exit_code)\n\n    # Commit writes to volume.\n    if volumes:\n        for _, vol in volumes.items():\n            vol.commit()\n\n\nclass ModalCloud(Cloud):\n    \"\"\"\n    Modal Cloud implementation.\n    \"\"\"\n\n    def __init__(self, config, app=None):\n        self.config = config\n        if not app:\n            app = modal.App()\n        self.app = app\n\n        self.volumes = {}\n        if config.volumes:\n            for volume_config in config.volumes:\n                _, mount, vol = self.create_volume(volume_config)\n                self.volumes[mount] = (vol, volume_config)\n\n    def get_env(self):\n        res = {\n            \"HF_DATASETS_CACHE\": \"/workspace/data/huggingface-cache/datasets\",\n            \"HF_HUB_CACHE\": \"/workspace/data/huggingface-cache/hub\",\n        }\n\n        for key in self.config.get(\"env\", []):\n            if isinstance(key, str):\n                if val := os.environ.get(key, \"\"):\n                    res[key] = val\n            elif isinstance(key, dict):\n                (key_, val) = list(key.items())[0]\n                res[key_] = val\n        return res\n\n    def get_image(self):\n        docker_tag = \"main-py3.11-cu128-2.9.1\"\n        if self.config.docker_tag:\n            docker_tag = self.config.docker_tag\n        docker_image = f\"axolotlai/axolotl:{docker_tag}\"\n\n        # grab the sha256 hash from docker hub for this image+tag\n        # this ensures that we always get the latest image for this tag, even if it's already cached\n        try:\n            manifest = subprocess.check_output(  # nosec\n                [\"docker\", \"manifest\", \"inspect\", docker_image],\n            ).decode(\"utf-8\")\n            sha256_hash = json.loads(manifest)[\"manifests\"][0][\"digest\"]\n        except subprocess.CalledProcessError:\n            sha256_hash = None\n\n        # create the image\n        if sha256_hash:\n            image = modal.Image.from_registry(f\"axolotlai/axolotl@{sha256_hash}\")\n        else:\n            image = modal.Image.from_registry(docker_image)\n\n        dockerfile_commands = []\n        if self.config.dockerfile_commands:\n            dockerfile_commands.extend(self.config.dockerfile_commands)\n\n        # branch\n        if self.config.branch:\n            dockerfile_commands.extend(\n                [\n                    # Random id for cache busting of branch commits\n                    f\"RUN echo '{str(randint(0, 1000000))}'\",  # nosec B311\n                    f\"RUN cd /workspace/axolotl && git fetch && git checkout {self.config.branch} && git pull\",\n                ]\n            )\n\n        if dockerfile_commands:\n            image = image.dockerfile_commands(dockerfile_commands)\n\n        if env := self.get_env():\n            image = image.env(env)\n\n        return image\n\n    def get_secrets(self):\n        res = []\n        if self.config.secrets:\n            for key in self.config.get(\"secrets\", []):\n                if isinstance(key, str):\n                    if val := os.environ.get(key, \"\"):\n                        res.append(modal.Secret.from_dict({key: val}))\n                elif isinstance(key, dict):\n                    (key_, val) = list(key.items())[0]\n                    res.append(modal.Secret.from_dict({key_: val}))\n        return res\n\n    def create_volume(self, volume_config):\n        name = volume_config.name\n        mount = volume_config.mount\n        return name, mount, modal.Volume.from_name(name, create_if_missing=True)\n\n    def get_ephemeral_disk_size(self):\n        return 1000 * 525  # 1 TiB\n\n    def get_preprocess_timeout(self):\n        if self.config.timeout_preprocess:\n            return int(self.config.timeout_preprocess)\n        return 60 * 60 * 3  # 3 hours\n\n    def get_preprocess_memory(self):\n        memory = 128  # default to 128GiB\n        if self.config.memory:\n            memory = int(self.config.memory)\n        if self.config.memory_preprocess:\n            memory = int(self.config.memory_preprocess)\n        return 1024 * memory\n\n    def get_preprocess_env(self):\n        return self.app.function(\n            image=self.get_image(),\n            volumes={k: v[0] for k, v in self.volumes.items()},\n            cpu=8.0,\n            ephemeral_disk=self.get_ephemeral_disk_size(),\n            memory=self.get_preprocess_memory(),\n            timeout=self.get_preprocess_timeout(),\n            secrets=self.get_secrets(),\n        )\n\n    def preprocess(self, config_yaml: str, *args, **kwargs):\n        modal_fn = self.get_preprocess_env()(_preprocess)\n        with modal.enable_output():\n            with self.app.run(detach=True):\n                modal_fn.remote(\n                    config_yaml,\n                    *args,\n                    volumes={k: v[0] for k, v in self.volumes.items()},\n                    **kwargs,\n                )\n\n    def get_train_timeout(self):\n        if self.config.timeout:\n            return int(self.config.timeout)\n        return 60 * 60 * 24  # 24 hours\n\n    def get_train_gpu(self):\n        count = self.config.gpu_count or 1\n        family = self.config.gpu.lower() or \"l40s\"\n\n        if family == \"l40s\":\n            return modal.gpu.L40S(count=count)\n        if family in [\"a100\", \"a100-40gb\"]:\n            return modal.gpu.A100(count=count, size=\"40GB\")\n        if family == \"a100-80gb\":\n            return modal.gpu.A100(count=count, size=\"80GB\")\n        if family in [\"a10\", \"a10g\"]:\n            return modal.gpu.A10G(count=count)\n        if family == \"h100\":\n            return f\"H100:{count}\"\n        if family == \"t4\":\n            return modal.gpu.T4(count=count)\n        if family == \"l4\":\n            return modal.gpu.L4(count=count)\n        raise ValueError(f\"Unsupported GPU family: {family}\")\n\n    def get_train_memory(self):\n        memory = 128  # default to 128GiB\n        if self.config.memory:\n            memory = int(self.config.memory)\n        return 1024 * memory\n\n    def get_train_env(self, local_dirs=None):\n        image = self.get_image()\n        for mount, local_dir in (local_dirs or {}).items():\n            image = image.add_local_dir(local_dir, mount)\n        return self.app.function(\n            image=image,\n            volumes={k: v[0] for k, v in self.volumes.items()},\n            cpu=16.0,\n            gpu=self.get_train_gpu(),\n            memory=self.get_train_memory(),\n            timeout=self.get_train_timeout(),\n            secrets=self.get_secrets(),\n        )\n\n    def train(\n        self,\n        config_yaml: str,\n        launcher: Literal[\"accelerate\", \"torchrun\", \"python\"] = \"accelerate\",\n        launcher_args: list[str] | None = None,\n        local_dirs: dict[str, str] | None = None,\n        **kwargs,\n    ):\n        modal_fn = self.get_train_env(local_dirs)(_train)\n        with modal.enable_output():\n            with self.app.run(detach=True):\n                modal_fn.remote(\n                    config_yaml,\n                    launcher=launcher,\n                    launcher_args=launcher_args,\n                    volumes={k: v[0] for k, v in self.volumes.items()},\n                    **kwargs,\n                )\n\n    def lm_eval(self, config_yaml: str):\n        modal_fn = self.get_train_env()(_lm_eval)\n        with modal.enable_output():\n            with self.app.run(detach=True):\n                if self.config.get(\"spawn\", False):\n                    modal_fn_exec = modal_fn.spawn\n                else:\n                    modal_fn_exec = modal_fn.remote\n                modal_fn_exec(\n                    config_yaml,\n                    volumes={k: v[0] for k, v in self.volumes.items()},\n                )\n\n\ndef _preprocess(config_yaml: str, volumes=None):\n    Path(\"/workspace/mounts\").mkdir(parents=True, exist_ok=True)\n    with open(\"/workspace/mounts/config.yaml\", \"w\", encoding=\"utf-8\") as f_out:\n        f_out.write(config_yaml)\n    run_folder = \"/workspace/mounts\"\n    run_cmd(\n        \"axolotl preprocess /workspace/mounts/config.yaml --dataset-processes=8\",\n        run_folder,\n        volumes,\n    )\n\n\ndef _train(\n    config_yaml: str,\n    launcher: Literal[\"accelerate\", \"torchrun\", \"python\"] = \"accelerate\",\n    launcher_args: list[str] | None = None,\n    volumes=None,\n    **kwargs,\n):\n    Path(\"/workspace/mounts\").mkdir(parents=True, exist_ok=True)\n    with open(\"/workspace/mounts/config.yaml\", \"w\", encoding=\"utf-8\") as f_out:\n        f_out.write(config_yaml)\n    run_folder = \"/workspace/mounts\"\n\n    launcher_args = launcher_args or []\n\n    # Build the base command\n    if launcher == \"accelerate\":\n        launcher_arg = \"--launcher accelerate\"\n    elif launcher == \"torchrun\":\n        launcher_arg = \"--launcher torchrun\"\n    else:\n        launcher_arg = \"--launcher python\"\n\n    # Build launcher args string\n    launcher_args_str = \"\"\n    if launcher_args:\n        launcher_args_str = \"-- \" + \" \".join(launcher_args)\n\n    run_cmd(\n        f\"axolotl train {launcher_arg} /workspace/mounts/config.yaml {launcher_args_str}\".strip(),\n        run_folder,\n        volumes,\n    )\n\n\ndef _lm_eval(config_yaml: str, volumes=None):\n    Path(\"/workspace/mounts\").mkdir(parents=True, exist_ok=True)\n    with open(\"/workspace/mounts/config.yaml\", \"w\", encoding=\"utf-8\") as f_out:\n        f_out.write(config_yaml)\n    run_folder = \"/workspace/mounts\"\n    run_cmd(\n        \"axolotl lm-eval /workspace/mounts/config.yaml\",\n        run_folder,\n        volumes,\n    )\n"
  },
  {
    "path": "src/axolotl/cli/config.py",
    "content": "\"\"\"Configuration loading and processing.\"\"\"\n\nimport json\nimport os\nimport tempfile\nfrom pathlib import Path\nfrom tempfile import NamedTemporaryFile\nfrom typing import Any, Optional, Union\nfrom urllib.parse import urlparse\n\nimport requests\nimport torch\nimport yaml\nfrom transformers.utils import is_torch_bf16_gpu_available, is_torch_tf32_available\n\nfrom axolotl.integrations.base import PluginManager\nfrom axolotl.telemetry.errors import send_errors\nfrom axolotl.telemetry.manager import TelemetryManager\nfrom axolotl.utils.comet_ import setup_comet_env_vars\nfrom axolotl.utils.config import (\n    normalize_cfg_datasets,\n    normalize_config,\n    validate_config,\n)\nfrom axolotl.utils.dict import DictDefault\nfrom axolotl.utils.logging import get_logger\nfrom axolotl.utils.mlflow_ import setup_mlflow_env_vars\nfrom axolotl.utils.tee import prepare_debug_log\nfrom axolotl.utils.trackio_ import setup_trackio_env_vars\nfrom axolotl.utils.trainer import prepare_optim_env\nfrom axolotl.utils.wandb_ import setup_wandb_env_vars\n\nLOG = get_logger(__name__)\n\n\ndef _coerce_value(value: Any, existing: Optional[Any] = None) -> Any:\n    \"\"\"Coerce a string CLI value to its most likely Python type.\n\n    If an existing value is present in the config, its type is used to guide\n    casting.  Otherwise, YAML-style inference is applied: booleans, ints,\n    floats, and None literals are recognised automatically.\n\n    Args:\n        value: The raw value (typically a string from the CLI).\n        existing: An optional existing config value whose type guides coercion.\n\n    Returns:\n        The value cast to the inferred or expected type.\n    \"\"\"\n    if not isinstance(value, str):\n        return value\n\n    # If the config already has a typed value, cast to match\n    if existing is not None:\n        if isinstance(existing, bool):\n            return value.lower() in (\"true\", \"1\", \"yes\")\n        if isinstance(existing, int):\n            try:\n                return int(value)\n            except (ValueError, TypeError):\n                return value\n        if isinstance(existing, float):\n            try:\n                return float(value)\n            except (ValueError, TypeError):\n                return value\n        # For other types (str, list, dict, etc.), return as-is\n        return value\n\n    # No existing value -- use YAML-style inference\n    lower = value.lower()\n    if lower in (\"true\", \"yes\"):\n        return True\n    if lower in (\"false\", \"no\"):\n        return False\n    if lower in (\"null\", \"none\", \"~\"):\n        return None\n\n    # Try int then float\n    try:\n        return int(value)\n    except ValueError:\n        pass\n    try:\n        return float(value)\n    except ValueError:\n        pass\n\n    return value\n\n\nAPI_KEY_FIELDS = {\"comet_api_key\"}\n\nTELEMETRY_MANAGER = TelemetryManager.get_instance()\n\n\ndef check_remote_config(config: Union[str, Path]) -> Union[str, Path]:\n    \"\"\"\n    First, determines if the passed config is a valid HTTPS URL. Then, attempts to query\n    for it and parse its content, first as JSON, then as YAML (YAML is preferred).\n    Finally, the parsed content is written to a local file and its path is returned.\n\n    Args:\n        config: HTTPS URL to a YAML or JSON file.\n\n    Returns:\n        Either the original `config` if it's not a valid HTTPS URL, or the path to the\n        downloaded remote config.\n\n    Raises:\n        ValueError: If the remote configuration is neither valid JSON or YAML.\n        RuntimeError: If some request-related exception occurs from the file download.\n        Exception: Catch-all for any other exception.\n    \"\"\"\n    # Check if the config is a valid HTTPS URL to a .yml or .yaml file\n    if not (isinstance(config, str) and config.startswith(\"https://\")):\n        return config  # Return the original value if it's not a valid URL\n\n    filename = os.path.basename(urlparse(config).path)\n    temp_dir = tempfile.mkdtemp()\n\n    try:\n        response = requests.get(config, timeout=30)\n        response.raise_for_status()  # Check for HTTP errors\n\n        content = response.content\n        try:\n            # Try parsing as JSON first to catch cases where JSON content is mistakenly\n            # considered YAML.\n            json.loads(content)\n\n            # Log a warning but do not raise an error; JSON is technically valid YAML.\n            # This can happen when you forget to point to a raw GitHub link.\n            LOG.warning(\n                f\"Warning: The content of the file at {config} is JSON, which is technically valid YAML but might not be intended.\"\n            )\n        except json.JSONDecodeError:\n            # If it's not valid JSON, verify it's valid YAML\n            try:\n                yaml.safe_load(content)\n            except yaml.YAMLError as err:\n                raise ValueError(\n                    f\"Failed to parse the content at {config} as YAML: {err}\"\n                ) from err\n\n        # Write the content to a file if it's valid YAML (or JSON treated as YAML)\n        output_path = Path(temp_dir) / filename\n        with open(output_path, \"wb\") as file:\n            file.write(content)\n        LOG.info(\n            f\"Using the following config obtained from {config}: \\n\\n{content.decode('utf-8')}\\n\"\n        )\n        return output_path\n\n    except requests.RequestException as err:\n        # This catches all requests-related exceptions including HTTPError\n        raise RuntimeError(f\"Failed to download {config}: {err}\") from err\n    except Exception as err:\n        # Catch-all for any other exceptions\n        raise err\n\n\ndef choose_config(path: Path) -> str:\n    \"\"\"\n    Helper method for choosing a `axolotl` config YAML file (considering only files\n    ending with `.yml` or `.yaml`). If more than one config file exists in the passed\n    `path`, the user is prompted to choose one.\n\n    Args:\n        path: Directory in which config file(s) are stored.\n\n    Returns:\n        Path to either (1) the sole YAML file, or (2) if more than one YAML files exist,\n        the user-selected YAML file.\n\n    Raises:\n        ValueError: If no YAML files are found in the given `path`.\n    \"\"\"\n    yaml_files = list(path.glob(\"*.yml\")) + list(path.glob(\"*.yaml\"))\n\n    if not yaml_files:\n        raise ValueError(\n            \"No YAML config files found in the specified directory. Are you using a .yml extension?\"\n        )\n\n    if len(yaml_files) == 1:\n        LOG.info(f\"Using default YAML file '{yaml_files[0]}'\")\n        return str(yaml_files[0])\n\n    LOG.info(\"Choose a YAML file:\")\n    for idx, file in enumerate(yaml_files):\n        LOG.info(f\"{idx + 1}. {file}\")\n\n    chosen_file = None\n    while chosen_file is None:\n        try:\n            choice = int(input(\"Enter the number of your choice: \"))\n            if 1 <= choice <= len(yaml_files):\n                chosen_file = str(yaml_files[choice - 1])\n            else:\n                LOG.info(\"Invalid choice. Please choose a number from the list.\")\n        except ValueError:\n            LOG.info(\"Invalid input. Please enter a number.\")\n\n    return chosen_file\n\n\ndef prepare_plugins(cfg: DictDefault):\n    \"\"\"\n    Registers the plugins for the given configuration.\n\n    Args:\n        cfg: Dictionary mapping `axolotl` config keys to values.\n    \"\"\"\n    if cfg.get(\"plugins\"):\n        plugin_manager = PluginManager.get_instance()\n        for plugin_name in cfg[\"plugins\"]:\n            plugin_manager.register(plugin_name)\n        for plugin in plugin_manager.plugins.values():\n            plugin.register(cfg)\n\n\ndef plugin_set_cfg(cfg: DictDefault):\n    if cfg.get(\"plugins\"):\n        plugin_manager = PluginManager.get_instance()\n        plugin_manager.cfg = cfg\n\n\n@send_errors\ndef load_cfg(\n    config: str | Path | DictDefault = Path(\"examples/\"), **kwargs\n) -> DictDefault:\n    \"\"\"\n    Loads the `axolotl` configuration stored at `config`, validates it, and performs\n    various setup.\n\n    Args:\n        config: Path (local or remote) to `axolotl` config YAML file.\n        kwargs: Additional keyword arguments to override config file values.\n\n    Returns:\n        `DictDefault` mapping configuration keys to values.\n    \"\"\"\n    if isinstance(config, (str, Path)):\n        config = check_remote_config(config)\n        if Path(config).is_dir():\n            config = choose_config(Path(config))\n\n        # Load the config from the yaml file\n        with open(config, encoding=\"utf-8\") as file:\n            cfg: DictDefault = DictDefault(yaml.safe_load(file))\n\n        cfg.axolotl_config_path = config\n    else:\n        cfg = config\n        with NamedTemporaryFile(\n            mode=\"w\", delete=False, suffix=\".yml\", prefix=\"axolotl_config_\"\n        ) as temp_file:\n            temp_file.write(yaml.dump(config.to_dict()))\n            temp_file.close()\n        cfg.axolotl_config_path = temp_file.name\n\n    TELEMETRY_MANAGER.send_event(event_type=\"config-loaded\", properties=cfg)\n\n    # If there are any options passed in the cli, if it is something that seems valid\n    # from the yaml, then overwrite the value\n    cfg_keys = cfg.keys()\n\n    # Separate nested (dot-notation) kwargs from flat kwargs\n    nested_kwargs: dict[str, dict[str, Any]] = {}\n    flat_kwargs: dict[str, Any] = {}\n    for key, value in kwargs.items():\n        if \"__\" in key:\n            parent, child = key.split(\"__\", 1)\n            nested_kwargs.setdefault(parent, {})[child] = value\n        else:\n            flat_kwargs[key] = value\n\n    # Apply flat kwargs\n    for key, value in flat_kwargs.items():\n        # If not strict, allow writing to cfg even if it's not in the yml already\n        if key in cfg_keys or not cfg.strict:\n            cfg[key] = _coerce_value(value, cfg.get(key))\n\n    # Apply nested kwargs (e.g., trl__beta -> cfg.trl.beta)\n    for parent, children in nested_kwargs.items():\n        if parent not in cfg_keys and cfg.strict:\n            continue\n        if cfg[parent] is None:\n            cfg[parent] = {}\n        if not isinstance(cfg[parent], dict):\n            LOG.warning(\n                \"Overwriting non-dict value for '%s' with nested CLI overrides\", parent\n            )\n            cfg[parent] = {}\n        for child_key, child_value in children.items():\n            existing_child = cfg[parent].get(child_key)\n            cfg[parent][child_key] = _coerce_value(child_value, existing_child)\n\n    try:\n        device_props = torch.cuda.get_device_properties(\"cuda\")\n        gpu_version = \"sm_\" + str(device_props.major) + str(device_props.minor)\n    except (RuntimeError, AssertionError):\n        gpu_version = None\n\n    prepare_plugins(cfg)\n\n    cfg = validate_config(\n        cfg,\n        capabilities={\n            \"bf16\": is_torch_bf16_gpu_available(),\n            \"fp8\": compute_supports_fp8(),\n            \"tf32\": is_torch_tf32_available(),\n            \"n_gpu\": int(os.environ.get(\"WORLD_SIZE\", 1)),\n            \"compute_capability\": gpu_version,\n        },\n        env_capabilities={\n            \"torch_version\": str(torch.__version__).split(\"+\", maxsplit=1)[0]\n        },\n    )\n\n    # NOTE(djsaunde): We start outputting to output_dir/debug.log at this point since we\n    # have to wait for cfg.output to be resolved. We could call this earlier if we write\n    # to a temporary file, and then move it later.\n    prepare_debug_log(cfg)\n    prepare_optim_env(cfg)\n    normalize_config(cfg)\n    normalize_cfg_datasets(cfg)\n    setup_wandb_env_vars(cfg)\n    setup_mlflow_env_vars(cfg)\n    setup_comet_env_vars(cfg)\n    setup_trackio_env_vars(cfg)\n    plugin_set_cfg(cfg)\n\n    TELEMETRY_MANAGER.send_event(event_type=\"config-processed\", properties=cfg)\n    cfg_to_log = {\n        k: \"[REDACTED]\" if k in API_KEY_FIELDS else v\n        for k, v in cfg.items()\n        if v is not None\n    }\n    LOG.info(\n        \"config:\\n%s\",\n        json.dumps(cfg_to_log, indent=2, default=str, sort_keys=True),\n    )\n\n    return cfg\n\n\ndef compute_supports_fp8() -> bool:\n    try:\n        compute_capability = torch.cuda.get_device_capability()\n        return compute_capability >= (9, 0)\n    except RuntimeError:\n        return False\n"
  },
  {
    "path": "src/axolotl/cli/delinearize_llama4.py",
    "content": "\"\"\"\nCLI tool to delinearize quantized/Linearized Llama-4 models.\n\"\"\"\n\nimport os\nfrom pathlib import Path\nfrom typing import Generator, Union\n\nimport fire\nimport torch\nfrom accelerate import init_empty_weights\nfrom transformers import AutoProcessor\n\n\ndef iter_convert_patched_to_hf(model_state_dict, num_experts) -> Generator:\n    keys = list(model_state_dict.keys())\n    for key in keys:\n        if \".feed_forward.experts.\" not in key:\n            yield key, model_state_dict[key]\n        if \".feed_forward.experts.gate_projs\" in key:\n            # gate gets fused with up so skip the yield on this and we'll fuse it when asking for the up\n            continue\n        if \".feed_forward.experts.up_projs\" in key:\n            if \".feed_forward.experts.up_projs.0.\" in key:\n                # handle the re-shape and fusing of gate and up, and conversion from linear to parameter\n                prefix = key.split(\".up_projs.0.\")[0]\n                key = f\"{prefix}.gate_up_proj\"\n                # grab all the up_projs and gate_projs across all experts\n                gate_stacked = torch.stack(\n                    [\n                        model_state_dict[\n                            f\"{prefix}.gate_projs.{expert_idx}.weight\"\n                        ].transpose(0, 1)\n                        for expert_idx in range(num_experts)\n                    ]\n                )\n                up_stacked = torch.stack(\n                    [\n                        model_state_dict[\n                            f\"{prefix}.up_projs.{expert_idx}.weight\"\n                        ].transpose(0, 1)\n                        for expert_idx in range(num_experts)\n                    ]\n                )\n                gate_up_proj = torch.cat((gate_stacked, up_stacked), dim=-1)\n                del gate_stacked, up_stacked\n                yield key, gate_up_proj\n            else:\n                del model_state_dict[key]\n                continue\n        if \".feed_forward.experts.down_projs\" in key:\n            if \".feed_forward.experts.down_projs.0.\" in key:\n                # handle the re-shape and fusing of gate and up, and conversion from linear to parameter\n                prefix = key.split(\".down_projs.0.\")[0]\n                key = f\"{prefix}.down_proj\"\n                # grab all the down_projs across all experts\n                down_stacked = torch.stack(\n                    [\n                        model_state_dict[\n                            f\"{prefix}.down_projs.{expert_idx}.weight\"\n                        ].transpose(0, 1)\n                        for expert_idx in range(num_experts)\n                    ]\n                )\n                yield key, down_stacked\n            else:\n                del model_state_dict[key]\n                continue\n\n\ndef do_cli(model: Union[Path, str], output: Union[Path, str]) -> None:\n    \"\"\"\n    Convert a patched HF format Llama4 model (with separated projections)\n    back to the original HF format (with fused projections).\n\n    Args:\n        model: Path to the patched HF model\n        output: Path to save the converted model\n    \"\"\"\n    print(f\"Loading model from {model}\")\n    from axolotl.monkeypatch.models.llama4.modeling import (\n        patch_llama4_linearized_modeling,\n    )\n\n    unpatch_llama4 = patch_llama4_linearized_modeling()\n    from transformers import Llama4ForConditionalGeneration\n\n    model_ = Llama4ForConditionalGeneration.from_pretrained(model, dtype=torch.bfloat16)\n    processor = AutoProcessor.from_pretrained(model)\n    processor.save_pretrained(output)\n\n    device = model_.device.type\n    if device == \"cuda\":\n        print(\n            f\"peak memory allocated: {torch.cuda.max_memory_allocated() / 1024**2} MB\"\n        )\n        print(f\"peak memory reserved: {torch.cuda.max_memory_reserved() / 1024**2} MB\")\n    model_config = model_.config\n    config = model_.config.get_text_config()\n\n    # Get key dimensions from the config\n    hidden_size = config.hidden_size\n    intermediate_size = config.intermediate_size\n    num_experts = config.num_local_experts\n\n    print(\n        f\"Model dimensions: hidden_size={hidden_size}, intermediate_size={intermediate_size}, num_experts={num_experts}\"\n    )\n\n    # Create output directory if it doesn't exist\n    os.makedirs(output, exist_ok=True)\n\n    # Get state dict\n    state_dict = model_.state_dict()\n    del model_\n\n    # Create a new state dict for the converted model\n    converted_state_dict = {}\n\n    # First, copy all keys that don't need modification\n    for key, value in iter_convert_patched_to_hf(state_dict, num_experts):\n        converted_state_dict[key] = value\n\n    del state_dict\n    if device == \"cuda\":\n        torch.cuda.empty_cache()\n        print(\"State dict converted.\")\n        print(\n            f\"peak memory allocated: {torch.cuda.max_memory_allocated() / 1024**2} MB\"\n        )\n        print(f\"peak memory reserved: {torch.cuda.max_memory_reserved() / 1024**2} MB\")\n    # Ideally re-load the model import to load the converted state dict\n    # Save the converted model\n    with init_empty_weights():\n        unpatch_llama4()\n        model_ = Llama4ForConditionalGeneration(model_config)\n\n    if device == \"cuda\":\n        print(\"State dict loaded into model.\")\n        print(\n            f\"peak memory allocated: {torch.cuda.max_memory_allocated() / 1024**2} MB\"\n        )\n        print(f\"peak memory reserved: {torch.cuda.max_memory_reserved() / 1024**2} MB\")\n    model_.load_state_dict(converted_state_dict, strict=False, assign=True)\n    print(f\"Saving converted model to {output}...\")\n    model_.save_pretrained(output)\n\n    print(f\"Model successfully converted and saved to {output}\")\n\n\nif __name__ == \"__main__\":\n    fire.Fire(do_cli)\n"
  },
  {
    "path": "src/axolotl/cli/evaluate.py",
    "content": "\"\"\"CLI to run evaluation on a model.\"\"\"\n\nimport os\nfrom pathlib import Path\nfrom typing import Union\n\nimport fire\nfrom transformers.hf_argparser import HfArgumentParser\n\nfrom axolotl.cli.args import TrainerCliArgs\nfrom axolotl.cli.checks import check_accelerate_default_config, check_user_token\nfrom axolotl.cli.config import load_cfg\nfrom axolotl.common.datasets import load_datasets, load_preference_datasets\nfrom axolotl.evaluate import evaluate\nfrom axolotl.utils.dict import DictDefault\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\n\ndef do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:\n    \"\"\"\n    Evaluates a `transformers` model by first loading the dataset(s) specified in the\n    `axolotl` config, and then calling `axolotl.evaluate.evaluate`, which computes\n    evaluation metrics on the given dataset(s) and writes them to disk.\n\n    Args:\n        cfg: Dictionary mapping `axolotl` config keys to values.\n        cli_args: CLI arguments.\n    \"\"\"\n\n    check_accelerate_default_config()\n    if int(os.getenv(\"LOCAL_RANK\", \"0\")) == 0:\n        check_user_token()\n\n    if cfg.rl:\n        dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)\n    else:\n        dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)\n\n    evaluate(cfg=cfg, dataset_meta=dataset_meta)\n\n\ndef do_cli(config: Union[Path, str] = Path(\"examples/\"), **kwargs) -> None:\n    \"\"\"\n    Parses `axolotl` config, CLI args, and calls `do_evaluate`.\n\n    Args:\n        config: Path to `axolotl` config YAML file.\n        kwargs: Additional keyword arguments to override config file values.\n    \"\"\"\n\n    parsed_cfg = load_cfg(config, **kwargs)\n    parser = HfArgumentParser(TrainerCliArgs)\n    parsed_cli_args, _ = parser.parse_args_into_dataclasses(\n        return_remaining_strings=True\n    )\n    do_evaluate(parsed_cfg, parsed_cli_args)\n\n\nif __name__ == \"__main__\":\n    fire.Fire(do_cli)\n"
  },
  {
    "path": "src/axolotl/cli/inference.py",
    "content": "\"\"\"CLI to run inference on a trained model.\"\"\"\n\nimport importlib\nimport sys\nfrom pathlib import Path\nfrom threading import Thread\nfrom typing import Union\n\nimport fire\nimport torch\nimport transformers\nfrom transformers import GenerationConfig, TextIteratorStreamer, TextStreamer\n\nfrom axolotl.cli.args import InferenceCliArgs\nfrom axolotl.cli.config import load_cfg\nfrom axolotl.cli.utils import load_model_and_tokenizer\nfrom axolotl.cli.utils.diffusion import (\n    diffusion_inference,\n    launch_diffusion_gradio_ui,\n)\nfrom axolotl.integrations.base import PluginManager\nfrom axolotl.telemetry.errors import send_errors\nfrom axolotl.utils.chat_templates import (\n    get_chat_template_from_config,\n)\nfrom axolotl.utils.dict import DictDefault\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\n\ndef get_multi_line_input() -> str:\n    \"\"\"\n    Gets multi-line input from terminal.\n\n    Returns:\n        Possibly multi-line, possibly empty stdin input as a string.\n    \"\"\"\n    print(\"Give me an instruction (Ctrl + D to submit): \")\n    print(\"=\" * 80)\n\n    instruction = \"\"\n    for line in sys.stdin:\n        instruction += line\n\n    return instruction\n\n\n@send_errors\ndef do_inference(\n    *,\n    cfg: DictDefault,\n    cli_args: InferenceCliArgs,\n):\n    \"\"\"\n    Runs inference on the command line in a loop. User input is accepted, a chat\n    template is (optionally) applied, and the model specified in the `axolotl` config is\n    used to generate completions according to a default generation config.\n\n    Args:\n        cfg: Dictionary mapping `axolotl` config keys to values.\n        cli_args: Inference-specific CLI arguments.\n    \"\"\"\n    model, tokenizer, _ = load_model_and_tokenizer(cfg=cfg, inference=True)\n    prompter = cli_args.prompter\n\n    prompter_module = None\n    chat_template_str = None\n    if prompter:\n        prompter_module = getattr(\n            importlib.import_module(\"axolotl.prompters\"), prompter\n        )\n    elif cfg.chat_template:\n        chat_template_str = get_chat_template_from_config(\n            cfg, ds_cfg=None, tokenizer=tokenizer\n        )\n    elif cfg.datasets and cfg.datasets[0].type == \"chat_template\":\n        chat_template_str = get_chat_template_from_config(\n            cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer\n        )\n\n    model = model.to(cfg.device, dtype=cfg.torch_dtype)\n\n    # Detect diffusion mode\n    plugin_manager = PluginManager.get_instance()\n    is_diffusion = any(\n        plugin.__class__.__name__ == \"DiffusionPlugin\"\n        for plugin in plugin_manager.plugins.values()\n    )\n\n    if is_diffusion:\n        print(\"=\" * 80)\n        print(\"Commands:\")\n        print(\":complete N -> completion mode with N tokens (default 64)\")\n        print(\":mask R     -> random masking with ratio R (0.0–1.0)\")\n\n    while True:\n        print(\"=\" * 80)\n        instruction = get_multi_line_input()\n        if not instruction:\n            return\n\n        if prompter_module:\n            prompt: str = next(\n                prompter_module().build_prompt(instruction=instruction.strip(\"\\n\"))\n            )\n        else:\n            prompt = instruction.strip()\n\n        if chat_template_str:\n            batch = tokenizer.apply_chat_template(\n                [\n                    {\n                        \"role\": \"user\",\n                        \"content\": prompt,\n                    }\n                ],\n                return_tensors=\"pt\",\n                add_special_tokens=True,\n                add_generation_prompt=True,\n                chat_template=chat_template_str,\n                tokenize=True,\n                return_dict=True,\n            )\n        else:\n            batch = tokenizer(prompt, return_tensors=\"pt\", add_special_tokens=True)\n\n        print(\"=\" * 80)\n        model.eval()\n        with torch.no_grad():\n            if is_diffusion:\n                diffusion_inference(\n                    model=model,\n                    tokenizer=tokenizer,\n                    cfg=cfg,\n                    prompt=prompt,\n                    chat_template_str=chat_template_str,\n                )\n                continue\n\n            generation_config = GenerationConfig(\n                repetition_penalty=1.1,\n                max_new_tokens=1024,\n                temperature=0.9,\n                top_p=0.95,\n                top_k=40,\n                bos_token_id=tokenizer.bos_token_id,\n                eos_token_id=tokenizer.eos_token_id,\n                pad_token_id=tokenizer.pad_token_id,\n                do_sample=True,\n                use_cache=True,\n                return_dict_in_generate=True,\n                output_attentions=False,\n                output_hidden_states=False,\n                output_scores=False,\n            )\n            streamer = TextStreamer(tokenizer)\n            generated = model.generate(\n                inputs=batch[\"input_ids\"].to(cfg.device),\n                generation_config=generation_config,\n                streamer=streamer,\n            )\n        print(\"=\" * 80)\n        print(tokenizer.decode(generated[\"sequences\"].cpu().tolist()[0]))\n\n\n@send_errors\ndef do_inference_gradio(\n    *,\n    cfg: DictDefault,\n    cli_args: InferenceCliArgs,\n):\n    \"\"\"\n    Runs inference in a Gradio interface. User input is accepted, a chat template is\n    (optionally) applied, and the model specified in the `axolotl` config is used to\n    generate completions according to a default generation config.\n\n    Args:\n        cfg: Dictionary mapping `axolotl` config keys to values.\n        cli_args: Inference-specific CLI arguments.\n    \"\"\"\n    import gradio as gr\n\n    model, tokenizer, _ = load_model_and_tokenizer(cfg=cfg, inference=True)\n    prompter = cli_args.prompter\n\n    prompter_module = None\n    chat_template_str = None\n    if prompter:\n        prompter_module = getattr(\n            importlib.import_module(\"axolotl.prompters\"), prompter\n        )\n    elif cfg.chat_template:\n        chat_template_str = get_chat_template_from_config(\n            cfg, ds_cfg=None, tokenizer=tokenizer\n        )\n    elif cfg.datasets and cfg.datasets[0].type == \"chat_template\":\n        chat_template_str = get_chat_template_from_config(\n            cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer\n        )\n\n    model = model.to(cfg.device, dtype=cfg.torch_dtype)\n\n    # Detect diffusion mode\n    plugin_manager = PluginManager.get_instance()\n    is_diffusion = any(\n        plugin.__class__.__name__ == \"DiffusionPlugin\"\n        for plugin in plugin_manager.plugins.values()\n    )\n\n    if is_diffusion:\n        launch_diffusion_gradio_ui(\n            model=model,\n            tokenizer=tokenizer,\n            cfg=cfg,\n            prompter_module=prompter_module,\n            chat_template_str=chat_template_str,\n        )\n        return\n\n    def generate(instruction):\n        if not instruction:\n            return\n        if prompter_module:\n            prompt: str = next(\n                prompter_module().build_prompt(instruction=instruction.strip(\"\\n\"))\n            )\n        else:\n            prompt = instruction.strip()\n\n        if chat_template_str:\n            batch = tokenizer.apply_chat_template(\n                [\n                    {\n                        \"role\": \"user\",\n                        \"content\": prompt,\n                    }\n                ],\n                return_tensors=\"pt\",\n                add_special_tokens=True,\n                add_generation_prompt=True,\n                chat_template=chat_template_str,\n                tokenize=True,\n                return_dict=True,\n            )\n        else:\n            batch = tokenizer(prompt, return_tensors=\"pt\", add_special_tokens=True)\n\n        model.eval()\n        with torch.no_grad():\n            generation_config = GenerationConfig(\n                repetition_penalty=1.1,\n                max_new_tokens=cfg.get(\"gradio_max_new_tokens\", 1024),\n                temperature=cfg.get(\"gradio_temperature\", 0.9),\n                top_p=0.95,\n                top_k=40,\n                bos_token_id=tokenizer.bos_token_id,\n                eos_token_id=tokenizer.eos_token_id,\n                pad_token_id=tokenizer.pad_token_id,\n                do_sample=True,\n                use_cache=True,\n                return_dict_in_generate=True,\n                output_attentions=False,\n                output_hidden_states=False,\n                output_scores=False,\n            )\n            streamer = TextIteratorStreamer(tokenizer)\n            generation_kwargs = {\n                \"inputs\": batch[\"input_ids\"].to(cfg.device),\n                \"attention_mask\": batch[\"attention_mask\"].to(cfg.device),\n                \"generation_config\": generation_config,\n                \"streamer\": streamer,\n            }\n\n            thread = Thread(target=model.generate, kwargs=generation_kwargs)\n            thread.start()\n\n            all_text = \"\"\n\n            for new_text in streamer:\n                all_text += new_text\n                yield all_text\n\n    demo = gr.Interface(\n        fn=generate,\n        inputs=\"textbox\",\n        outputs=\"text\",\n        title=cfg.get(\"gradio_title\", \"Axolotl Gradio Interface\"),\n    )\n\n    demo.launch(\n        footer_links=[\"gradio\", \"settings\"],\n        share=cfg.get(\"gradio_share\", True),\n        server_name=cfg.get(\"gradio_server_name\", \"127.0.0.1\"),\n        server_port=cfg.get(\"gradio_server_port\", None),\n    )\n\n\ndef do_cli(\n    config: Union[Path, str] = Path(\"examples/\"), gradio: bool = False, **kwargs\n) -> None:\n    \"\"\"\n    Parses axolotl config, CLI args, and calls `do_inference` or `do_inference_gradio`.\n\n    Args:\n        config: Path to `axolotl` config YAML file.\n        kwargs: Additional keyword arguments to override config file values.\n    \"\"\"\n\n    parsed_cfg = load_cfg(config, inference=True, rl=None, **kwargs)\n    parsed_cfg.sample_packing = False\n    parser = transformers.HfArgumentParser(InferenceCliArgs)\n    parsed_cli_args, _ = parser.parse_args_into_dataclasses(\n        return_remaining_strings=True\n    )\n\n    if gradio:\n        do_inference_gradio(cfg=parsed_cfg, cli_args=parsed_cli_args)\n    else:\n        do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args)\n\n\nif __name__ == \"__main__\":\n    fire.Fire(do_cli)\n"
  },
  {
    "path": "src/axolotl/cli/main.py",
    "content": "\"\"\"Click CLI definitions for various axolotl commands.\"\"\"\n\nimport os\nimport subprocess  # nosec B404\nfrom typing import Literal, Optional\n\nimport click\nfrom dotenv import load_dotenv\n\nimport axolotl\nfrom axolotl.cli.args import (\n    EvaluateCliArgs,\n    PreprocessCliArgs,\n    QuantizeCliArgs,\n    TrainerCliArgs,\n    VllmServeCliArgs,\n)\nfrom axolotl.cli.art import print_axolotl_text_art\nfrom axolotl.cli.utils import (\n    add_options_from_config,\n    add_options_from_dataclass,\n    build_command,\n    fetch_from_github,\n    filter_none_kwargs,\n    generate_config_files,\n    launch_training,\n)\nfrom axolotl.integrations.lm_eval.cli import lm_eval\nfrom axolotl.utils import set_misc_env, set_pytorch_cuda_alloc_conf\nfrom axolotl.utils.logging import get_logger\nfrom axolotl.utils.schemas.config import AxolotlInputConfig\n\nLOG = get_logger(__name__)\n\nLAUNCHER_COMMAND_MAPPING = {\n    \"accelerate\": [\"accelerate\", \"launch\"],\n    \"torchrun\": [\"torchrun\"],\n}\n\n\n@click.group()\n@click.version_option(version=axolotl.__version__, prog_name=\"axolotl\")\ndef cli():\n    \"\"\"Axolotl CLI - Train and fine-tune large language models\"\"\"\n    print_axolotl_text_art()\n    load_dotenv()\n    set_pytorch_cuda_alloc_conf()\n    set_misc_env()\n\n\n@cli.command()\n@click.argument(\"config\", type=click.Path(exists=True, path_type=str))\n@click.option(\"--cloud\", default=None, type=click.Path(exists=True, path_type=str))\n@add_options_from_dataclass(PreprocessCliArgs)\n@add_options_from_config(AxolotlInputConfig)\n@filter_none_kwargs\ndef preprocess(config: str, cloud: Optional[str] = None, **kwargs):\n    \"\"\"\n    Preprocess datasets before training.\n\n    Args:\n        config: Path to `axolotl` config YAML file.\n        cloud: Path to a cloud accelerator configuration file.\n        kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`\n            config options.\n    \"\"\"\n\n    if cloud:\n        from axolotl.cli.cloud import do_cli_preprocess\n\n        do_cli_preprocess(cloud_config=cloud, config=config)\n    else:\n        from axolotl.cli.preprocess import do_cli\n\n        do_cli(config=config, **kwargs)\n\n\n@cli.command(\n    context_settings={\"ignore_unknown_options\": True, \"allow_extra_args\": True}\n)\n@click.argument(\"config\", type=click.Path(exists=True, path_type=str))\n@click.option(\n    \"--launcher\",\n    type=click.Choice([\"accelerate\", \"torchrun\", \"python\"]),\n    default=\"accelerate\",\n    help=\"Launcher to use for multi-GPU training\",\n)\n@click.option(\"--cloud\", default=None, type=click.Path(exists=True, path_type=str))\n@click.option(\n    \"--sweep\",\n    type=click.Path(exists=True, path_type=str),\n    help=\"YAML config for sweeping hyperparameters\",\n)\n@add_options_from_dataclass(TrainerCliArgs)\n@add_options_from_config(AxolotlInputConfig)\n@filter_none_kwargs\n@click.pass_context\ndef train(\n    ctx: click.Context,\n    config: str,\n    launcher: Literal[\"accelerate\", \"torchrun\", \"python\"] = \"accelerate\",\n    cloud: str | None = None,\n    sweep: str | None = None,\n    **kwargs,\n):\n    \"\"\"\n    Train or fine-tune a model.\n\n    Args:\n        ctx: Click context for extra args.\n        config: Path to `axolotl` config YAML file.\n        launcher: Launcher to use for multi-GPU training (\"accelerate\", \"torchrun\", or \"python\").\n        cloud: Path to a cloud accelerator configuration file\n        sweep: Path to YAML config for sweeping hyperparameters.\n        kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`\n            config options.\n    \"\"\"\n    # Extract launcher args from extra args (after --)\n    launcher_args = ctx.args if ctx.args else []\n\n    # Handle Ray launcher override\n    _launcher = None if kwargs.get(\"use_ray\") else launcher\n\n    # Process each configuration\n    for cfg_file, is_group in generate_config_files(config, sweep):\n        try:\n            use_exec = is_group is not True\n            launch_training(cfg_file, _launcher, cloud, kwargs, launcher_args, use_exec)\n        except subprocess.CalledProcessError as exc:\n            LOG.error(f\"Failed to train/fine-tune config '{cfg_file}': {exc}\")\n            if not sweep:\n                raise exc\n        finally:\n            # Only delete temp files, not the original config\n            if cfg_file != config:\n                os.unlink(cfg_file)\n\n\n@cli.command(\n    context_settings={\"ignore_unknown_options\": True, \"allow_extra_args\": True}\n)\n@click.argument(\"config\", type=click.Path(exists=True, path_type=str))\n@click.option(\n    \"--launcher\",\n    type=click.Choice([\"accelerate\", \"torchrun\", \"python\"]),\n    default=\"accelerate\",\n    help=\"Launcher to use for multi-GPU evaluation\",\n)\n@add_options_from_dataclass(EvaluateCliArgs)\n@add_options_from_config(AxolotlInputConfig)\n@filter_none_kwargs\n@click.pass_context\ndef evaluate(ctx: click.Context, config: str, launcher: str, **kwargs):\n    \"\"\"\n    Evaluate a model.\n\n    Args:\n        ctx: Click context for extra args.\n        config: Path to `axolotl` config YAML file.\n        launcher: Launcher to use for multi-GPU evaluation (\"accelerate\", \"torchrun\", or \"python\").\n        kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`\n            config options.\n    \"\"\"\n    # Extract launcher args from extra args (after --)\n    launcher_args = ctx.args if ctx.args else []\n\n    if launcher in LAUNCHER_COMMAND_MAPPING:\n        base_cmd = (\n            LAUNCHER_COMMAND_MAPPING[launcher]\n            + launcher_args\n            + [\"-m\", \"axolotl.cli.evaluate\"]\n        )\n        if config:\n            base_cmd.append(config)\n        cmd = build_command(base_cmd, kwargs)\n        subprocess.run(cmd, check=True)  # nosec B603\n    else:\n        from axolotl.cli.evaluate import do_cli\n\n        do_cli(config=config, **kwargs)\n\n\n@cli.command(\n    context_settings={\"ignore_unknown_options\": True, \"allow_extra_args\": True}\n)\n@click.argument(\"config\", type=click.Path(exists=True, path_type=str))\n@click.option(\n    \"--launcher\",\n    type=click.Choice([\"accelerate\", \"torchrun\", \"python\"]),\n    default=\"accelerate\",\n    help=\"Launcher to use for multi-GPU inference\",\n)\n@click.option(\"--gradio\", is_flag=True, help=\"Launch Gradio interface\")\n@add_options_from_dataclass(TrainerCliArgs)\n@add_options_from_config(AxolotlInputConfig)\n@filter_none_kwargs\n@click.pass_context\ndef inference(ctx: click.Context, config: str, launcher: str, gradio: bool, **kwargs):\n    \"\"\"\n    Run inference with a trained model.\n\n    Args:\n        ctx: Click context for extra args.\n        config: Path to `axolotl` config YAML file.\n        launcher: Launcher to use for multi-GPU inference (\"accelerate\", \"torchrun\", or \"python\").\n        gradio: Whether to use Gradio browser interface or command line for inference.\n        kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`\n            config options.\n    \"\"\"\n    # Extract launcher args from extra args (after --)\n    launcher_args = ctx.args if ctx.args else []\n\n    if launcher in LAUNCHER_COMMAND_MAPPING:\n        base_cmd = (\n            LAUNCHER_COMMAND_MAPPING[launcher]\n            + launcher_args\n            + [\"-m\", \"axolotl.cli.inference\"]\n        )\n        if config:\n            base_cmd.append(config)\n        if gradio:\n            base_cmd.append(\"--gradio\")\n        cmd = build_command(base_cmd, kwargs)\n        subprocess.run(cmd, check=True)  # nosec B603\n    else:\n        from axolotl.cli.inference import do_cli\n\n        do_cli(config=config, gradio=gradio, **kwargs)\n\n\n@cli.command(\n    context_settings={\"ignore_unknown_options\": True, \"allow_extra_args\": True}\n)\n@click.argument(\"config\", type=click.Path(exists=True, path_type=str))\n@click.option(\n    \"--launcher\",\n    type=click.Choice([\"accelerate\", \"torchrun\", \"python\"]),\n    default=\"accelerate\",\n    help=\"Launcher to use for weight merging\",\n)\n@add_options_from_dataclass(TrainerCliArgs)\n@add_options_from_config(AxolotlInputConfig)\n@filter_none_kwargs\n@click.pass_context\ndef merge_sharded_fsdp_weights(\n    ctx: click.Context, config: str, launcher: str, **kwargs\n):\n    \"\"\"\n    Merge sharded FSDP model weights.\n\n    Args:\n        ctx: Click context for extra args.\n        config: Path to `axolotl` config YAML file.\n        launcher: Launcher to use for weight merging (\"accelerate\", \"torchrun\", or \"python\").\n        kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`\n            config options.\n    \"\"\"\n    # Extract launcher args from extra args (after --)\n    launcher_args = ctx.args if ctx.args else []\n\n    if launcher in LAUNCHER_COMMAND_MAPPING:\n        base_cmd = (\n            LAUNCHER_COMMAND_MAPPING[launcher]\n            + launcher_args\n            + [\"-m\", \"axolotl.cli.merge_sharded_fsdp_weights\"]\n        )\n        if config:\n            base_cmd.append(config)\n        cmd = build_command(base_cmd, kwargs)\n        subprocess.run(cmd, check=True)  # nosec B603\n    else:\n        from axolotl.cli.merge_sharded_fsdp_weights import do_cli\n\n        do_cli(config=config, **kwargs)\n\n\n@cli.command()\n@click.argument(\"config\", type=click.Path(exists=True, path_type=str))\n@add_options_from_dataclass(TrainerCliArgs)\n@add_options_from_config(AxolotlInputConfig)\n@filter_none_kwargs\ndef merge_lora(config: str, **kwargs):\n    \"\"\"\n    Merge trained LoRA adapters into a base model.\n\n    Args:\n        config: Path to `axolotl` config YAML file.\n        kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`\n            config options.\n    \"\"\"\n    from axolotl.cli.merge_lora import do_cli\n\n    do_cli(config=config, **kwargs)\n\n\n@cli.command()\n@click.argument(\"directory\", type=click.Choice([\"examples\", \"deepspeed_configs\"]))\n@click.option(\"--dest\", help=\"Destination directory\")\ndef fetch(directory: str, dest: Optional[str]):\n    \"\"\"\n    Fetch example configs or other resources.\n\n    Available directories:\n    - examples: Example configuration files\n    - deepspeed_configs: DeepSpeed configuration files\n\n    Args:\n        directory: One of `examples`, `deepspeed_configs`.\n        dest: Optional destination directory.\n    \"\"\"\n    fetch_from_github(f\"{directory}/\", dest)\n\n\n@cli.command()\n@click.argument(\"config\", type=click.Path(exists=True, path_type=str))\n@add_options_from_dataclass(VllmServeCliArgs)\n@filter_none_kwargs\ndef vllm_serve(config: str, **cli_args: VllmServeCliArgs):\n    from axolotl.cli.vllm_serve import do_vllm_serve\n\n    do_vllm_serve(config, cli_args)\n\n\n@cli.command()\n@click.argument(\"config\", type=click.Path(exists=True, path_type=str))\n@add_options_from_dataclass(QuantizeCliArgs)\n@filter_none_kwargs\ndef quantize(config: str, **cli_args: QuantizeCliArgs):\n    from axolotl.cli.quantize import do_quantize\n\n    do_quantize(config, cli_args)\n\n\n@cli.command()\n@click.argument(\"model\", type=click.Path(exists=True, path_type=str))\n@click.argument(\"output\", type=click.Path(exists=False, path_type=str))\ndef delinearize_llama4(model: str, output: str):\n    from axolotl.cli.delinearize_llama4 import do_cli as do_delinearize_llama4\n\n    do_delinearize_llama4(model, output)\n\n\ncli.add_command(lm_eval)\n\n\ndef main():\n    cli()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "src/axolotl/cli/merge_lora.py",
    "content": "\"\"\"CLI to merge a trained LoRA into a base model.\"\"\"\n\nfrom pathlib import Path\nfrom typing import Union\n\nimport fire\n\nfrom axolotl.cli.config import load_cfg\nfrom axolotl.cli.utils import load_model_and_tokenizer\nfrom axolotl.telemetry.errors import send_errors\nfrom axolotl.utils.dict import DictDefault\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\n\n@send_errors\ndef do_merge_lora(*, cfg: DictDefault) -> None:\n    \"\"\"\n    Calls `transformers`' `merge_and_unload` on the model given in the `axolotl` config\n    along with the LoRA adapters to combine them into a single base model.\n\n    Args:\n        cfg: Dictionary mapping `axolotl` config keys to values.\n    \"\"\"\n    model, tokenizer, processor = load_model_and_tokenizer(cfg=cfg)\n\n    LOG.info(\"Running merge of LoRA with base model...\")\n    model = model.merge_and_unload(progressbar=True)\n    try:\n        model.to(dtype=cfg.torch_dtype)\n    except ValueError as e:\n        LOG.warning(\"Failed to convert model to dtype %s\", cfg.torch_dtype)\n        LOG.warning(\"Ignore this if the base_model is pre-quantized.\")\n        LOG.warning(\"Error raised: %s\", e)\n\n    model.generation_config.do_sample = True\n    model.config.use_cache = True\n\n    if cfg.local_rank == 0:\n        LOG.info(f\"Saving merged model to: {str(Path(cfg.output_dir) / 'merged')}...\")\n        model.save_pretrained(\n            str(Path(cfg.output_dir) / \"merged\"),\n            progressbar=True,\n        )\n        tokenizer.save_pretrained(\n            str(Path(cfg.output_dir) / \"merged\"),\n            save_jinja_files=cfg.tokenizer_save_jinja_files,\n        )\n\n        if processor:\n            processor.save_pretrained(str(Path(cfg.output_dir) / \"merged\"))\n\n\ndef do_cli(config: Union[Path, str] = Path(\"examples/\"), **kwargs) -> None:\n    \"\"\"\n    Parses `axolotl` config, CLI args, and calls `do_merge_lora`. Note that various\n    config values will be overwritten to allow the LoRA merge logic to work as expected\n    (`load_in_8bit=False`, `load_in4bit=False`, `flash_attention=False`, etc.).\n\n    Args:\n        config: Path to `axolotl` config YAML file.\n        kwargs: Additional keyword arguments to override config file values.\n\n    Raises:\n        ValueError: If target directory for LoRA merged model does not exist.\n    \"\"\"\n\n    parsed_cfg = load_cfg(\n        config,\n        merge_lora=True,\n        load_in_8bit=False,\n        load_in_4bit=False,\n        quantize_moe_experts=False,\n        flash_attention=False,\n        context_parallel_size=None,\n        deepspeed=None,\n        fsdp=None,\n        fsdp_config=None,\n        **kwargs,\n    )\n\n    if not parsed_cfg.lora_model_dir and parsed_cfg.output_dir:\n        parsed_cfg.lora_model_dir = parsed_cfg.output_dir\n    if not Path(parsed_cfg.lora_model_dir).exists():\n        raise ValueError(\n            f\"Target directory for merge: `{parsed_cfg.lora_model_dir}` does not exist.\"\n        )\n\n    do_merge_lora(cfg=parsed_cfg)\n\n\nif __name__ == \"__main__\":\n    fire.Fire(do_cli)\n"
  },
  {
    "path": "src/axolotl/cli/merge_sharded_fsdp_weights.py",
    "content": "\"\"\"CLI to merge sharded FSDP model checkpoints into a single combined checkpoint.\"\"\"\n\nimport json\nimport os\nimport shutil\nfrom pathlib import Path\nfrom typing import Dict, Union\n\nimport fire\nimport torch\nimport torch.distributed.checkpoint as dist_cp\nimport torch.distributed.checkpoint.format_utils as dist_cp_format_utils\nfrom accelerate import PartialState\nfrom accelerate.utils import (\n    SAFE_WEIGHTS_INDEX_NAME,\n    SAFE_WEIGHTS_NAME,\n    is_torch_version,\n)\nfrom huggingface_hub import split_torch_state_dict_into_shards\nfrom safetensors.torch import save_file as safe_save_file\nfrom torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner\n\nfrom axolotl.cli.config import load_cfg\nfrom axolotl.telemetry.errors import send_errors\nfrom axolotl.utils.logging import get_logger\nfrom axolotl.utils.train import determine_last_checkpoint\n\nLOG = get_logger(__name__)\n\n\nclass BFloat16CastPlanner(_EmptyStateDictLoadPlanner):\n    \"\"\"A custom planner to cast tensors to bfloat16 on the fly during loading.\"\"\"\n\n    def commit_tensor(self, read_item, tensor):\n        tensor.copy_(tensor.to(torch.bfloat16))\n\n\ndef _distributed_checkpoint_to_merged_weights(\n    checkpoint_dir: Union[str, Path],\n    save_path: str,\n    max_shard_size: str = \"5GB\",\n) -> Path:\n    \"\"\"\n    Passthrough to `torch.distributed.checkpoint.format_utils.dcp_to_torch_save`. Will\n    save under `save_path` as `model.safetensors`.\n\n    Args:\n        checkpoint_dir: Directory where distributed checkpoint is saved.\n        save_path: Path to save model to.\n        max_shard_size: Max size of model shards to save.\n\n    Returns:\n        Path where model is saved.\n    \"\"\"\n\n    state_dict: Dict = {}\n    save_path_ = Path(save_path)\n    save_path_.mkdir(exist_ok=True)\n    dist_cp_format_utils._load_state_dict(\n        state_dict,\n        storage_reader=dist_cp.FileSystemReader(checkpoint_dir),\n        planner=BFloat16CastPlanner(),\n        no_dist=True,\n    )\n\n    # To handle if state is a dict like {model: {...}}\n    if len(state_dict.keys()) == 1:\n        state_dict = state_dict[list(state_dict)[0]]\n\n    # Ensure all tensors are in bfloat16\n    for key, value in state_dict.items():\n        if isinstance(value, torch.Tensor) and value.dtype != torch.bfloat16:\n            state_dict[key] = value.to(torch.bfloat16)\n\n    filename_pattern = SAFE_WEIGHTS_NAME.replace(\".safetensors\", \"{suffix}.safetensors\")\n    state_dict_split = split_torch_state_dict_into_shards(\n        state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size\n    )\n\n    # Save index if sharded\n    index = None\n    if state_dict_split.is_sharded:\n        index = {\n            \"metadata\": state_dict_split.metadata,\n            \"weight_map\": state_dict_split.tensor_to_filename,\n        }\n\n    # Save the model\n    filename_to_tensors = state_dict_split.filename_to_tensors.items()\n\n    for shard_file, tensors in filename_to_tensors:\n        shard = {tensor: state_dict[tensor] for tensor in tensors}\n        safe_save_file(\n            shard, os.path.join(save_path_, shard_file), metadata={\"format\": \"pt\"}\n        )\n\n    if index is not None:\n        save_index_file = os.path.join(save_path_, SAFE_WEIGHTS_INDEX_NAME)\n        # Save the index as well\n        with open(save_index_file, \"w\", encoding=\"utf-8\") as fout:\n            content = json.dumps(index, indent=2, sort_keys=True) + \"\\n\"\n            fout.write(content)\n\n    return save_path_\n\n\n@send_errors\ndef merge_fsdp_weights(\n    checkpoint_dir: str,\n    output_path: str,\n    remove_checkpoint_dir: bool = False,\n):\n    \"\"\"\n    Merge the weights from sharded FSDP model checkpoints into a single combined checkpoint. Should be used if\n    `SHARDED_STATE_DICT` was used for the model. Weights will be saved to `{output_path}/model.safetensors`.\n\n    Note: this is a CPU-bound process.\n\n    Args:\n        checkpoint_dir (`str`):\n            The directory containing the FSDP checkpoints (can be either the model or optimizer).\n        output_path (`str`):\n            The path to save the merged checkpoint.\n        remove_checkpoint_dir (`bool`, *optional*, defaults to `False`):\n            Whether to remove the checkpoint directory after merging.\n\n    Raises:\n        ValueError: If torch version < 2.3.0, or if `checkpoint_dir` does not exist.\n    \"\"\"\n    checkpoint_dir_ = Path(checkpoint_dir)\n\n    if not is_torch_version(\">=\", \"2.3.0\"):\n        raise ValueError(\"`merge_fsdp_weights` requires PyTorch >= 2.3.0`\")\n\n    # Verify that the checkpoint directory exists\n    if not checkpoint_dir_.exists():\n        model_path_exists = (checkpoint_dir_ / \"pytorch_model_fsdp_0\").exists()\n        optimizer_path_exists = (checkpoint_dir_ / \"optimizer_0\").exists()\n        err = f\"Tried to load from {checkpoint_dir_} but couldn't find a valid metadata file.\"\n        if model_path_exists and optimizer_path_exists:\n            err += (\n                \" However, potential model and optimizer checkpoint directories exist.\"\n            )\n            err += f\"Please pass in either {checkpoint_dir_}/pytorch_model_fsdp_0 or {checkpoint_dir_}/optimizer_0\"\n            err += \"instead.\"\n        elif model_path_exists:\n            err += \" However, a potential model checkpoint directory exists.\"\n            err += (\n                f\"Please try passing in {checkpoint_dir_}/pytorch_model_fsdp_0 instead.\"\n            )\n        elif optimizer_path_exists:\n            err += \" However, a potential optimizer checkpoint directory exists.\"\n            err += f\"Please try passing in {checkpoint_dir_}/optimizer_0 instead.\"\n        raise ValueError(err)\n\n    # To setup `save` to work\n    state = PartialState()\n    if state.is_main_process:\n        LOG.info(f\"Merging FSDP weights from {checkpoint_dir_}\")\n        save_path = _distributed_checkpoint_to_merged_weights(\n            checkpoint_dir_, output_path\n        )\n        LOG.info(f\"Successfully merged FSDP weights and saved to {save_path}\")\n        if remove_checkpoint_dir:\n            LOG.info(f\"Removing old checkpoint directory {checkpoint_dir_}\")\n            shutil.rmtree(checkpoint_dir_)\n\n\ndef do_cli(config: Union[Path, str] = Path(\"examples/\"), **kwargs):\n    \"\"\"\n    Parses `axolotl` config, CLI args, and calls `merge_fsdp_weights`.\n\n    Args:\n        config: Path to `axolotl` config YAML file.\n        kwargs: Additional keyword arguments to override config file values.\n    \"\"\"\n\n    parsed_cfg = load_cfg(config, **kwargs)\n\n    fsdp_dir = Path(parsed_cfg.output_dir) / \"pytorch_model_fsdp_0\"\n    if not fsdp_dir.exists():\n        checkpoint_dir = determine_last_checkpoint(parsed_cfg, update=False)\n        if checkpoint_dir:\n            fsdp_dir = Path(checkpoint_dir) / \"pytorch_model_fsdp_0\"\n        if not fsdp_dir.exists():\n            raise ValueError(\n                f\"Could not find FSDP checkpoint `pytorch_model_fsdp_0` in {checkpoint_dir}\"\n            )\n\n    output_path = str(Path(parsed_cfg.output_dir) / \"merged\")\n    merge_fsdp_weights(\n        checkpoint_dir=str(fsdp_dir),\n        output_path=output_path,\n    )\n    state = PartialState()\n    state.wait_for_everyone()\n    LOG.info(\n        f\"FSDP SHARDED_STATE_DICT weights successfully merged to: {output_path}\",\n    )\n    LOG.info(\n        \"Merged weights are only the safetensors and doesn't include the model configuration \"\n        f\"or tokenizer which may be found in {parsed_cfg.output_dir}.\",\n    )\n\n\nif __name__ == \"__main__\":\n    fire.Fire(do_cli)\n"
  },
  {
    "path": "src/axolotl/cli/preprocess.py",
    "content": "\"\"\"CLI to run preprocessing of a dataset.\"\"\"\n\nimport os\nimport warnings\nfrom pathlib import Path\nfrom typing import Union\n\nimport fire\nimport transformers\nfrom accelerate import init_empty_weights\nfrom colorama import Fore\nfrom transformers import AutoModelForCausalLM\n\nfrom axolotl.cli.args import PreprocessCliArgs\nfrom axolotl.cli.checks import check_accelerate_default_config, check_user_token\nfrom axolotl.cli.config import load_cfg\nfrom axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH\nfrom axolotl.common.datasets import load_datasets, load_preference_datasets\nfrom axolotl.integrations.base import PluginManager\nfrom axolotl.telemetry.errors import send_errors\nfrom axolotl.utils.dict import DictDefault\nfrom axolotl.utils.logging import get_logger\nfrom axolotl.utils.trainer import disable_datasets_caching\n\nLOG = get_logger(__name__)\n\n\n@send_errors\ndef do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:\n    \"\"\"\n    Preprocesses dataset specified in axolotl config.\n\n    Args:\n        cfg: Dictionary mapping `axolotl` config keys to values.\n        cli_args: Preprocessing-specific CLI arguments.\n    \"\"\"\n    check_accelerate_default_config()\n    check_user_token()\n\n    if cli_args.iterable:\n        LOG.error(\n            \"The --iterable CLI argument for 'axolotl preprocess' is no longer \"\n            \"supported. For training, set 'streaming: true' in your YAML config or \"\n            \"pass '--streaming' in your 'axolotl train' command for on-the-fly \"\n            \"preprocessing.\"\n        )\n        return\n\n    for key in [\"skip_prepare_dataset\", \"pretraining_dataset\"]:\n        if cfg.get(key):\n            LOG.error(\n                f\"You have set `{key}:`. `preprocess` is not needed. Run the 'axolotl \"\n                \"train' CLI directly instead.\"\n            )\n            return\n\n    if not cfg.dataset_prepared_path:\n        msg = (\n            Fore.RED\n            + \"preprocess CLI called without dataset_prepared_path set, \"\n            + f\"using default path: {DEFAULT_DATASET_PREPARED_PATH}\"\n            + Fore.RESET\n        )\n        LOG.warning(msg)\n        cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH\n\n    with disable_datasets_caching():\n        plugin_manager = PluginManager.get_instance()\n        if plugin_manager.load_datasets(cfg, preprocess=True):\n            pass\n        elif cfg.rl:\n            load_preference_datasets(cfg=cfg, cli_args=cli_args)\n        else:\n            load_datasets(cfg=cfg, cli_args=cli_args)\n\n    if cli_args.download:\n        model_name = cfg.base_model\n        with warnings.catch_warnings():\n            # there are a bunch of useless UserWarnings about\n            # \"copying from a non-meta parameter in the checkpoint to a meta parameter in the current model\"\n            warnings.simplefilter(\"ignore\")\n            with init_empty_weights(include_buffers=True):\n                # fmt: off\n                try:\n                    AutoModelForCausalLM.from_pretrained(\n                        model_name, trust_remote_code=True\n                    )\n                except Exception:  # nosec B110\n                    pass\n                # fmt: on\n\n    LOG.info(\n        Fore.GREEN\n        + f\"Success! Preprocessed data path: `dataset_prepared_path: {cfg.dataset_prepared_path}`\"\n        + Fore.RESET\n    )\n\n\ndef do_cli(\n    config: Union[Path, str] = Path(\"examples/\"),\n    **kwargs,\n) -> None:\n    \"\"\"\n    Parses `axolotl` config, CLI args, and calls `do_preprocess`.\n\n    Args:\n        config: Path to `axolotl` config YAML file.\n        kwargs: Additional keyword arguments to override config file values.\n    \"\"\"\n\n    os.environ[\"AXOLOTL_IS_PREPROCESS\"] = \"1\"\n    is_preprocess = kwargs.pop(\"is_preprocess\", True)\n    parsed_cfg = load_cfg(config, is_preprocess=is_preprocess, **kwargs)\n    parsed_cfg.is_preprocess = True\n    parser = transformers.HfArgumentParser(PreprocessCliArgs)\n    parsed_cli_args, _ = parser.parse_args_into_dataclasses(\n        return_remaining_strings=True\n    )\n\n    do_preprocess(parsed_cfg, parsed_cli_args)\n\n\nif __name__ == \"__main__\":\n    fire.Fire(do_cli)\n"
  },
  {
    "path": "src/axolotl/cli/quantize.py",
    "content": "\"\"\"\nCLI to post-training quantize a model using torchao\n\"\"\"\n\nfrom pathlib import Path\nfrom typing import Union\n\nfrom transformers import AutoConfig, AutoModelForCausalLM, TorchAoConfig\n\nfrom axolotl.cli.config import load_cfg\nfrom axolotl.loaders import load_processor, load_tokenizer\nfrom axolotl.utils.logging import get_logger\nfrom axolotl.utils.quantization import (\n    TorchAOQuantDType,\n    get_quantization_config,\n    quantization_config_to_str,\n    quantize_model,\n)\n\nLOG = get_logger(__name__)\n\n\ndef do_quantize(\n    config: Union[Path, str],\n    cli_args: dict,\n):\n    \"\"\"\n    Quantizes a model's model's weights\n\n    Args:\n        config (Union[Path, str]): The path to the config file\n        cli_args (dict): Additional command-line arguments\n    \"\"\"\n\n    cfg = load_cfg(config)\n\n    if cfg.qat and cfg.quantization:\n        raise ValueError(\n            \"QAT and quantization cannot be used together. Please specify only one of qat or quantization in your config file.\"\n        )\n\n    if cfg.qat:\n        quantize_cfg = cfg.qat\n    elif cfg.quantization:\n        quantize_cfg = cfg.quantization\n    else:\n        raise ValueError(\n            \"No quantization configuration found. Please specify either qat or quantization in your config file.\"\n        )\n\n    model_path = cli_args.get(\"base_model\") or cfg.output_dir\n    if weight_dtype := cli_args.get(\"weight_dtype\"):\n        weight_dtype = TorchAOQuantDType.from_string(weight_dtype)\n    else:\n        weight_dtype = quantize_cfg.weight_dtype\n    if activation_dtype := cli_args.get(\"activation_dtype\"):\n        activation_dtype = TorchAOQuantDType.from_string(activation_dtype)\n    else:\n        activation_dtype = quantize_cfg.activation_dtype\n    group_size = cli_args.get(\"group_size\") or quantize_cfg.group_size\n    quantize_embedding = (\n        cli_args.get(\"quantize_embedding\") or quantize_cfg.quantize_embedding\n    )\n    output_dir = cli_args.get(\"output_dir\") or cfg.output_dir\n    hub_model_id = cli_args.get(\"hub_model_id\") or cfg.hub_model_id\n\n    LOG.info(f\"Loading model from {model_path}.\")\n    tokenizer = load_tokenizer(cfg)\n\n    processor = None\n    if cfg.is_multimodal:\n        processor = load_processor(cfg, tokenizer)\n\n    config = AutoConfig.from_pretrained(model_path)\n    torch_dtype = config.torch_dtype if hasattr(config, \"torch_dtype\") else None\n    model = AutoModelForCausalLM.from_pretrained(\n        model_path, device_map=\"auto\", dtype=torch_dtype\n    )\n\n    LOG.info(\n        f\"Quantizing model with configuration: \\n\"\n        f\"\\tweight_dtype: {weight_dtype}\\n\"\n        f\"\\tactivation_dtype: {activation_dtype}\\n\"\n        f\"\\tgroup_size: {group_size}\\n\"\n        f\"\\tquantize_embedding: {quantize_embedding}\"\n    )\n\n    quantize_model(\n        model, weight_dtype, group_size, activation_dtype, quantize_embedding\n    )\n\n    quantization_config = get_quantization_config(\n        weight_dtype, activation_dtype, group_size\n    )\n\n    ao_config = TorchAoConfig(\n        quant_type=quantization_config,\n        include_input_output_embeddings=quantize_embedding,\n    )\n    model.config.quantization_config = ao_config\n\n    LOG.info(f\"Saving quantized model to: {str(Path(output_dir) / 'quantized')}.\")\n    model.save_pretrained(\n        str(Path(output_dir) / \"quantized\"),\n        progressbar=True,\n    )\n    tokenizer.save_pretrained(\n        str(Path(output_dir) / \"quantized\"),\n        progressbar=True,\n        save_jinja_files=cfg.tokenizer_save_jinja_files,\n    )\n\n    if processor:\n        LOG.info(f\"Saving processor to: {str(Path(output_dir) / 'quantized')}.\")\n        processor.save_pretrained(str(Path(output_dir) / \"quantized\"))\n\n    if hub_model_id:\n        hub_model_id = (\n            hub_model_id.rstrip(\"-\")\n            + f\"-{quantization_config_to_str[type(quantization_config)]}\"\n        )\n        model.push_to_hub(hub_model_id)\n        tokenizer.push_to_hub(hub_model_id)\n        if processor:\n            processor.push_to_hub(hub_model_id)\n        LOG.info(f\"Quantized model pushed to: {hub_model_id}.\")\n\n    LOG.info(f\"Quantized model saved to: {str(Path(output_dir) / 'quantized')}.\")\n"
  },
  {
    "path": "src/axolotl/cli/train.py",
    "content": "\"\"\"CLI to run training on a model.\"\"\"\n\nimport gc\nimport os\nfrom pathlib import Path\nfrom typing import Union\n\nimport fire\nfrom accelerate import Accelerator\nfrom transformers.hf_argparser import HfArgumentParser\n\nfrom axolotl.cli.args import TrainerCliArgs\nfrom axolotl.cli.checks import check_accelerate_default_config, check_user_token\nfrom axolotl.cli.config import load_cfg\nfrom axolotl.common.datasets import load_datasets, load_preference_datasets\nfrom axolotl.integrations.base import PluginManager\nfrom axolotl.train import train\nfrom axolotl.utils.config import normalize_config, resolve_dtype\nfrom axolotl.utils.dict import DictDefault\nfrom axolotl.utils.trainer import prepare_optim_env\n\n\ndef do_train(cfg: DictDefault, cli_args: TrainerCliArgs):\n    \"\"\"\n    Trains a `transformers` model by first loading the dataset(s) specified in the\n    `axolotl` config, and then calling `axolotl.train.train`. Also runs the plugin\n    manager's `post_train_unload` once training completes.\n\n    Args:\n        cfg: Dictionary mapping `axolotl` config keys to values.\n        cli_args: Training-specific CLI arguments.\n    \"\"\"\n    check_accelerate_default_config()\n    if int(os.getenv(\"LOCAL_RANK\", \"0\")) == 0:\n        check_user_token()\n\n    plugin_manager = PluginManager.get_instance()\n    dataset_meta = plugin_manager.load_datasets(cfg, preprocess=False)\n    if not dataset_meta:\n        if cfg.rl:\n            dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)\n        else:\n            dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)\n\n    model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)\n\n    del model, tokenizer, trainer\n\n    gc.collect()\n\n    plugin_manager = PluginManager.get_instance()\n    plugin_manager.post_train_unload(cfg)\n\n\ndef do_cli(config: Union[Path, str] = Path(\"examples/\"), **kwargs):\n    \"\"\"\n    Parses `axolotl` config, CLI args, and calls `do_train`.\n\n    Args:\n        config: Path to `axolotl` config YAML file.\n        kwargs: Additional keyword arguments to override config file values.\n    \"\"\"\n    parsed_cfg = load_cfg(config, **kwargs)\n    parser = HfArgumentParser(TrainerCliArgs)\n    parsed_cli_args, _ = parser.parse_args_into_dataclasses(\n        return_remaining_strings=True\n    )\n\n    if parsed_cfg.use_ray:\n        from ray.train import RunConfig, ScalingConfig\n        from ray.train.torch import TorchTrainer\n\n        train_loop_config = {\"cfg\": parsed_cfg.to_dict(), \"cli_args\": parsed_cli_args}\n        trainer = TorchTrainer(\n            ray_train_func,\n            train_loop_config=train_loop_config,\n            scaling_config=ScalingConfig(\n                num_workers=parsed_cfg.ray_num_workers,\n                resources_per_worker=parsed_cfg.resources_per_worker.to_dict(),\n                use_gpu=True,\n            ),\n            run_config=RunConfig(\n                name=parsed_cfg.ray_run_name,\n                storage_path=Path(parsed_cfg.output_dir).absolute().as_posix(),\n            ),\n        )\n        return trainer.fit()\n    return do_train(parsed_cfg, parsed_cli_args)\n\n\ndef ray_train_func(kwargs: dict):\n    # cast `cfg` back to DictDefault (ray tune deepcopy has issues with DictDefault so needed it to be dict)\n    # also renormalize the config now that TorchTrainer has spawned distributed workers\n    cfg = DictDefault(kwargs[\"cfg\"])\n    prepare_optim_env(cfg)\n    normalize_config(cfg)\n\n    # now that we are on the worker node, we can check `is_torch_bf16_gpu_available` to resolve dtype\n    resolve_dtype(cfg)\n\n    # ray serializing objects gets rid of frozen attribute - HF expects dict not DefaultDict\n    if cfg.deepspeed and hasattr(cfg.deepspeed, \"to_dict\"):\n        cfg.deepspeed = cfg.deepspeed.to_dict()\n\n    # initialize accelerator before model instantiation\n    Accelerator(gradient_accumulation_steps=cfg.gradient_accumulation_steps)\n\n    # Register plugins in Ray workers\n    if cfg.get(\"plugins\"):\n        from axolotl.cli.config import plugin_set_cfg, prepare_plugins\n\n        prepare_plugins(cfg)\n        plugin_set_cfg(cfg)\n\n    kwargs[\"cfg\"] = cfg\n\n    do_train(**kwargs)\n\n\nif __name__ == \"__main__\":\n    fire.Fire(do_cli)\n"
  },
  {
    "path": "src/axolotl/cli/utils/__init__.py",
    "content": "\"\"\"Init for axolotl.cli.utils module.\"\"\"\n\nfrom .args import (\n    add_options_from_config,\n    add_options_from_dataclass,\n    filter_none_kwargs,\n)\nfrom .fetch import fetch_from_github\nfrom .load import load_model_and_tokenizer\nfrom .sweeps import generate_sweep_configs\nfrom .train import build_command, generate_config_files, launch_training\n\n__all__ = [\n    \"filter_none_kwargs\",\n    \"add_options_from_dataclass\",\n    \"add_options_from_config\",\n    \"build_command\",\n    \"generate_config_files\",\n    \"generate_sweep_configs\",\n    \"load_model_and_tokenizer\",\n    \"launch_training\",\n    \"fetch_from_github\",\n]\n"
  },
  {
    "path": "src/axolotl/cli/utils/args.py",
    "content": "\"\"\"Utilities for axolotl CLI args.\"\"\"\n\nimport dataclasses\nfrom functools import wraps\nfrom types import NoneType, UnionType\nfrom typing import Any, Callable, Type, Union, get_args, get_origin\n\nimport click\nfrom pydantic import BaseModel\n\n\ndef _strip_optional_type(field_type: type | str | None):\n    \"\"\"\n    Extracts the non-`None` type from an `Optional` / `Union` type.\n\n    Args:\n        field_type: Type of field for Axolotl CLI command.\n\n    Returns:\n        If the input type is `Union[T, None]` or `Optional[T]`, returns `T`. Otherwise\n            returns the input type unchanged.\n    \"\"\"\n    is_union = get_origin(field_type) is Union or isinstance(field_type, UnionType)\n    if is_union and type(None) in get_args(field_type):\n        field_type = next(\n            t for t in get_args(field_type) if not isinstance(t, NoneType)\n        )\n\n    return field_type\n\n\ndef filter_none_kwargs(func: Callable) -> Callable:\n    \"\"\"\n    Wraps function to remove `None`-valued `kwargs`.\n\n    Args:\n        func: Function to wrap.\n\n    Returns:\n        Wrapped function.\n    \"\"\"\n\n    @wraps(func)\n    def wrapper(*args, **kwargs) -> Callable:\n        \"\"\"Filters out `None`-valued `kwargs`.\"\"\"\n        filtered_kwargs = {k: v for k, v in kwargs.items() if v is not None}\n\n        return func(*args, **filtered_kwargs)\n\n    return wrapper\n\n\ndef add_options_from_dataclass(config_class: Type[Any]) -> Callable:\n    \"\"\"\n    Create Click options from the fields of a dataclass.\n\n    Args:\n        config_class: Dataclass with fields to parse from the CLI.\n\n    Returns:\n        Function decorator for Axolotl CLI command.\n    \"\"\"\n\n    def decorator(function: Callable) -> Callable:\n        # Process dataclass fields in reverse order for correct option ordering\n        for field in reversed(dataclasses.fields(config_class)):\n            field_type = _strip_optional_type(field.type)\n\n            if field_type is bool:\n                field_name = field.name.replace(\"_\", \"-\")\n                option_name = f\"--{field_name}/--no-{field_name}\"\n                function = click.option(\n                    option_name,\n                    default=field.default,\n                    help=field.metadata.get(\"description\"),\n                )(function)\n            else:\n                option_name = f\"--{field.name.replace('_', '-')}\"\n                function = click.option(\n                    option_name,\n                    type=field_type,\n                    default=field.default,\n                    help=field.metadata.get(\"description\"),\n                )(function)\n\n        return function\n\n    return decorator\n\n\ndef _is_pydantic_model(field_type: type) -> bool:\n    \"\"\"Check if a type is a Pydantic BaseModel subclass.\"\"\"\n    try:\n        return isinstance(field_type, type) and issubclass(field_type, BaseModel)\n    except TypeError:\n        return False\n\n\ndef _get_field_description(field) -> str | None:\n    \"\"\"Get description from a Pydantic field, checking both .description and json_schema_extra.\"\"\"\n    if field.description:\n        return field.description\n    if field.json_schema_extra and isinstance(field.json_schema_extra, dict):\n        return field.json_schema_extra.get(\"description\")\n    return None\n\n\ndef _add_nested_model_options(\n    function: Callable, parent_name: str, model_class: Type[BaseModel]\n) -> Callable:\n    \"\"\"\n    Add Click options for all fields of a nested Pydantic model using dot-notation.\n\n    Note: Only single-level nesting is supported (e.g., ``--trl.beta``).\n    Deeper nesting (e.g., ``--trl.scheduler.warmup``) is not handled.\n\n    Args:\n        function: Click command function to add options to.\n        parent_name: Parent field name (e.g., \"trl\").\n        model_class: Nested Pydantic model class.\n\n    Returns:\n        Function with added Click options.\n    \"\"\"\n    for sub_name, sub_field in reversed(model_class.model_fields.items()):\n        sub_type = _strip_optional_type(sub_field.annotation)\n        # Use dot notation: --parent.sub_field\n        cli_name = f\"{parent_name}.{sub_name}\".replace(\"_\", \"-\")\n        # The kwarg name uses double-underscore as separator\n        param_name = f\"{parent_name}__{sub_name}\"\n        description = _get_field_description(sub_field)\n\n        if sub_type is bool:\n            option_name = f\"--{cli_name}/--no-{cli_name}\"\n            function = click.option(\n                option_name, param_name, default=None, help=description\n            )(function)\n        else:\n            option_name = f\"--{cli_name}\"\n            click_type = {str: str, int: int, float: float}.get(sub_type)\n            function = click.option(\n                option_name, param_name, default=None, type=click_type, help=description\n            )(function)\n\n    return function\n\n\ndef add_options_from_config(config_class: Type[BaseModel]) -> Callable:\n    \"\"\"\n    Create Click options from the fields of a Pydantic model.\n\n    For fields whose type is itself a Pydantic BaseModel, dot-notation CLI options are\n    generated for each sub-field (e.g., ``--trl.beta=0.1``).\n\n    Args:\n        config_class: PyDantic model with fields to parse from the CLI\n\n    Returns:\n        Function decorator for Axolotl CLI command.\n    \"\"\"\n\n    def decorator(function: Callable) -> Callable:\n        # Process model fields in reverse order for correct option ordering\n        for name, field in reversed(config_class.model_fields.items()):\n            field_type = _strip_optional_type(field.annotation)\n\n            # Handle nested Pydantic models with dot-notation options\n            if _is_pydantic_model(field_type):\n                function = _add_nested_model_options(function, name, field_type)\n                continue\n\n            if field_type is bool:\n                field_name = name.replace(\"_\", \"-\")\n                option_name = f\"--{field_name}/--no-{field_name}\"\n                function = click.option(\n                    option_name, default=None, help=field.description\n                )(function)\n            else:\n                option_name = f\"--{name.replace('_', '-')}\"\n                function = click.option(\n                    option_name, default=None, help=field.description\n                )(function)\n\n        return function\n\n    return decorator\n"
  },
  {
    "path": "src/axolotl/cli/utils/diffusion.py",
    "content": "\"\"\"Helpers for diffusion-mode inference in CLI and Gradio.\"\"\"\n\nfrom __future__ import annotations\n\nimport gradio as gr\nfrom colorama import Fore, Style\n\nfrom axolotl.integrations.diffusion import generate, resolve_mask_token_id\nfrom axolotl.utils.dict import DictDefault\n\n\ndef diffusion_inference(\n    model,\n    tokenizer,\n    cfg,\n    prompt: str,\n    chat_template_str: str | None = None,\n):\n    \"\"\"Diffusion inference helper method.\"\"\"\n    mode = \"random\"\n    completion_tokens = 0\n    target_mask_ratio = None\n    mode, completion_tokens, target_mask_ratio, cleaned = _parse_commands(prompt)\n\n    if cleaned:\n        prompt = cleaned\n\n    info = run_diffusion(\n        model=model,\n        tokenizer=tokenizer,\n        cfg=cfg,\n        prompt=prompt,\n        chat_template_str=chat_template_str,\n        mode=mode,\n        target_mask_ratio=target_mask_ratio,\n        completion_tokens=completion_tokens,\n    )\n    masked_text = info[\"masked_text\"]\n    mask_ratio = info[\"mask_ratio\"]\n    generated_ids = info[\"generated_ids\"]\n    masked_positions = info[\"masked_positions\"]\n    orig_ids = info[\"orig_ids\"]\n\n    # Display with masked preview and colored diff\n    if masked_text is not None and mask_ratio is not None:\n        print(f\"Masked ({mask_ratio:.1%}):\\n{masked_text}\\n\")\n    if generated_ids is not None:\n        # Compute per-token style\n        styles: list[str] = []\n        for i, tid in enumerate(generated_ids):\n            if i in masked_positions:\n                if i < len(orig_ids) and tid == orig_ids[i]:\n                    styles.append(\"green\")  # correct fill\n                elif i < len(orig_ids):\n                    styles.append(\"red\")  # incorrect fill\n                else:\n                    styles.append(\"normal\")  # appended\n            else:\n                same = i < len(orig_ids) and tid == orig_ids[i]\n                styles.append(\"dim\" if same else \"normal\")\n\n        # Group contiguous spans by style\n        styled_spans: list[tuple[str, int, int]] = []\n        if generated_ids:\n            current_style = styles[0]\n            start = 0\n            for i in range(1, len(generated_ids)):\n                s = styles[i]\n                if s != current_style:\n                    styled_spans.append((current_style, start, i))\n                    current_style, start = s, i\n            styled_spans.append((current_style, start, len(generated_ids)))\n\n        out_parts = []\n        for style_name, a, b in styled_spans:\n            chunk_text = tokenizer.decode(generated_ids[a:b], skip_special_tokens=False)\n            if style_name == \"green\":\n                out_parts.append(Fore.GREEN + chunk_text + Style.RESET_ALL)\n            elif style_name == \"red\":\n                out_parts.append(Fore.RED + chunk_text + Style.RESET_ALL)\n            else:\n                if style_name == \"dim\":\n                    out_parts.append(Style.DIM + chunk_text + Style.RESET_ALL)\n                else:\n                    out_parts.append(chunk_text)\n        print(\"Generated:\\n\" + \"\".join(out_parts))\n    else:\n        print(\"Generated:\\n(no output)\")\n\n\ndef _parse_commands(text: str):\n    \"\"\"\n    Parse leading diffusion commands.\n\n    Supported at start of input (can be chained):\n      :complete N  -> completion mode with N tokens (default 64)\n      :mask R      -> random masking with ratio R in [0, 1]\n    \"\"\"\n    tokens = text.strip().split()\n    i = 0\n    mode = \"random\"\n    completion_tokens = 0\n    target_mask_ratio = None\n    consumed = 0\n    while i < len(tokens) and tokens[i].startswith(\":\"):\n        cmd = tokens[i]\n        i += 1\n        consumed = i\n        if cmd == \":complete\":\n            mode = \"completion\"\n            if i < len(tokens):\n                try:\n                    completion_tokens = int(tokens[i])\n                    i += 1\n                    consumed = i\n                except Exception:\n                    completion_tokens = 64\n            else:\n                completion_tokens = 64\n        elif cmd == \":mask\":\n            mode = \"random\"\n            if i < len(tokens):\n                try:\n                    target_mask_ratio = float(tokens[i])\n                    i += 1\n                    consumed = i\n                except Exception:\n                    target_mask_ratio = None\n        else:\n            i -= 1\n            consumed = i\n            break\n\n    cleaned = \" \".join(tokens[consumed:])\n\n    return mode, completion_tokens, target_mask_ratio, cleaned\n\n\ndef run_diffusion(\n    *,\n    model,\n    tokenizer,\n    cfg: DictDefault,\n    prompt: str,\n    chat_template_str: str | None,\n    mode: str = \"random\",\n    target_mask_ratio: float | None = None,\n    completion_tokens: int = 0,\n):\n    \"\"\"Run a single diffusion generation and return a structured result dict.\"\"\"\n    if chat_template_str:\n        batch = tokenizer.apply_chat_template(\n            [{\"role\": \"user\", \"content\": prompt}],\n            return_tensors=\"pt\",\n            add_special_tokens=True,\n            add_generation_prompt=True,\n            chat_template=chat_template_str,\n            tokenize=True,\n            return_dict=True,\n        )\n    else:\n        batch = tokenizer(prompt, return_tensors=\"pt\", add_special_tokens=True)\n\n    mask_token_id = resolve_mask_token_id(tokenizer, cfg, allow_add=False)\n\n    seq = batch[\"input_ids\"].to(cfg.device)\n    gen_mode = \"completion\" if mode == \"completion\" else \"random\"\n    comp_tokens = int(completion_tokens) if gen_mode == \"completion\" else 0\n\n    result = generate(\n        model,\n        tokenizer,\n        original_sequence=seq[:1],\n        num_diffusion_steps=cfg.diffusion.num_diffusion_steps,\n        temperature=cfg.diffusion.generation_temperature,\n        mask_token_id=int(mask_token_id),\n        mode=gen_mode,  # type: ignore[arg-type]\n        completion_tokens=comp_tokens,\n        target_mask_ratio=target_mask_ratio,\n    )\n\n    masked_text = result.get(\"masked\") if isinstance(result, dict) else None\n    mask_ratio = result.get(\"mask_ratio\") if isinstance(result, dict) else None\n    generated_ids = result.get(\"generated_ids\") if isinstance(result, dict) else None\n    masked_positions = (\n        set(result.get(\"masked_positions\") or []) if isinstance(result, dict) else set()\n    )\n    orig_ids = seq[0].detach().cpu().tolist()\n\n    return {\n        \"masked_text\": masked_text,\n        \"mask_ratio\": mask_ratio,\n        \"generated_ids\": generated_ids,\n        \"masked_positions\": masked_positions,\n        \"orig_ids\": orig_ids,\n    }\n\n\ndef render_html(\n    *,\n    generated_ids: list[int] | None,\n    orig_ids: list[int],\n    masked_positions: set[int],\n    tokenizer,\n) -> str:\n    \"\"\"Render HTML visualizing diffusion outputs.\"\"\"\n    if not generated_ids:\n        return \"<pre>Generated:\\n(no output)</pre>\"\n\n    def _style_for(i: int, tid: int) -> str:\n        if i in masked_positions:\n            if i < len(orig_ids) and tid == orig_ids[i]:\n                return \"green\"\n            if i < len(orig_ids):\n                return \"red\"\n            return \"normal\"\n        same = i < len(orig_ids) and tid == orig_ids[i]\n        return \"dim\" if same else \"normal\"\n\n    # Group contiguous spans by style to reduce HTML size\n    spans: list[tuple[str, int, int]] = []\n    if generated_ids:\n        cur = _style_for(0, generated_ids[0])\n        start = 0\n        for i in range(1, len(generated_ids)):\n            s = _style_for(i, generated_ids[i])\n            if s != cur:\n                spans.append((cur, start, i))\n                cur, start = s, i\n        spans.append((cur, start, len(generated_ids)))\n\n    html_parts = []\n    for style_name, a, b in spans:\n        txt = tokenizer.decode(generated_ids[a:b], skip_special_tokens=False)\n        if style_name == \"green\":\n            html_parts.append(f'<span style=\"color:#2e7d32\">{txt}</span>')\n        elif style_name == \"red\":\n            html_parts.append(f'<span style=\"color:#c62828\">{txt}</span>')\n        elif style_name == \"dim\":\n            html_parts.append(f'<span style=\"opacity:0.6\">{txt}</span>')\n        else:\n            html_parts.append(txt)\n\n    legend = (\n        '<div style=\"font-size:0.9em;margin-bottom:4px\">'\n        '<span style=\"color:#2e7d32\">correct</span>, '\n        '<span style=\"color:#c62828\">incorrect</span>, '\n        '<span style=\"opacity:0.6\">unchanged</span>'\n        \"</div>\"\n    )\n\n    return (\n        legend\n        + '<pre style=\"white-space:pre-wrap\">Generated:\\n'\n        + \"\".join(html_parts)\n        + \"</pre>\"\n    )\n\n\ndef launch_diffusion_gradio_ui(\n    *,\n    model,\n    tokenizer,\n    cfg: DictDefault,\n    prompter_module=None,\n    chat_template_str: str | None = None,\n):\n    \"\"\"Build and launch a simple Gradio UI for diffusion inference.\"\"\"\n    with gr.Blocks(\n        title=cfg.get(\"gradio_title\", \"Axolotl Diffusion Interface\")\n    ) as demo:\n        gr.Markdown(\n            \"\"\"\n            ## Axolotl Diffusion Inference\n            - Mode \"Random\" masks tokens at a target ratio and fills them.\n            - Mode \"Completion\" appends N masked tokens at the end and fills them.\n            \"\"\"\n        )\n\n        with gr.Row():\n            mode = gr.Radio(\n                choices=[\"random\", \"completion\"],\n                value=\"random\",\n                label=\"Mode\",\n            )\n            mask_ratio = gr.Slider(\n                minimum=0.0,\n                maximum=1.0,\n                step=0.05,\n                value=0.4,\n                label=\"Mask ratio (random mode)\",\n                interactive=True,\n            )\n            completion_tokens = gr.Number(\n                value=64,\n                precision=0,\n                label=\"Completion tokens (completion mode)\",\n                interactive=True,\n                visible=False,\n            )\n\n        instruction = gr.Textbox(label=\"Instruction\", lines=6)\n        run_btn = gr.Button(\"Generate\")\n\n        masked_preview = gr.Textbox(label=\"Masked preview\", lines=6)\n        html_out = gr.HTML(label=\"Generated\")\n\n        def _toggle_controls(selected_mode: str):\n            return (\n                gr.update(visible=(selected_mode == \"random\")),\n                gr.update(visible=(selected_mode == \"completion\")),\n            )\n\n        mode.change(\n            _toggle_controls,\n            inputs=[mode],\n            outputs=[mask_ratio, completion_tokens],\n        )\n\n        def _gen(instruction_text: str, selected_mode: str, mratio: float, ctoks: int):\n            if not instruction_text:\n                return \"\", \"<pre>Generated:\\n(no output)</pre>\"\n\n            if prompter_module:\n                prompt: str = next(\n                    prompter_module().build_prompt(\n                        instruction=instruction_text.strip(\"\\n\")\n                    )\n                )\n            else:\n                prompt = instruction_text.strip()\n\n            info = run_diffusion(\n                model=model,\n                tokenizer=tokenizer,\n                cfg=cfg,\n                prompt=prompt,\n                chat_template_str=chat_template_str,\n                mode=selected_mode,\n                target_mask_ratio=mratio if selected_mode == \"random\" else None,\n                completion_tokens=int(ctoks) if selected_mode == \"completion\" else 0,\n            )\n\n            masked_text = info.get(\"masked_text\")\n            mask_ratio_val = info.get(\"mask_ratio\")\n            generated_ids = info.get(\"generated_ids\")\n            masked_positions = info.get(\"masked_positions\") or set()\n            orig_ids = info.get(\"orig_ids\") or []\n\n            preview = (\n                f\"Masked ({mask_ratio_val:.1%}):\\n{masked_text}\"\n                if masked_text is not None and mask_ratio_val is not None\n                else \"\"\n            )\n            html = render_html(\n                generated_ids=generated_ids,\n                orig_ids=orig_ids,\n                masked_positions=masked_positions,\n                tokenizer=tokenizer,\n            )\n            return preview, html\n\n        run_btn.click(\n            _gen,\n            inputs=[instruction, mode, mask_ratio, completion_tokens],\n            outputs=[masked_preview, html_out],\n        )\n\n        demo.launch(\n            footer_links=[\"gradio\", \"settings\"],\n            share=cfg.get(\"gradio_share\", True),\n            server_name=cfg.get(\"gradio_server_name\", \"127.0.0.1\"),\n            server_port=cfg.get(\"gradio_server_port\", None),\n        )\n"
  },
  {
    "path": "src/axolotl/cli/utils/fetch.py",
    "content": "\"\"\"Utilities for axolotl fetch CLI command.\"\"\"\n\nimport concurrent.futures\nimport hashlib\nimport json\nfrom pathlib import Path\n\nimport click\nimport requests\n\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\n\ndef _download_file(\n    file_info: tuple, raw_base_url: str, dest_path: Path, dir_prefix: str\n) -> tuple[str, str]:\n    \"\"\"\n    Download a single file and return its processing status.\n\n    Args:\n        file_info: Tuple of (file_path, remote_sha).\n        raw_base_url: Base URL for raw GitHub content.\n        dest_path: Local destination directory.\n        dir_prefix: Directory prefix to filter files.\n\n    Returns:\n        Tuple of (file_path, status) where status is 'new', 'updated', or 'unchanged'.\n    \"\"\"\n    file_path, remote_sha = file_info\n    raw_url = f\"{raw_base_url}/{file_path}\"\n    dest_file = dest_path / file_path.split(dir_prefix)[-1]\n\n    # Check if file exists and needs updating\n    if dest_file.exists():\n        with open(dest_file, \"rb\") as file:\n            content = file.read()\n            # Calculate git blob SHA\n            blob = b\"blob \" + str(len(content)).encode() + b\"\\0\" + content\n            local_sha = hashlib.sha1(blob, usedforsecurity=False).hexdigest()\n\n        if local_sha == remote_sha:\n            print(f\"Skipping {file_path} (unchanged)\")\n            return file_path, \"unchanged\"\n\n        print(f\"Updating {file_path}\")\n        status = \"updated\"\n    else:\n        print(f\"Downloading {file_path}\")\n        status = \"new\"\n\n    # Create directories if needed\n    dest_file.parent.mkdir(parents=True, exist_ok=True)\n\n    # Download and save file\n    try:\n        response = requests.get(raw_url, timeout=30)\n        response.raise_for_status()\n\n        with open(dest_file, \"wb\") as file:\n            file.write(response.content)\n\n        return file_path, status\n    except (requests.RequestException, IOError) as request_error:\n        print(f\"Error downloading {file_path}: {str(request_error)}\")\n        return file_path, \"error\"\n\n\ndef fetch_from_github(\n    dir_prefix: str, dest_dir: str | None = None, max_workers: int = 5\n) -> None:\n    \"\"\"\n    Sync files from a specific directory in the GitHub repository.\n    Only downloads files that don't exist locally or have changed.\n\n    Args:\n        dir_prefix: Directory prefix to filter files (e.g., 'examples/',\n            'deepspeed_configs/').\n        dest_dir: Local destination directory.\n        max_workers: Maximum number of concurrent downloads.\n    \"\"\"\n    api_url = \"https://api.github.com/repos/axolotl-ai-cloud/axolotl/git/trees/main?recursive=1\"\n    raw_base_url = \"https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main\"\n\n    # Get repository tree with timeout\n    response = requests.get(api_url, timeout=30)\n    response.raise_for_status()\n    tree = json.loads(response.text)\n\n    # Filter for files and get their SHA\n    files = {\n        item[\"path\"]: item[\"sha\"]\n        for item in tree[\"tree\"]\n        if item[\"type\"] == \"blob\" and item[\"path\"].startswith(dir_prefix)\n    }\n\n    if not files:\n        raise click.ClickException(f\"No files found in {dir_prefix}\")\n\n    # Default destination directory is the last part of dir_prefix\n    default_dest = Path(dir_prefix.rstrip(\"/\"))\n    dest_path = Path(dest_dir) if dest_dir else default_dest\n\n    # Keep track of processed files for summary\n    files_processed: dict[str, list[str]] = {\n        \"new\": [],\n        \"updated\": [],\n        \"unchanged\": [],\n        \"error\": [],\n    }\n\n    # Process files in parallel using ThreadPoolExecutor\n    with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:\n        future_to_file = {\n            executor.submit(\n                _download_file,\n                (file_path, remote_sha),\n                raw_base_url,\n                dest_path,\n                dir_prefix,\n            ): file_path\n            for file_path, remote_sha in files.items()\n        }\n\n        # Process completed tasks as they finish\n        for future in concurrent.futures.as_completed(future_to_file):\n            file_path = future_to_file[future]\n            try:\n                file_path, status = future.result()\n                files_processed[status].append(file_path)\n            except (requests.RequestException, IOError) as request_error:\n                print(f\"Error processing {file_path}: {str(request_error)}\")\n                files_processed[\"error\"].append(file_path)\n\n    # Log summary\n    LOG.info(\"\\nSync Summary:\")\n    LOG.info(f\"New files: {len(files_processed['new'])}\")\n    LOG.info(f\"Updated files: {len(files_processed['updated'])}\")\n    LOG.info(f\"Unchanged files: {len(files_processed['unchanged'])}\")\n    if files_processed[\"error\"]:\n        LOG.info(f\"Failed files: {len(files_processed['error'])}\")\n"
  },
  {
    "path": "src/axolotl/cli/utils/load.py",
    "content": "\"\"\"Utilities for model, tokenizer, etc. loading.\"\"\"\n\nfrom typing import Any\n\nfrom transformers import (\n    PreTrainedModel,\n    PreTrainedTokenizer,\n    PreTrainedTokenizerFast,\n    ProcessorMixin,\n)\n\nfrom axolotl.loaders import load_processor, load_tokenizer\nfrom axolotl.loaders.model import ModelLoader\nfrom axolotl.utils.dict import DictDefault\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\n\ndef load_model_and_tokenizer(\n    *,\n    cfg: DictDefault,\n    inference: bool = False,\n) -> tuple[\n    PreTrainedModel,\n    PreTrainedTokenizer | PreTrainedTokenizerFast | Any,\n    ProcessorMixin | None,\n]:\n    \"\"\"\n    Helper function for loading a model, tokenizer, and processor specified in the\n    given `axolotl` config.\n\n    Args:\n        cfg: Dictionary mapping `axolotl` config keys to values.\n        inference: Boolean denoting inference mode.\n\n    Returns:\n        Tuple of (PreTrainedModel, PreTrainedTokenizer, ProcessorMixin).\n    \"\"\"\n    LOG.info(f\"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}\")\n    tokenizer = load_tokenizer(cfg)\n\n    LOG.info(\"loading model...\")\n    model_loader = ModelLoader(cfg, tokenizer, inference=inference)\n    model, _ = model_loader.load()\n\n    processor = None\n    if cfg.is_multimodal:\n        LOG.info(\"loading processor...\")\n        processor = load_processor(cfg, tokenizer)\n\n    return model, tokenizer, processor\n"
  },
  {
    "path": "src/axolotl/cli/utils/sweeps.py",
    "content": "\"\"\"Utilities for handling sweeps over configs for axolotl train CLI command\"\"\"\n\nimport random\nfrom copy import deepcopy\nfrom itertools import product\nfrom typing import Any\n\n\ndef generate_sweep_configs(\n    base_config: dict[str, list], sweeps_config: dict[str, list]\n) -> list[dict[str, Any]]:\n    \"\"\"\n    Recursively generates all possible configurations by applying sweeps to the base config.\n\n    Args:\n        base_config (dict): The original configuration dictionary\n        sweeps_config (dict): Dictionary where keys are parameters and values are either:\n            - lists of values to sweep independently\n            - or for paired values, a list of dicts under the '_' key\n\n    Returns:\n        list: List of all possible configuration dictionaries\n\n    Example:\n        sweeps_config = {\n            'learning_rate': [0.1, 0.01],\n            '_': [\n                {'load_in_8bit': True, 'adapter': 'lora'},\n                {'load_in_4bit': True, 'adapter': 'qlora'}\n            ]\n        }\n    \"\"\"\n    # Separate paired values from regular sweeps\n    paired_values = sweeps_config.get(\"_\", [])\n    regular_sweeps = {k: v for k, v in sweeps_config.items() if k != \"_\"}\n\n    # Process regular sweeps\n    param_names = list(regular_sweeps.keys())\n    param_values = list(regular_sweeps.values())\n\n    # Generate combinations for regular sweeps\n    regular_combinations = list(product(*param_values)) if param_values else [()]\n\n    # Combine regular sweeps with paired values\n    all_combinations = []\n    for reg_combo in regular_combinations:\n        if paired_values:\n            for paired_set in paired_values:\n                new_config = {}\n                # new_config = deepcopy(base_config)\n                # Combine regular parameters with paired parameters\n                full_combo = {\n                    **dict(zip(param_names, reg_combo, strict=False)),\n                    **paired_set,\n                }\n                for param_name, param_value in full_combo.items():\n                    new_config[param_name] = param_value\n                print(new_config)\n                all_combinations.append(new_config)\n        else:\n            # If no paired values, just use regular combinations\n            # new_config = deepcopy(base_config)\n            new_config = {}\n            for param_name, param_value in zip(param_names, reg_combo, strict=False):\n                new_config[param_name] = param_value\n            print(new_config)\n            all_combinations.append(new_config)\n\n    # randomize the order of trials\n    random.seed(42)\n    random.shuffle(all_combinations)\n\n    # Generate a new config for each combination\n    result_configs = []\n    for combination in all_combinations:\n        new_config = deepcopy(base_config)\n        for param_name, param_value in combination.items():\n            new_config[param_name] = param_value\n        result_configs.append(new_config)\n\n    return result_configs\n"
  },
  {
    "path": "src/axolotl/cli/utils/train.py",
    "content": "\"\"\"Utilities for axolotl train CLI command.\"\"\"\n\nimport os\nimport subprocess  # nosec\nimport sys\nimport tempfile\nfrom pathlib import Path\nfrom typing import Any, Iterator, Literal\n\nimport yaml\n\nfrom axolotl.cli.utils.sweeps import generate_sweep_configs\n\n\ndef _add_default_rdzv_args(launcher_args: list[str]) -> list[str]:\n    \"\"\"\n    Add default RDZV arguments if rdzv_endpoint is set but rdzv_backend/rdzv_id are missing.\n\n    Args:\n        launcher_args: List of launcher arguments\n\n    Returns:\n        Updated launcher args with defaults added if needed\n    \"\"\"\n    args = launcher_args.copy()\n\n    # Check if rdzv_endpoint is present\n    has_rdzv_endpoint = any(\"--rdzv_endpoint\" in arg for arg in args)\n\n    if has_rdzv_endpoint:\n        # Check if rdzv_backend is already provided\n        has_rdzv_backend = any(\"--rdzv_backend\" in arg for arg in args)\n        if not has_rdzv_backend:\n            args.extend([\"--rdzv_backend\", \"c10d\"])\n\n        # Check if rdzv_id is already provided\n        has_rdzv_id = any(\"--rdzv_id\" in arg for arg in args)\n        if not has_rdzv_id:\n            import uuid\n\n            args.extend([\"--rdzv_id\", str(uuid.uuid4())[:8]])\n\n    return args\n\n\ndef build_command(base_cmd: list[str], options: dict[str, Any]) -> list[str]:\n    \"\"\"\n    Build command list from base command and options.\n\n    Args:\n        base_cmd: Command without options.\n        options: Options to parse and append to base command.\n\n    Returns:\n        List of strings giving shell command.\n    \"\"\"\n    cmd = base_cmd.copy()\n\n    for key, value in options.items():\n        if value is None:\n            continue\n\n        key = key.replace(\"_\", \"-\")\n        cmd.append(f\"--{key}={value}\")\n\n    return cmd\n\n\ndef generate_config_files(config: str, sweep: str | None) -> Iterator[tuple[str, bool]]:\n    \"\"\"\n    Generate list of configuration files to process. Yields a tuple of the configuration file name and a boolean indicating\n    whether this is a group of configurations (i.e., a sweep).\n\n    Args:\n        config: Base configuration file\n        sweep: Sweep configuration file\n    \"\"\"\n\n    if not sweep:\n        yield config, False\n        return\n\n    # Load sweep and base configurations\n    with open(sweep, \"r\", encoding=\"utf-8\") as fin:\n        sweep_config: dict[str, list] = yaml.safe_load(fin)\n    with open(config, \"r\", encoding=\"utf-8\") as fin:\n        base_config: dict[str, list] = yaml.safe_load(fin)\n\n    # Generate all possible configurations\n    permutations = generate_sweep_configs(base_config, sweep_config)\n    is_group = len(permutations) > 1\n    base_output_dir = base_config.get(\"output_dir\", \"./model-out\")\n    for idx, permutation in enumerate(permutations, start=1):\n        permutation_dir = Path(permutation.get(\"output_dir\", base_output_dir))\n        permutation_id = f\"sweep{idx:04d}\"\n        permutation[\"output_dir\"] = str(permutation_dir / permutation_id)\n\n        temp_file = tempfile.NamedTemporaryFile(\n            mode=\"w\",\n            suffix=\".yaml\",\n            delete=False,\n            encoding=\"utf-8\",\n        )\n        yaml.dump(permutation, temp_file)\n        temp_file.close()\n        yield temp_file.name, is_group\n\n\ndef launch_training(\n    cfg_file: str,\n    launcher: Literal[\"accelerate\", \"torchrun\", \"python\"] | None,\n    cloud: str | None,\n    kwargs: dict,\n    launcher_args: list[str] | None = None,\n    use_exec: bool = False,\n) -> None:\n    \"\"\"Execute training with the given configuration.\"\"\"\n    launcher_args = launcher_args or []\n\n    if cloud:\n        _launch_cloud_training(cloud, cfg_file, launcher, kwargs, launcher_args)\n    elif launcher:\n        if launcher == \"accelerate\":\n            _launch_accelerate_training(cfg_file, kwargs, launcher_args, use_exec)\n        elif launcher == \"torchrun\":\n            _launch_torchrun_training(cfg_file, kwargs, launcher_args, use_exec)\n        elif launcher == \"python\":\n            _launch_python_training(cfg_file, kwargs)\n    elif launcher is None:\n        # handle ray train launch\n        _launch_python_training(cfg_file, kwargs)\n\n\ndef _launch_cloud_training(\n    cloud: str,\n    cfg_file: str,\n    launcher: Literal[\"accelerate\", \"torchrun\", \"python\"] | None,\n    kwargs: dict,\n    launcher_args: list[str] | None = None,\n) -> None:\n    \"\"\"Execute training via cloud launcher.\"\"\"\n    from axolotl.cli.cloud import do_cli_train\n\n    launcher_args = launcher_args or []\n    cwd = os.getcwd() if launcher else None\n\n    do_cli_train(\n        cloud_config=cloud,\n        config=cfg_file,\n        launcher=launcher or \"accelerate\",\n        launcher_args=launcher_args,\n        cwd=cwd,\n        **kwargs,\n    )\n\n\ndef _launch_accelerate_training(\n    cfg_file: str,\n    kwargs: dict,\n    launcher_args: list[str] | None = None,\n    use_exec: bool = False,\n) -> None:\n    \"\"\"Execute training via accelerate launcher.\"\"\"\n    launcher_args = launcher_args or []\n    internal_launcher_args = []\n\n    # Extract launcher-specific arguments from kwargs (legacy support)\n    if \"main_process_port\" in kwargs:\n        main_process_port = kwargs.pop(\"main_process_port\")\n        internal_launcher_args.extend([\"--main_process_port\", str(main_process_port)])\n\n    if \"num_processes\" in kwargs:\n        num_processes = kwargs.pop(\"num_processes\")\n        internal_launcher_args.extend([\"--num_processes\", str(num_processes)])\n\n    # Combine internal args with user-provided launcher args\n    all_launcher_args = internal_launcher_args + launcher_args\n\n    base_cmd = (\n        [\"accelerate\", \"launch\"] + all_launcher_args + [\"-m\", \"axolotl.cli.train\"]\n    )\n    if cfg_file:\n        base_cmd.append(cfg_file)\n\n    cmd = build_command(base_cmd, kwargs)\n    if use_exec:\n        # make sure to flush stdout and stderr before replacing the process\n        sys.stdout.flush()\n        sys.stderr.flush()\n        os.execvpe(cmd[0], cmd, os.environ)  # nosec B606\n    else:\n        subprocess.run(cmd, check=True)  # nosec B603\n\n\ndef _launch_torchrun_training(\n    cfg_file: str,\n    kwargs: dict,\n    launcher_args: list[str] | None = None,\n    use_exec: bool = False,\n) -> None:\n    \"\"\"Execute training via torchrun launcher.\"\"\"\n    launcher_args = launcher_args or []\n\n    # Add default RDZV arguments if rdzv_endpoint is set\n    launcher_args = _add_default_rdzv_args(launcher_args)\n\n    base_cmd = [\"torchrun\"] + launcher_args + [\"-m\", \"axolotl.cli.train\"]\n    if cfg_file:\n        base_cmd.append(cfg_file)\n\n    cmd = build_command(base_cmd, kwargs)\n    if use_exec:\n        # make sure to flush stdout and stderr before replacing the process\n        sys.stdout.flush()\n        sys.stderr.flush()\n        os.execvpe(cmd[0], cmd, os.environ)  # nosec B606\n    else:\n        subprocess.run(cmd, check=True)  # nosec B603\n\n\ndef _launch_python_training(cfg_file: str, kwargs: dict) -> None:\n    \"\"\"Execute training via python launcher.\"\"\"\n    from axolotl.cli.train import do_cli\n\n    do_cli(config=cfg_file, **kwargs)\n"
  },
  {
    "path": "src/axolotl/cli/vllm_serve.py",
    "content": "\"\"\"\nCLI to start the vllm server for online RL\n\"\"\"\n\nfrom dataclasses import dataclass, field\nfrom pathlib import Path\nfrom typing import Union\n\nfrom trl.scripts.vllm_serve import ScriptArguments\n\nfrom axolotl.cli.config import load_cfg\n\n\n@dataclass\nclass AxolotlScriptArguments(ScriptArguments):\n    \"\"\"\n    Additional arguments for the VLLM server\n    \"\"\"\n\n    reasoning_parser: str = field(default=\"\", kw_only=True)\n    enable_reasoning: bool | None = field(default=None, kw_only=True)\n\n\ndef do_vllm_serve(\n    config: Union[Path, str],\n    cli_args: dict,\n):\n    \"\"\"\n    Starts the VLLM server for serving LLM models used for online RL\n\n    Args\n        :param cfg: Parsed doct of the YAML config\n        :param cli_args: dict of additional command-line arguments of type VllmServeCliArgs\n\n    Returns:\n        process_id: the process id of the started VLLM server\n    \"\"\"\n    cfg = load_cfg(config)\n    model = cfg.base_model\n\n    # Determine serve module: explicit CLI/config > auto-select from vllm_lora_sync > default\n    serve_module = cli_args.get(\"serve_module\") or getattr(\n        cfg.vllm, \"serve_module\", None\n    )\n    if (\n        serve_module is None\n        and getattr(cfg, \"trl\", None)\n        and getattr(cfg.trl, \"vllm_lora_sync\", False)\n    ):\n        serve_module = \"axolotl.scripts.vllm_serve_lora\"\n    if serve_module is None:\n        serve_module = \"trl.scripts.vllm_serve\"\n    vllm_serve_main = __import__(serve_module, fromlist=[\"main\"]).main\n    tensor_parallel_size = 1\n    data_parallel_size = 1\n\n    if cli_args.get(\"tensor_parallel_size\") or cfg.vllm.tensor_parallel_size:\n        tensor_parallel_size = (\n            cli_args.get(\"tensor_parallel_size\") or cfg.vllm.tensor_parallel_size\n        )\n    if cli_args.get(\"data_parallel_size\") or cfg.vllm.data_parallel_size:\n        data_parallel_size = (\n            cli_args.get(\"data_parallel_size\") or cfg.vllm.data_parallel_size\n        )\n    host = cli_args.get(\"host\") or cfg.vllm.host\n    port = cli_args.get(\"port\") or cfg.vllm.port\n    gpu_memory_utilization = (\n        cli_args.get(\"gpu_memory_utilization\") or cfg.vllm.gpu_memory_utilization\n    )\n    dtype = cli_args.get(\"dtype\") or cfg.vllm.dtype\n    max_model_len = cli_args.get(\"max_model_len\") or cfg.vllm.max_model_len\n    enable_prefix_caching = (\n        cli_args.get(\"enable_prefix_caching\") or cfg.vllm.enable_prefix_caching\n    )\n    reasoning_parser = (\n        cli_args.get(\"reasoning_parser\") or cfg.vllm.reasoning_parser or \"\"\n    )\n    enable_reasoning = (\n        cli_args.get(\"enable_reasoning\") or cfg.vllm.enable_reasoning or False\n    )\n\n    base_kwargs = dict(\n        model=model,\n        tensor_parallel_size=tensor_parallel_size,\n        data_parallel_size=data_parallel_size,\n        host=host,\n        port=port,\n        gpu_memory_utilization=gpu_memory_utilization,\n        dtype=dtype,\n        max_model_len=max_model_len,\n        enable_prefix_caching=enable_prefix_caching,\n    )\n\n    # Use LoRAScriptArguments when serving with native LoRA support\n    if serve_module == \"axolotl.scripts.vllm_serve_lora\":\n        from axolotl.scripts.vllm_serve_lora import LoRAScriptArguments\n\n        lora_kwargs = {}\n        if hasattr(cfg, \"lora_r\") and cfg.lora_r:\n            lora_kwargs[\"max_lora_rank\"] = cfg.lora_r\n        vllm_script_args = LoRAScriptArguments(**base_kwargs, **lora_kwargs)\n    else:\n        vllm_script_args = AxolotlScriptArguments(\n            **base_kwargs,\n            reasoning_parser=reasoning_parser,\n            enable_reasoning=enable_reasoning,\n        )\n\n    vllm_serve_main(vllm_script_args)\n"
  },
  {
    "path": "src/axolotl/common/__init__.py",
    "content": ""
  },
  {
    "path": "src/axolotl/common/architectures.py",
    "content": "\"\"\"\nCommon architecture specific constants\n\"\"\"\n\nMOE_ARCH_BLOCK = {\n    \"dbrx\": \"DbrxFFN\",\n    \"jamba\": \"JambaSparseMoeBlock\",\n    \"jetmoe\": [\n        \"JetMoeMoA\",\n        \"JetMoeMoE\",\n    ],\n    \"mixtral\": \"MixtralSparseMoeBlock\",\n    \"qwen2_moe\": \"Qwen2MoeSparseMoeBlock\",\n    \"qwen3_moe\": \"Qwen3MoeSparseMoeBlock\",\n    \"qwen3_5_moe\": \"Qwen3_5MoeSparseMoeBlock\",\n    \"qwen3_vl_moe\": \"Qwen3VLMoeTextSparseMoeBlock\",\n    \"deepseek_v2\": \"DeepseekV2MoE\",\n    \"deepseek_v3\": \"DeepseekV3MoE\",\n    \"mistral4\": \"Mistral4MoE\",\n    \"gpt_oss\": \"GptOssDecoderLayer\",\n    \"lfm2_moe\": \"Lfm2MoeSparseMoeBlock\",\n    \"afmoe\": \"AfmoeMoE\",\n    \"glm4_moe\": \"Glm4MoeDecoderLayer\",\n    \"glm4_moe_lite\": \"Glm4MoeLiteDecoderLayer\",\n    \"glm_moe_dsa\": \"GlmMoeDsaDecoderLayer\",\n}\n"
  },
  {
    "path": "src/axolotl/common/const.py",
    "content": "\"\"\"Various shared constants\"\"\"\n\nDEFAULT_DATASET_PREPARED_PATH = \"last_run_prepared\"\n"
  },
  {
    "path": "src/axolotl/common/datasets.py",
    "content": "\"\"\"Dataset loading utilities.\"\"\"\n\nimport math\nimport random\nfrom dataclasses import dataclass\n\nfrom datasets import Dataset\n\nimport axolotl.monkeypatch.data.batch_dataset_fetcher  # noqa: F401\nfrom axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs\nfrom axolotl.loaders import load_processor, load_tokenizer\nfrom axolotl.telemetry.errors import send_errors\nfrom axolotl.utils.data import prepare_datasets, prepare_preference_datasets\nfrom axolotl.utils.dict import DictDefault\nfrom axolotl.utils.logging import get_logger\nfrom axolotl.utils.schemas.enums import RLType\nfrom axolotl.utils.tokenization import check_dataset_labels\n\nLOG = get_logger(__name__)\n\n\n@dataclass\nclass TrainDatasetMeta:\n    \"\"\"Dataclass with fields for training and validation datasets and metadata.\"\"\"\n\n    train_dataset: Dataset\n    eval_dataset: Dataset | None = None\n    total_num_steps: int | None = None\n\n\ndef sample_dataset(dataset: Dataset, num_samples: int) -> Dataset:\n    \"\"\"Randomly sample `num_samples` samples with replacement from `dataset`.\"\"\"\n    return dataset.select(\n        [random.randrange(0, len(dataset) - 1) for _ in range(num_samples)]  # nosec\n    )\n\n\n@send_errors\ndef load_datasets(\n    *,\n    cfg: DictDefault,\n    cli_args: PreprocessCliArgs | TrainerCliArgs | None = None,\n    debug: bool = False,\n) -> TrainDatasetMeta:\n    \"\"\"Loads one or more training or evaluation datasets, calling\n    `axolotl.utils.data.prepare_datasets`. Optionally, logs out debug information.\n\n    Args:\n        cfg: Dictionary mapping `axolotl` config keys to values.\n        cli_args: Command-specific CLI arguments.\n        debug: Whether to print out tokenization of sample. This is duplicated in\n            `cfg` and `cli_args`, but is kept due to use in our Colab notebooks.\n\n    Returns:\n        Dataclass with fields for training and evaluation datasets and the computed\n            `total_num_steps`.\n    \"\"\"\n    tokenizer = load_tokenizer(cfg)\n    processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None\n\n    train_dataset, eval_dataset, total_num_steps, prompters = prepare_datasets(\n        cfg,\n        tokenizer,\n        processor=processor,\n    )\n\n    if (\n        cfg.debug\n        or getattr(cli_args, \"debug\", False)\n        or getattr(cli_args, \"debug_text_only\", False)\n        or getattr(cli_args, \"debug_num_examples\", 0) > 0\n        or debug\n    ):\n        LOG.info(\"check_dataset_labels...\")\n\n        num_examples = cli_args.debug_num_examples if cli_args else 1\n        text_only = cli_args.debug_text_only if cli_args else False\n        try:\n            train_samples = sample_dataset(train_dataset, num_examples)\n            check_dataset_labels(\n                train_samples,\n                tokenizer,\n                num_examples=num_examples,\n                text_only=text_only,\n            )\n        except AttributeError:\n            # can't sample iterable datasets\n            pass\n\n        LOG.info(\"printing prompters...\")\n        for prompter in prompters:\n            LOG.info(prompter)\n\n    return TrainDatasetMeta(\n        train_dataset=train_dataset,\n        eval_dataset=eval_dataset,\n        total_num_steps=total_num_steps,\n    )\n\n\n@send_errors\ndef load_preference_datasets(\n    *, cfg: DictDefault, cli_args: PreprocessCliArgs | TrainerCliArgs | None = None\n) -> TrainDatasetMeta:\n    \"\"\"Loads one or more training or evaluation datasets for RL training using paired\n    preference data, calling `axolotl.utils.data.rl.prepare_preference_datasets`.\n    Optionally, logs out debug information.\n\n    Args:\n        cfg: Dictionary mapping `axolotl` config keys to values.\n        cli_args: Command-specific CLI arguments.\n\n    Returns:\n        Dataclass with fields for training and evaluation datasets and the computed\n        `total_num_steps`.\n    \"\"\"\n    tokenizer = load_tokenizer(cfg)\n    train_dataset, eval_dataset = prepare_preference_datasets(cfg, tokenizer)\n\n    total_num_steps: int | None = None\n    if cfg.rl is not RLType.GRPO:\n        total_num_steps = int(\n            math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)\n        )\n\n    if ((cli_args and cli_args.debug) or cfg.debug) and cfg.rl != RLType.ORPO:\n        LOG.info(\"check_dataset_labels...\")\n\n        num_examples = cli_args.debug_num_examples if cli_args else 1\n        text_only = cli_args.debug_text_only if cli_args else False\n\n        tokenizer = load_tokenizer(cfg)\n        train_samples = sample_dataset(train_dataset, num_examples)\n        check_dataset_labels(\n            dataset=train_samples,\n            tokenizer=tokenizer,\n            num_examples=num_examples,\n            text_only=text_only,\n            rl_mode=True,\n        )\n\n    return TrainDatasetMeta(\n        train_dataset=train_dataset,\n        eval_dataset=eval_dataset,\n        total_num_steps=total_num_steps,\n    )\n"
  },
  {
    "path": "src/axolotl/convert.py",
    "content": "\"\"\"Module containing File Reader, File Writer, Json Parser, and Jsonl Serializer classes\"\"\"\n\nimport json\nimport sys\n\n\nclass FileReader:\n    \"\"\"\n    Reads a file and returns its contents as a string\n    \"\"\"\n\n    def read(self, file_path):\n        with open(file_path, encoding=\"utf-8\") as file:\n            return file.read()\n\n\nclass FileWriter:\n    \"\"\"\n    Writes a string to a file\n    \"\"\"\n\n    def __init__(self, file_path):\n        self.file_path = file_path\n\n    def write(self, content):\n        with open(self.file_path, \"w\", encoding=\"utf-8\") as file:\n            file.write(content)\n\n\nclass StdoutWriter:\n    \"\"\"\n    Writes a string to stdout\n    \"\"\"\n\n    def write(self, content):\n        sys.stdout.write(content)\n        sys.stdout.write(\"\\n\")\n\n\nclass JsonParser:\n    \"\"\"\n    Parses a string as JSON and returns the result\n    \"\"\"\n\n    def parse(self, content):\n        return json.loads(content)\n\n\nclass JsonlSerializer:\n    \"\"\"\n    Serializes a list of JSON objects into a JSONL string\n    \"\"\"\n\n    def serialize(self, data):\n        lines = [json.dumps(item) for item in data]\n        return \"\\n\".join(lines)\n\n\nclass JsonToJsonlConverter:\n    \"\"\"\n    Converts a JSON file to JSONL\n    \"\"\"\n\n    def __init__(self, file_reader, file_writer, json_parser, jsonl_serializer):\n        self.file_reader = file_reader\n        self.file_writer = file_writer\n        self.json_parser = json_parser\n        self.jsonl_serializer = jsonl_serializer\n\n    def convert(self, input_file_path):\n        content = self.file_reader.read(input_file_path)\n        data = self.json_parser.parse(content)\n        # data = [r for r in data if r[\"conversations\"]]  # vicuna cleaned has rows with empty conversations\n        jsonl_content = self.jsonl_serializer.serialize(data)\n        self.file_writer.write(jsonl_content)\n"
  },
  {
    "path": "src/axolotl/core/__init__.py",
    "content": ""
  },
  {
    "path": "src/axolotl/core/attention/__init__.py",
    "content": ""
  },
  {
    "path": "src/axolotl/core/builders/__init__.py",
    "content": "\"\"\"Trainer builder classes\"\"\"\n\nfrom .causal import HFCausalTrainerBuilder\nfrom .rl import HFRLTrainerBuilder\n\n__all__ = [\"HFCausalTrainerBuilder\", \"HFRLTrainerBuilder\"]\n"
  },
  {
    "path": "src/axolotl/core/builders/base.py",
    "content": "# Copyright 2024 Axolotl AI. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Base class for trainer builder\"\"\"\n\nimport abc\nimport importlib\nimport logging\nimport sys\nfrom abc import abstractmethod\nfrom contextlib import suppress\nfrom pathlib import Path\nfrom typing import Any\n\nimport torch\nfrom transformers import TrainerCallback\nfrom transformers.trainer_pt_utils import AcceleratorConfig\n\nfrom axolotl.integrations.base import PluginManager\nfrom axolotl.monkeypatch.trainer.lr import patch_trainer_get_lr\nfrom axolotl.telemetry.callbacks import TelemetryCallback\nfrom axolotl.telemetry.manager import TelemetryManager\nfrom axolotl.utils import (\n    is_comet_available,\n    is_mlflow_available,\n    is_opentelemetry_available,\n    is_trackio_available,\n)\nfrom axolotl.utils.callbacks import (\n    GCCallback,\n    SaveAxolotlConfigtoWandBCallback,\n    SaveModelOnFirstStepCallback,\n)\nfrom axolotl.utils.callbacks.profiler import PytorchProfilerCallback\nfrom axolotl.utils.distributed import build_parallelism_config\nfrom axolotl.utils.schemas.enums import CustomSupportedOptimizers\n\nLOG = logging.getLogger(__name__)\n\nwith suppress(ImportError):\n    import torch._dynamo\n\n\nclass TrainerBuilderBase(abc.ABC):\n    \"\"\"Base class for trainer builder.\"\"\"\n\n    def __init__(self, cfg, model, tokenizer, processor=None):\n        self.cfg = cfg\n        self.model = model\n        self.tokenizer = tokenizer\n        self.processor = processor\n\n        self._train_dataset = None\n        self._eval_dataset = None\n        self._model_ref = None\n        self._peft_config = None\n\n        # If the model supports tagging, add the axolotl tag.\n        # This makes sure the tag is correctly pushed even if a user calls\n        # model.push_to_hub instead of trainer.push_to_hub.\n        if hasattr(model, \"add_model_tags\"):\n            model.add_model_tags([\"axolotl\"])\n\n        patch_trainer_get_lr()\n\n    @property\n    def model_ref(self):\n        return self._model_ref\n\n    @model_ref.setter\n    def model_ref(self, model):\n        self._model_ref = model\n\n    @property\n    def train_dataset(self):\n        return self._train_dataset\n\n    @train_dataset.setter\n    def train_dataset(self, dataset):\n        self._train_dataset = dataset\n\n    @property\n    def eval_dataset(self):\n        return self._eval_dataset\n\n    @eval_dataset.setter\n    def eval_dataset(self, dataset):\n        self._eval_dataset = dataset\n\n    @property\n    def peft_config(self):\n        return self._peft_config\n\n    @peft_config.setter\n    def peft_config(self, peft_config):\n        self._peft_config = peft_config\n\n    @abstractmethod\n    def build(self, total_num_steps):\n        pass\n\n    def get_callbacks(self) -> list[TrainerCallback]:\n        callbacks = []\n\n        plugin_manager = PluginManager.get_instance()\n        callbacks.extend(\n            plugin_manager.add_callbacks_pre_trainer(cfg=self.cfg, model=self.model)\n        )\n\n        if self.cfg.gc_steps:\n            callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps))\n\n        if self.cfg.dynamic_checkpoint and self.cfg.dynamic_checkpoint.enabled:\n            from axolotl.utils.callbacks.dynamic_checkpoint import (\n                DynamicCheckpointCallback,\n            )\n\n            callbacks.append(DynamicCheckpointCallback(self.cfg))\n\n        if self.cfg.use_wandb:\n            callbacks.append(\n                SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)\n            )\n        if self.cfg.use_mlflow and is_mlflow_available():\n            from axolotl.utils.callbacks.mlflow_ import (\n                SaveAxolotlConfigtoMlflowCallback,\n            )\n\n            callbacks.extend(\n                [\n                    SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path),\n                ]\n            )\n        if self.cfg.use_comet and is_comet_available():\n            from axolotl.utils.callbacks.comet_ import SaveAxolotlConfigtoCometCallback\n\n            callbacks.append(\n                SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)\n            )\n        if self.cfg.use_trackio and is_trackio_available():\n            from axolotl.utils.callbacks.trackio_ import (\n                SaveAxolotlConfigtoTrackioCallback,\n            )\n\n            callbacks.append(\n                SaveAxolotlConfigtoTrackioCallback(self.cfg.axolotl_config_path)\n            )\n        if self.cfg.use_otel_metrics and is_opentelemetry_available():\n            from axolotl.utils.callbacks.opentelemetry import (\n                OpenTelemetryMetricsCallback,\n            )\n\n            callbacks.append(OpenTelemetryMetricsCallback(self.cfg))\n        if self.cfg.save_first_step:\n            callbacks.append(SaveModelOnFirstStepCallback())\n\n        if self.cfg.profiler_steps:\n            callbacks.append(\n                PytorchProfilerCallback(\n                    steps_to_profile=self.cfg.profiler_steps,\n                    profiler_steps_start=self.cfg.profiler_steps_start,\n                )\n            )\n\n        telemetry_manager = TelemetryManager.get_instance()\n        if telemetry_manager.enabled:\n            callbacks.append(TelemetryCallback())\n\n        return callbacks\n\n    def get_post_trainer_create_callbacks(self, trainer):\n        \"\"\"\n        Callbacks added after the trainer is created, usually b/c these need access to the trainer\n        \"\"\"\n        callbacks = []\n        if self.cfg.plugins:\n            plugin_manager = PluginManager.get_instance()\n            callbacks.extend(\n                [\n                    cb\n                    for cb in plugin_manager.add_callbacks_post_trainer(\n                        self.cfg, trainer\n                    )\n                    if cb\n                ]\n            )\n        return callbacks\n\n    def hook_pre_create_training_args(self, training_arguments_kwargs):\n        # TODO\n        return training_arguments_kwargs\n\n    def hook_post_create_training_args(self, training_arguments):\n        # TODO\n        return training_arguments\n\n    def hook_pre_create_trainer(self, trainer_kwargs, trainer_cls):\n        # TODO\n        return trainer_kwargs, trainer_cls\n\n    def hook_post_create_trainer(self, trainer):\n        # TODO\n        return trainer\n\n    def _configure_warmup_and_logging(\n        self, total_num_steps: int, training_args_kwargs: dict\n    ):\n        warmup_steps: int | float = 0\n        warmup_ratio = 0.0\n        if self.cfg.warmup_steps is not None:\n            warmup_steps = self.cfg.warmup_steps\n        elif self.cfg.warmup_ratio is not None:\n            if total_num_steps:\n                warmup_steps = max(int(self.cfg.warmup_ratio * total_num_steps), 0)\n            else:\n                warmup_ratio = self.cfg.warmup_ratio\n        elif total_num_steps:\n            warmup_steps = min(int(0.03 * total_num_steps), 100)\n        else:\n            warmup_ratio = 0.03\n\n        # transformers v5\n        if warmup_ratio > 0.0 and warmup_steps == 0:\n            warmup_steps = warmup_ratio\n\n        if warmup_steps == 1:\n            warmup_steps = 2\n\n        if self.cfg.logging_steps is not None:\n            training_args_kwargs[\"logging_steps\"] = self.cfg.logging_steps\n        else:\n            training_args_kwargs[\"logging_steps\"] = (\n                500  # transformers defaults to 500\n                if not total_num_steps\n                else max(min(int(0.005 * total_num_steps), 10), 1)\n            )\n\n        training_args_kwargs[\"warmup_steps\"] = warmup_steps\n\n    def _configure_precision_settings(self, training_args_kwargs: dict):\n        training_args_kwargs[\"fp16\"] = (self.cfg.fp16 and not self.cfg.bf16) or False\n        training_args_kwargs[\"tf32\"] = True if self.cfg.tf32 is True else False\n        if self.cfg.bf16 == \"full\":\n            training_args_kwargs[\"bf16_full_eval\"] = True\n        else:\n            bf16 = self.cfg.bf16 or self.cfg.bfloat16\n            bf16 = bf16 if bf16 is not None else False\n            training_args_kwargs[\"bf16\"] = bf16\n\n    def _configure_scheduler(self, training_args_kwargs: dict):\n        if self.cfg.lr_scheduler in [\"one_cycle\", \"rex\"]:\n            training_args_kwargs[\"lr_scheduler_type\"] = \"cosine\"\n            training_args_kwargs[\"alternate_lr_scheduler_type\"] = self.cfg.lr_scheduler\n        else:\n            training_args_kwargs[\"lr_scheduler_type\"] = (\n                self.cfg.lr_scheduler if self.cfg.lr_scheduler else \"cosine\"\n            )\n        training_args_kwargs[\"lr_scheduler_kwargs\"] = (\n            self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}\n        )\n\n    def _configure_optimizer(self, training_args_kwargs: dict, trainer_kwargs: dict):\n        def _configure_custom_optimizer(\n            training_args_kwargs: dict, trainer_kwargs: dict\n        ):\n            # Common optimizer kwargs\n            optimizer_kwargs = {\n                \"lr\": training_args_kwargs[\"learning_rate\"],\n                \"weight_decay\": training_args_kwargs[\"weight_decay\"],\n            }\n\n            # Adam-specific kwargs\n            adam_kwargs: dict = {}\n            if training_args_kwargs.get(\"adam_beta1\") and training_args_kwargs.get(\n                \"adam_beta2\"\n            ):\n                adam_kwargs[\"betas\"] = (\n                    training_args_kwargs.get(\"adam_beta1\"),\n                    training_args_kwargs.get(\"adam_beta2\"),\n                )\n            if training_args_kwargs.get(\"adam_epsilon\"):\n                adam_kwargs[\"eps\"] = training_args_kwargs.get(\"adam_epsilon\")\n\n            if self.cfg.optimizer == \"muon\":\n                _, device_mesh = build_parallelism_config(self.cfg)\n\n                if device_mesh is not None:\n                    from axolotl.contribs.mit.muon.dist_muon import (\n                        DistMuonOptimizerFactory,\n                    )\n\n                    optimizer_cls = DistMuonOptimizerFactory\n                    optimizer_kwargs[\"device_mesh\"] = device_mesh\n                else:\n                    from axolotl.contribs.mit.muon import (\n                        MuonOptimizerFactory,\n                    )\n\n                    optimizer_cls = MuonOptimizerFactory\n\n                optimizer_kwargs.update(adam_kwargs)\n            elif self.cfg.optimizer == \"dion\":\n                from axolotl.contribs.mit.dion import (\n                    DionOptimizerFactory,\n                )\n\n                optimizer_cls = DionOptimizerFactory\n                optimizer_kwargs[\"dion_lr\"] = training_args_kwargs[\"dion_learning_rate\"]\n                optimizer_kwargs[\"dion_mu\"] = training_args_kwargs[\"dion_momentum\"]\n                optimizer_kwargs.update(adam_kwargs)\n                _, device_mesh = build_parallelism_config(self.cfg)\n                if device_mesh is not None:\n                    optimizer_kwargs[\"device_mesh\"] = device_mesh\n            elif self.cfg.optimizer == \"optimi_adamw\":\n                from optimi import AdamW\n\n                optimizer_kwargs[\"foreach\"] = False\n                optimizer_cls = AdamW\n                optimizer_kwargs.update(adam_kwargs)\n            elif self.cfg.optimizer == \"ao_adamw_fp8\":\n                from torchao.prototype.low_bit_optim import AdamWFp8\n\n                optimizer_cls = AdamWFp8\n                optimizer_kwargs.update(adam_kwargs)\n            elif self.cfg.optimizer == \"adopt_adamw\":\n                from axolotl.utils.optimizers.adopt import ADOPT\n\n                optimizer_cls = ADOPT\n                adam_kwargs[\"decouple\"] = True\n                optimizer_kwargs.update(adam_kwargs)\n            elif self.cfg.optimizer == \"came_pytorch\":\n                from came_pytorch import CAME\n\n                optimizer_cls = CAME\n\n                beta1 = training_args_kwargs.get(\"adam_beta1\", 0.9)\n                beta2 = training_args_kwargs.get(\"adam_beta2\", 0.999)\n                beta3 = training_args_kwargs.get(\"adam_beta3\", 0.9999)\n                eps1 = training_args_kwargs.get(\"adam_epsilon\", 1e-30)\n                eps2 = training_args_kwargs.get(\"adam_epsilon2\", 1e-16)\n                adam_kwargs[\"betas\"] = (beta1, beta2, beta3)\n                adam_kwargs[\"eps\"] = (eps1, eps2)\n\n                optimizer_kwargs.update(adam_kwargs)\n            elif self.cfg.optimizer == \"flash_adamw\":\n                from flashoptim import FlashAdamW\n\n                optimizer_cls = FlashAdamW\n                optimizer_kwargs.update(adam_kwargs)\n            elif self.cfg.optimizer == \"flash_adam\":\n                from flashoptim import FlashAdam\n\n                optimizer_cls = FlashAdam\n                optimizer_kwargs.update(adam_kwargs)\n            elif self.cfg.optimizer == \"flash_sgd\":\n                from flashoptim import FlashSGD\n\n                optimizer_cls = FlashSGD\n            elif self.cfg.optimizer == \"flash_sgdw\":\n                from flashoptim import FlashSGDW\n\n                optimizer_cls = FlashSGDW\n            elif self.cfg.optimizer == \"flash_lion\":\n                from flashoptim import FlashLion\n\n                optimizer_cls = FlashLion\n                if \"betas\" in adam_kwargs:\n                    optimizer_kwargs[\"betas\"] = adam_kwargs[\"betas\"]\n            else:\n                raise ValueError(\n                    f\"Unhandled optimizer: {self.cfg.optimizer}. Please raise an Issue.\"\n                )\n\n            # Parse any additional optimizer args from config\n            if self.cfg.optim_args:\n                if isinstance(self.cfg.optim_args, dict):\n                    optimizer_kwargs.update(self.cfg.optim_args)\n                else:\n                    # Parse string format \"key1=value1,key2=value2\"\n                    for mapping in self.cfg.optim_args.replace(\" \", \"\").split(\",\"):\n                        key, value = mapping.split(\"=\")\n                        optimizer_kwargs[key] = value\n\n            # Note: This is not used in training_args_kwargs, but in trainer_kwargs\n            trainer_kwargs[\"optimizer_cls_and_kwargs\"] = (\n                optimizer_cls,\n                optimizer_kwargs,\n            )\n\n        # Handle custom optimizer\n        custom_supported_optimizers = [opt.value for opt in CustomSupportedOptimizers]\n        if self.cfg.optimizer in custom_supported_optimizers:\n            _configure_custom_optimizer(training_args_kwargs, trainer_kwargs)\n        else:\n            # Use transformers' optimizer\n            training_args_kwargs[\"optim\"] = self.cfg.optimizer\n\n            # Parse any additional optimizer args from config\n            if self.cfg.optim_args:\n                if isinstance(self.cfg.optim_args, dict):\n                    optim_args = \",\".join(\n                        [f\"{key}={value}\" for key, value in self.cfg.optim_args.items()]\n                    )\n                else:\n                    optim_args = self.cfg.optim_args\n                training_args_kwargs[\"optim_args\"] = optim_args\n\n            if (\n                self.cfg.optimizer == \"adamw_anyprecision\"\n                and Path(self.cfg.torchdistx_path).exists()\n            ):\n                sys.path.append(self.cfg.torchdistx_path)\n                importlib.import_module(\"torchdistx\")\n\n    def _configure_hub_parameters(self, training_args_kwargs: dict):\n        if self.cfg.hub_model_id:\n            training_args_kwargs[\"hub_model_id\"] = self.cfg.hub_model_id\n            training_args_kwargs[\"push_to_hub\"] = True\n            training_args_kwargs[\"hub_private_repo\"] = True\n            training_args_kwargs[\"hub_always_push\"] = True\n\n            if self.cfg.hub_strategy:\n                training_args_kwargs[\"hub_strategy\"] = self.cfg.hub_strategy\n\n            if self.cfg.hub_revision:\n                training_args_kwargs[\"hub_revision\"] = self.cfg.hub_revision\n\n    def _configure_save_and_eval_strategy(self, training_args_kwargs: dict):\n        # save_strategy and save_steps\n        if self.cfg.save_steps:\n            training_args_kwargs[\"save_strategy\"] = \"steps\"\n            training_args_kwargs[\"save_steps\"] = self.cfg.save_steps\n        elif self.cfg.save_strategy:\n            training_args_kwargs[\"save_strategy\"] = self.cfg.save_strategy\n        else:\n            # default to saving each epoch if not defined\n            training_args_kwargs[\"save_strategy\"] = \"epoch\"\n\n        training_args_kwargs[\"save_total_limit\"] = (\n            self.cfg.save_total_limit if self.cfg.save_total_limit else 4\n        )\n\n        # eval_strategy and eval_steps\n        if not self.eval_dataset and self.cfg.val_set_size == 0:\n            # do not eval if no eval_dataset and val_set_size=0\n            training_args_kwargs[\"eval_strategy\"] = \"no\"\n        elif self.cfg.eval_steps:\n            training_args_kwargs[\"eval_strategy\"] = \"steps\"\n            training_args_kwargs[\"eval_steps\"] = self.cfg.eval_steps\n            training_args_kwargs[\"eval_on_start\"] = True\n        elif self.cfg.eval_strategy:\n            training_args_kwargs[\"eval_strategy\"] = self.cfg.eval_strategy\n            training_args_kwargs[\"eval_on_start\"] = True\n\n    def _configure_reporting(self, training_args_kwargs: dict):\n        report_to = []\n        if self.cfg.use_wandb:\n            report_to.append(\"wandb\")\n        if self.cfg.use_mlflow:\n            report_to.append(\"mlflow\")\n        if self.cfg.use_tensorboard:\n            report_to.append(\"tensorboard\")\n        if self.cfg.use_comet:\n            report_to.append(\"comet_ml\")\n        if self.cfg.use_trackio:\n            report_to.append(\"trackio\")\n\n        training_args_kwargs[\"report_to\"] = report_to\n\n        if self.cfg.use_wandb:\n            training_args_kwargs[\"run_name\"] = self.cfg.wandb_name\n        elif self.cfg.use_mlflow:\n            training_args_kwargs[\"run_name\"] = self.cfg.mlflow_run_name\n        elif self.cfg.use_trackio:\n            training_args_kwargs[\"run_name\"] = self.cfg.trackio_run_name\n        else:\n            training_args_kwargs[\"run_name\"] = None\n\n    def _configure_torch_compile(self, training_args_kwargs: dict):\n        if self.cfg.torch_compile and getattr(torch, \"_dynamo\", None):\n            torch._dynamo.config.suppress_errors = True\n            torch._dynamo.config.accumulated_cache_size_limit = 256\n            training_args_kwargs[\"torch_compile\"] = self.cfg.torch_compile\n            if self.cfg.torch_compile_backend:\n                training_args_kwargs[\"torch_compile_backend\"] = (\n                    self.cfg.torch_compile_backend\n                )\n            if self.cfg.torch_compile_mode:\n                training_args_kwargs[\"torch_compile_mode\"] = self.cfg.torch_compile_mode\n\n    def _configure_accelerator_config(self, training_args_kwargs: dict):\n        if self.cfg.accelerator_config:\n            training_args_kwargs[\"accelerator_config\"] = AcceleratorConfig(\n                **self.cfg.accelerator_config\n            )\n        else:\n            training_args_kwargs[\"accelerator_config\"] = AcceleratorConfig()\n\n    def _configure_gradient_checkpointing(self, training_args_kwargs: dict):\n        if self.cfg.activation_offloading is True:\n            # don't use the HF gradient checkpointing, manually wrap\n            training_args_kwargs[\"gradient_checkpointing\"] = False\n            training_args_kwargs[\"activation_offloading\"] = True\n        elif self.cfg.gradient_checkpointing is not None:\n            training_args_kwargs[\"gradient_checkpointing\"] = (\n                self.cfg.gradient_checkpointing\n            )\n            if self.cfg.gradient_checkpointing_kwargs is not None:\n                training_args_kwargs[\"gradient_checkpointing_kwargs\"] = (\n                    self.cfg.gradient_checkpointing_kwargs\n                )\n            else:\n                training_args_kwargs[\"gradient_checkpointing_kwargs\"] = {\n                    \"use_reentrant\": False\n                }\n\n    def _set_base_training_args(\n        self, total_num_steps\n    ) -> tuple[dict[str, Any], dict[str, Any]]:\n        training_args_kwargs: dict[str, Any] = {}\n        trainer_kwargs: dict[str, Any] = {}\n\n        self._configure_warmup_and_logging(total_num_steps, training_args_kwargs)\n        self._configure_precision_settings(training_args_kwargs)\n        self._configure_save_and_eval_strategy(training_args_kwargs)\n        self._configure_gradient_checkpointing(training_args_kwargs)\n\n        # set arg into trainer_args_kwargs with same name if value not None\n        for arg in [\n            # optim/scheduler\n            \"adam_beta1\",\n            \"adam_beta2\",\n            \"adam_beta3\",\n            \"adam_epsilon\",\n            \"adam_epsilon2\",\n            \"cosine_min_lr_ratio\",\n            \"cosine_constant_lr_ratio\",\n            \"optim_target_modules\",\n            # trainer\n            \"max_grad_norm\",\n            \"dataloader_num_workers\",\n            \"dataloader_pin_memory\",\n            \"dataloader_prefetch_factor\",\n            \"gradient_accumulation_steps\",\n            \"learning_rate\",\n            \"embedding_lr\",\n            \"embedding_lr_scale\",\n            \"lr_groups\",\n            \"loraplus_lr_ratio\",\n            \"loraplus_lr_embedding\",\n            \"output_dir\",\n            \"save_only_model\",\n            \"weight_decay\",\n            \"seed\",\n            \"dion_momentum\",\n            \"dion_rank_fraction\",\n            \"dion_rank_multiple_of\",\n            \"dataset_num_proc\",\n        ]:\n            if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None:\n                training_args_kwargs[arg] = getattr(self.cfg, arg)\n\n        arg_map = {\n            \"dion_learning_rate\": \"dion_lr\",\n            \"include_num_input_tokens_seen\": \"include_tokens_per_second\",\n        }\n        for kwarg, cfg_arg in arg_map.items():\n            if hasattr(self.cfg, cfg_arg) and getattr(self.cfg, cfg_arg) is not None:\n                training_args_kwargs[kwarg] = getattr(self.cfg, cfg_arg)\n\n        training_args_kwargs[\"per_device_train_batch_size\"] = self.cfg.micro_batch_size\n        training_args_kwargs[\"average_tokens_across_devices\"] = False\n\n        if self.cfg.eval_batch_size:\n            training_args_kwargs[\"per_device_eval_batch_size\"] = (\n                self.cfg.eval_batch_size\n            )\n\n        training_args_kwargs[\"include_tkps\"] = self.cfg.include_tkps\n        training_args_kwargs[\"max_steps\"] = self.cfg.max_steps or total_num_steps or -1\n        training_args_kwargs[\"num_train_epochs\"] = self.cfg.num_epochs\n\n        # max_length is not used in CausalTrainer\n        if self.cfg.reward_model or self.cfg.rl:\n            training_args_kwargs[\"max_length\"] = self.cfg.sequence_len\n\n        if self.cfg.fsdp_config or self.cfg.fsdp:\n            training_args_kwargs[\"fsdp_config\"] = self.cfg.fsdp_config\n            training_args_kwargs[\"fsdp\"] = self.cfg.fsdp if self.cfg.fsdp else True\n\n        self._configure_reporting(training_args_kwargs)\n        self._configure_hub_parameters(training_args_kwargs)\n        self._configure_scheduler(training_args_kwargs)\n        self._configure_optimizer(training_args_kwargs, trainer_kwargs)\n        self._configure_torch_compile(training_args_kwargs)\n        self._configure_accelerator_config(training_args_kwargs)\n\n        return training_args_kwargs, trainer_kwargs\n"
  },
  {
    "path": "src/axolotl/core/builders/causal.py",
    "content": "\"\"\"Builder for causal trainers\"\"\"\n\nimport inspect\nimport math\nimport os\nfrom pathlib import Path\nfrom typing import Type, Union\n\nimport transformers\nfrom transformers import (\n    DataCollatorWithFlattening,\n    EarlyStoppingCallback,\n    Trainer,\n)\nfrom trl.trainer.reward_trainer import DataCollatorForPreference\n\nfrom axolotl.core.builders.base import TrainerBuilderBase\nfrom axolotl.core.trainers import (\n    AxolotlMambaTrainer,\n    AxolotlPRMTrainer,\n    AxolotlRewardTrainer,\n    AxolotlTrainer,\n)\nfrom axolotl.integrations.base import PluginManager\nfrom axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES\nfrom axolotl.monkeypatch.relora import ReLoRACallback\nfrom axolotl.processing_strategies import get_processing_strategy\nfrom axolotl.utils import is_comet_available, is_mlflow_available\nfrom axolotl.utils.callbacks import (\n    LossWatchDogCallback,\n    bench_eval_callback_factory,\n    causal_lm_bench_eval_callback_factory,\n    colab_inference_post_train_callback,\n    log_prediction_callback_factory,\n)\nfrom axolotl.utils.callbacks.lisa import lisa_callback_factory\nfrom axolotl.utils.callbacks.qat import QATCallback\nfrom axolotl.utils.callbacks.tokens_per_second import TokensPerSecondCallback\nfrom axolotl.utils.chat_templates import get_chat_template_from_config\nfrom axolotl.utils.collators import (\n    BatchSamplerDataCollatorForSeq2Seq,\n    DataCollatorForSeq2Seq,\n    MambaDataCollator,\n    V2BatchSamplerDataCollatorForSeq2Seq,\n)\nfrom axolotl.utils.collators.mm_chat import MultiModalChatDataCollator\nfrom axolotl.utils.import_helper import get_cls_from_module_str\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\n\nclass HFCausalTrainerBuilder(TrainerBuilderBase):\n    \"\"\"\n    Build the HuggingFace training args/trainer for causal models and reward modeling\n    using TRL.\n    \"\"\"\n\n    def get_callbacks(self):\n        callbacks = super().get_callbacks()\n\n        if self.cfg.relora:\n            callbacks.append(ReLoRACallback(self.cfg))\n\n        # TODO: check if can move to base class\n        if self.cfg.loss_watchdog_threshold is not None:\n            callbacks.append(LossWatchDogCallback(self.cfg))\n\n        if self.cfg.qat:\n            callbacks.append(QATCallback(self.cfg.qat))\n\n        if self.cfg.include_tkps:\n            callbacks.append(\n                TokensPerSecondCallback(\n                    self.cfg.tensor_parallel_size,\n                    self.cfg.context_parallel_size,\n                    resume_from_checkpoint=self.cfg.resume_from_checkpoint,\n                )\n            )\n        return callbacks\n\n    def get_post_trainer_create_callbacks(self, trainer):\n        callbacks = []\n        if self.cfg.use_wandb and self.cfg.eval_table_size > 0:\n            LogPredictionCallback = log_prediction_callback_factory(\n                trainer, self.tokenizer, \"wandb\"\n            )\n            callbacks.append(LogPredictionCallback(self.cfg))\n        if (\n            self.cfg.use_mlflow\n            and is_mlflow_available()\n            and self.cfg.eval_table_size > 0\n        ):\n            LogPredictionCallback = log_prediction_callback_factory(\n                trainer, self.tokenizer, \"mlflow\"\n            )\n            callbacks.append(LogPredictionCallback(self.cfg))\n        if self.cfg.use_comet and is_comet_available() and self.cfg.eval_table_size > 0:\n            LogPredictionCallback = log_prediction_callback_factory(\n                trainer, self.tokenizer, \"comet_ml\"\n            )\n            callbacks.append(LogPredictionCallback(self.cfg))\n\n        if self.cfg.do_bench_eval:\n            callbacks.append(bench_eval_callback_factory(trainer, self.tokenizer))\n        if self.cfg.do_causal_lm_eval:\n            CausalLMBenchEvalCallback = causal_lm_bench_eval_callback_factory(\n                trainer, self.tokenizer\n            )\n            callbacks.append(CausalLMBenchEvalCallback(self.cfg))\n\n        if self.cfg.early_stopping_patience:\n            early_stop_cb = EarlyStoppingCallback(\n                self.cfg.early_stopping_patience,\n            )\n            callbacks.append(early_stop_cb)\n\n        if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:\n            callbacks.append(lisa_callback_factory(trainer))\n\n        if any(\"COLAB_\" in key for key in os.environ):\n            ColabCallback = colab_inference_post_train_callback(trainer)\n            callbacks.append(ColabCallback(self.cfg))\n\n        if getattr(self.cfg, \"generate_samples\", False):\n            from axolotl.utils.callbacks.generation import SFTGenerationCallback\n\n            callbacks.append(SFTGenerationCallback(trainer))\n            LOG.info(\"SFT sample generation enabled\")\n\n        callbacks.extend(super().get_post_trainer_create_callbacks(trainer=trainer))\n        return callbacks\n\n    def _get_trainer_cls(self):\n        \"\"\"\n        Gets the trainer class for the given configuration.\n        \"\"\"\n        if self.cfg.plugins:\n            plugin_manager = PluginManager.get_instance()\n            trainer_cls = plugin_manager.get_trainer_cls(self.cfg)\n            if trainer_cls:\n                return trainer_cls\n        if self.cfg.model_config_type == \"mamba\":\n            return AxolotlMambaTrainer\n        if self.cfg.reward_model:\n            return AxolotlRewardTrainer\n        if self.cfg.process_reward_model:\n            return AxolotlPRMTrainer\n\n        if self.cfg.trainer_cls:\n            # override the trainer cls\n            try:\n                trainer_cls = get_cls_from_module_str(self.cfg.trainer_cls)\n                LOG.debug(f\"Using custom trainer class: {self.cfg.trainer_cls}\")\n                return trainer_cls\n            except (ImportError, AttributeError, ValueError) as e:\n                raise ValueError(\n                    f\"Failed to load custom trainer class '{self.cfg.trainer_cls}': {e}\"\n                ) from e\n\n        return AxolotlTrainer\n\n    def build(self, total_num_steps):\n        from axolotl.core.training_args import (\n            AxolotlPRMConfig,\n            AxolotlRewardConfig,\n            AxolotlTrainingArguments,\n        )\n\n        training_arguments_kwargs, trainer_kwargs = self._set_base_training_args(\n            total_num_steps\n        )\n        if self.cfg.adapter == \"qlora\":\n            training_arguments_kwargs[\"qlora\"] = True\n\n        # deepspeed\n        if self.cfg.deepspeed:\n            training_arguments_kwargs[\"deepspeed\"] = self.cfg.deepspeed\n\n        if self.cfg.lr_quadratic_warmup is not None:\n            training_arguments_kwargs[\"lr_quadratic_warmup\"] = (\n                self.cfg.lr_quadratic_warmup\n            )\n\n        if self.cfg.dataloader_drop_last is not None:\n            training_arguments_kwargs[\"dataloader_drop_last\"] = (\n                self.cfg.dataloader_drop_last\n            )\n        elif self.cfg.sample_packing and self.cfg.eval_sample_packing is False:\n            training_arguments_kwargs[\"dataloader_drop_last\"] = True\n\n        if self.cfg.remove_unused_columns is not None:\n            training_arguments_kwargs[\"remove_unused_columns\"] = (\n                self.cfg.remove_unused_columns\n            )\n\n        if self.cfg.do_bench_eval:\n            training_arguments_kwargs[\"do_bench_eval\"] = self.cfg.do_bench_eval\n            if self.cfg.bench_dataset:\n                training_arguments_kwargs[\"bench_dataset\"] = self.cfg.bench_dataset\n        if self.cfg.do_causal_lm_eval:\n            training_arguments_kwargs[\"do_causal_lm_eval\"] = self.cfg.do_causal_lm_eval\n        if self.cfg.metric_for_best_model:\n            training_arguments_kwargs[\"metric_for_best_model\"] = (\n                self.cfg.metric_for_best_model\n            )\n        if self.cfg.greater_is_better:\n            training_arguments_kwargs[\"greater_is_better\"] = self.cfg.greater_is_better\n\n        # DDP Config\n        if self.cfg.ddp_timeout:\n            training_arguments_kwargs[\"ddp_timeout\"] = self.cfg.ddp_timeout\n        # see https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html\n        if self.cfg.ddp_bucket_cap_mb:\n            training_arguments_kwargs[\"ddp_bucket_cap_mb\"] = self.cfg.ddp_bucket_cap_mb\n        if self.cfg.ddp_broadcast_buffers is not None:\n            training_arguments_kwargs[\"ddp_broadcast_buffers\"] = (\n                self.cfg.ddp_broadcast_buffers\n            )\n\n        # these are all the \"standard\" kwargs that are def used\n        training_arguments_kwargs[\"max_seq_length\"] = self.cfg.sequence_len\n\n        if self.cfg.auto_find_batch_size is not None:\n            training_arguments_kwargs[\"auto_find_batch_size\"] = (\n                self.cfg.auto_find_batch_size\n            )\n\n        training_arguments_kwargs[\"eval_accumulation_steps\"] = (\n            self.cfg.gradient_accumulation_steps\n        )\n\n        training_arguments_kwargs[\"load_best_model_at_end\"] = (\n            (\n                self.cfg.load_best_model_at_end is not False\n                or self.cfg.early_stopping_patience\n            )\n            and (\n                (not self.cfg.test_datasets and self.cfg.val_set_size > 0)\n                or (self.cfg.test_datasets and self.cfg.val_set_size == 0)\n            )\n            and self.cfg.save_steps\n            and self.cfg.eval_steps\n            and self.cfg.save_steps % self.cfg.eval_steps == 0\n        ) or False\n\n        # handle ddp\n        ddp_find_unused_parameters = None\n        if self.cfg.ddp:\n            ddp_find_unused_parameters = bool(self.cfg.ddp_find_unused_parameters)\n        training_arguments_kwargs[\"ddp_find_unused_parameters\"] = (\n            ddp_find_unused_parameters\n        )\n\n        if self.cfg.group_by_length:\n            training_arguments_kwargs[\"train_sampling_strategy\"] = \"group_by_length\"\n        training_arguments_kwargs[\"curriculum_sampling\"] = self.cfg.curriculum_sampling\n\n        training_arguments_kwargs[\"sample_packing\"] = bool(self.cfg.sample_packing)\n        training_arguments_kwargs[\"sample_packing_drop_attention_mask\"] = bool(\n            self.cfg.flash_attention\n            or self.cfg.xformers_attention\n            or self.cfg.flex_attention\n        )\n        training_arguments_kwargs[\"multipack_real_batches\"] = (\n            self.cfg.multipack_real_batches\n            if self.cfg.multipack_real_batches is not None\n            else not (\n                self.cfg.flash_attention\n                or self.cfg.flex_attention\n                or self.cfg.xformers_attention\n            )\n        )\n        training_arguments_kwargs[\"eval_sample_packing\"] = bool(\n            self.cfg.eval_sample_packing\n        )\n        if self.cfg.sample_packing_sequentially is not None:\n            training_arguments_kwargs[\"sample_packing_sequentially\"] = (\n                self.cfg.sample_packing_sequentially\n            )\n        if self.cfg.sample_packing_bin_size is not None:\n            training_arguments_kwargs[\"sample_packing_bin_size\"] = (\n                self.cfg.sample_packing_bin_size\n            )\n        if self.cfg.sample_packing_group_size is not None:\n            training_arguments_kwargs[\"sample_packing_group_size\"] = (\n                self.cfg.sample_packing_group_size\n            )\n        if self.cfg.sample_packing_eff_est:\n            training_arguments_kwargs[\"sample_packing_efficiency\"] = (\n                self.cfg.sample_packing_eff_est\n            )\n\n        if self.cfg.relora and self.cfg.jagged_restart_steps:\n            if self.cfg.relora_prune_ratio:\n                training_arguments_kwargs[\"relora_prune_ratio\"] = (\n                    self.cfg.relora_prune_ratio\n                )\n\n        if self.cfg.jagged_restart_steps:\n            training_arguments_kwargs[\"jagged_restart_steps\"] = (\n                self.cfg.jagged_restart_steps\n            )\n            if self.cfg.jagged_restart_warmup_steps:\n                training_arguments_kwargs[\"jagged_restart_warmup_steps\"] = (\n                    self.cfg.jagged_restart_warmup_steps\n                )\n            if self.cfg.jagged_restart_anneal_steps:\n                training_arguments_kwargs[\"jagged_restart_anneal_steps\"] = (\n                    self.cfg.jagged_restart_anneal_steps\n                )\n\n        if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:\n            training_arguments_kwargs[\"lisa_n_layers\"] = self.cfg.lisa_n_layers\n            training_arguments_kwargs[\"lisa_step_interval\"] = (\n                self.cfg.lisa_step_interval\n            )\n            training_arguments_kwargs[\"lisa_layers_attribute\"] = (\n                self.cfg.lisa_layers_attribute\n            )\n\n        training_arguments_kwargs = self.hook_pre_create_training_args(\n            training_arguments_kwargs\n        )\n        training_arguments_kwargs[\"model_type\"] = self.cfg.model_config_type\n        training_arguments_kwargs[\"pretraining\"] = bool(self.cfg.pretraining_dataset)\n        if self.cfg.chat_template:\n            training_arguments_kwargs[\"chat_template\"] = get_chat_template_from_config(\n                cfg=self.cfg,\n                tokenizer=self.tokenizer,\n            )\n\n        if self.cfg.neftune_noise_alpha is not None:\n            training_arguments_kwargs[\"neftune_noise_alpha\"] = (\n                self.cfg.neftune_noise_alpha\n            )\n\n        if self.cfg.image_size:\n            training_arguments_kwargs[\"image_size\"] = self.cfg.image_size\n        if self.cfg.image_resize_algorithm:\n            training_arguments_kwargs[\"image_resize_algorithm\"] = (\n                self.cfg.image_resize_algorithm\n            )\n\n        if self.cfg.plugins:\n            plugin_manager = PluginManager.get_instance()\n            plugin_training_args = plugin_manager.get_training_args(self.cfg)\n            if plugin_training_args:\n                training_arguments_kwargs.update(plugin_training_args)\n\n        if self.cfg.reward_model:\n            training_args_cls = AxolotlRewardConfig\n            if self.cfg.center_rewards_coefficient is not None:\n                training_arguments_kwargs[\"center_rewards_coefficient\"] = (\n                    self.cfg.center_rewards_coefficient\n                )\n        elif self.cfg.process_reward_model:\n            training_args_cls = AxolotlPRMConfig\n        else:\n            training_args_cls = AxolotlTrainingArguments\n        training_args = training_args_cls(\n            **training_arguments_kwargs,\n        )\n        training_args = self.hook_post_create_training_args(training_args)\n\n        # unset run_name so wandb sets up experiment names\n        if self.cfg.use_wandb and training_args.run_name == training_args.output_dir:\n            training_args.run_name = None\n\n        data_collator_kwargs = {\n            \"padding\": True,  # True/\"longest\" is the default\n        }\n        multiple = 64\n        if self.cfg.pad_to_sequence_len:\n            data_collator_kwargs[\"pad_to_multiple_of\"] = multiple * math.ceil(\n                self.cfg.sequence_len / multiple\n            )\n        elif self.cfg.pad_to_sequence_len is None:\n            # A100 is best at 64, while others at 8. Let's use the larger so we don't have to check\n            # https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html\n            data_collator_kwargs[\"pad_to_multiple_of\"] = multiple\n\n        if self.cfg.use_eaft:\n            from functools import partial\n\n            from axolotl.monkeypatch.loss.eaft import eaft_loss\n\n            configured_eaft_loss = partial(\n                eaft_loss,\n                alpha=self.cfg.eaft_alpha if self.cfg.eaft_alpha is not None else 1.0,\n                k=self.cfg.eaft_k if self.cfg.eaft_k is not None else 20,\n            )\n            trainer_kwargs[\"compute_loss_func\"] = configured_eaft_loss\n\n        trainer_cls = self._get_trainer_cls()\n\n        trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(\n            trainer_kwargs, trainer_cls\n        )\n        if eval_data_collator := self.build_collator(\n            training_args, is_eval=True, **data_collator_kwargs\n        ):\n            if not (self.cfg.reward_model or self.cfg.process_reward_model):\n                trainer_kwargs[\"eval_data_collator\"] = eval_data_collator\n        if not (self.cfg.reward_model or self.cfg.process_reward_model):\n            trainer_kwargs[\"bench_data_collator\"] = transformers.DataCollatorForSeq2Seq(\n                self.tokenizer,\n                return_tensors=\"pt\",\n                **data_collator_kwargs,\n            )\n        sig = inspect.signature(trainer_cls)\n        if \"processing_class\" in sig.parameters or issubclass(trainer_cls, Trainer):\n            trainer_kwargs[\"processing_class\"] = self.tokenizer\n        elif \"tokenizer\" in sig.parameters:\n            trainer_kwargs[\"tokenizer\"] = self.tokenizer\n\n        if (\n            trainer_cls not in [AxolotlRewardTrainer, AxolotlPRMTrainer]\n            and self.cfg.datasets is not None\n        ):\n            trainer_kwargs[\"dataset_tags\"] = [\n                d[\"path\"] for d in self.cfg.datasets if not Path(d[\"path\"]).is_dir()\n            ]\n        # TRL's RewardTrainer validates num_labels=1 on pre-loaded models; ensure the\n        # config reflects this regardless of how the model was instantiated.\n        if (\n            self.cfg.reward_model\n            and getattr(self.model.config, \"num_labels\", None) != 1\n        ):\n            self.model.config.num_labels = 1\n        trainer = trainer_cls(\n            model=self.model,\n            train_dataset=self.train_dataset,\n            eval_dataset=self.eval_dataset,\n            args=training_args,\n            data_collator=self.build_collator(training_args, **data_collator_kwargs),\n            callbacks=self.get_callbacks(),\n            **trainer_kwargs,\n        )\n        trainer = self.hook_post_create_trainer(trainer)\n        # if the trainer has the `axolotl_cfg` property, set it\n        if hasattr(trainer, \"axolotl_cfg\"):\n            trainer.axolotl_cfg = self.cfg\n        for callback in self.get_post_trainer_create_callbacks(trainer):\n            trainer.add_callback(callback)\n\n        if self.cfg.deepspeed and self.cfg.sample_packing:\n            trainer.accelerator.state.deepspeed_plugin.deepspeed_config[\n                \"train_micro_batch_size_per_gpu\"\n            ] = self.cfg.micro_batch_size\n\n        return trainer\n\n    def build_collator(\n        self,\n        training_args,  # type: \"AxolotlTrainingArguments\"  # type: ignore\n        is_eval=False,\n        **kwargs,\n    ):\n        if training_args.pretraining:\n            if (\n                self.cfg.pretraining_sample_concatenation is False\n                or self.cfg.micro_batch_size > 1\n            ):\n                return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)\n            if not (self.cfg.sample_packing and self.cfg.pretrain_multipack_attn) or (\n                self.cfg.micro_batch_size == 1 and is_eval is False\n            ):\n                return None\n\n        if self.cfg.model_config_type == \"mamba\":\n            return MambaDataCollator(tokenizer=self.tokenizer)\n\n        use_batch_sampler_collator = False\n        if is_eval is False and training_args.sample_packing:\n            use_batch_sampler_collator = True\n        if is_eval and training_args.eval_sample_packing:\n            use_batch_sampler_collator = True\n\n        collator: Type[\n            Union[\n                V2BatchSamplerDataCollatorForSeq2Seq,\n                BatchSamplerDataCollatorForSeq2Seq,\n                DataCollatorForSeq2Seq,\n                DataCollatorWithFlattening,\n                DataCollatorForPreference,\n            ]\n        ]\n        collator_args = [self.tokenizer]\n\n        collator_cls_and_kwargs = None\n        if self.cfg.plugins:\n            plugin_manager = PluginManager.get_instance()\n            collator_cls_and_kwargs = plugin_manager.get_collator_cls_and_kwargs(\n                self.cfg, is_eval=is_eval\n            )\n\n        if collator_cls_and_kwargs:\n            collator = collator_cls_and_kwargs[0]\n            if kwargs and isinstance(kwargs, dict):\n                kwargs.update(collator_cls_and_kwargs[1])\n        elif self.cfg.reward_model:\n            collator = DataCollatorForPreference\n            tokenizer = collator_args.pop(0)\n            kwargs[\"pad_token_id\"] = tokenizer.pad_token_id\n            kwargs.pop(\"padding\")\n        elif use_batch_sampler_collator:\n            # Use V2BatchSamplerDataCollatorForSeq2Seq for flex attention,\n            # supported multipack models, or non-flash-attention llama\n            if (\n                self.cfg.flex_attention\n                or self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES\n                or (\n                    self.cfg.model_config_type in [\"llama\"]\n                    and self.cfg.flash_attention is not True\n                )\n            ):\n                collator = V2BatchSamplerDataCollatorForSeq2Seq\n            else:\n                collator = BatchSamplerDataCollatorForSeq2Seq\n        else:\n            if self.cfg.processor_type and self.processor:\n                collator = MultiModalChatDataCollator\n                kwargs[\"processing_strategy\"] = get_processing_strategy(\n                    self.processor,\n                    training_args.chat_template,\n                    self.cfg.chat_template,\n                    image_size=training_args.image_size,\n                    image_resize_algorithm=training_args.image_resize_algorithm,\n                )\n            elif self.cfg.batch_flattening:\n                collator = DataCollatorWithFlattening\n                collator_args.pop(0)\n                kwargs.pop(\"pad_to_multiple_of\", None)\n                kwargs.pop(\"padding\", None)\n            else:\n                collator = DataCollatorForSeq2Seq\n\n        kwargs[\"return_tensors\"] = \"pt\"\n\n        return collator(\n            *collator_args,\n            **kwargs,\n        )\n"
  },
  {
    "path": "src/axolotl/core/builders/rl.py",
    "content": "\"\"\"Builder for RLHF trainers\"\"\"\n\nimport inspect\nfrom pathlib import Path\n\nfrom axolotl.core.builders.base import TrainerBuilderBase\nfrom axolotl.core.trainers import (\n    AxolotlCPOTrainer,\n    AxolotlKTOTrainer,\n    AxolotlORPOTrainer,\n)\nfrom axolotl.core.trainers.dpo import DPOStrategy\nfrom axolotl.core.trainers.dpo.args import AxolotlDPOConfig\nfrom axolotl.integrations.base import PluginManager\nfrom axolotl.loaders.utils import ensure_dtype\nfrom axolotl.utils.callbacks.qat import QATCallback\nfrom axolotl.utils.import_helper import get_cls_from_module_str\nfrom axolotl.utils.logging import get_logger\nfrom axolotl.utils.schemas.enums import RLType\n\nLOG = get_logger(__name__)\n\n\nclass HFRLTrainerBuilder(TrainerBuilderBase):\n    \"\"\"Trainer factory class for TRL-based RLHF trainers (e.g. DPO)\"\"\"\n\n    def get_callbacks(self):\n        callbacks = super().get_callbacks()\n\n        if self.cfg.qat:\n            callbacks.append(QATCallback(self.cfg.qat))\n\n        return callbacks\n\n    def get_post_trainer_create_callbacks(self, trainer):\n        callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)\n        return callbacks\n\n    def _get_trainer_cls(self, trainer_kwargs: dict):\n        \"\"\"\n        Returns trainer_cls and trainer_cls_args\n        \"\"\"\n        if self.cfg.plugins:\n            plugin_manager = PluginManager.get_instance()\n            trainer_cls = plugin_manager.get_trainer_cls(self.cfg)\n            trainer_cls_args = []  # type: ignore\n\n            if trainer_cls is not None:\n                return trainer_cls, trainer_cls_args\n\n        trainer_cls = None\n        trainer_cls_args = [self.model]\n\n        if self.cfg.rl in {RLType.GRPO, RLType.GDPO}:\n            from axolotl.core.trainers.grpo import GRPOStrategy\n\n            async_grpo = bool(\n                self.cfg.trl\n                and (\n                    getattr(self.cfg.trl, \"async_prefetch\", False)\n                    or getattr(self.cfg.trl, \"use_data_producer\", False)\n                )\n            )\n            trainer_cls = GRPOStrategy.get_trainer_class(\n                sequence_parallel=self.cfg.context_parallel_size > 1,\n                async_grpo=async_grpo,\n            )\n            trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))\n            trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))\n\n        elif self.cfg.rl in [RLType.DPO, RLType.IPO]:\n            trainer_cls = DPOStrategy.get_trainer_class()\n            trainer_cls_args.append(self.model_ref)\n\n        elif self.cfg.rl is RLType.ORPO:\n            trainer_cls = AxolotlORPOTrainer\n        elif self.cfg.rl is RLType.KTO:\n            trainer_cls = AxolotlKTOTrainer\n        elif self.cfg.rl is RLType.SIMPO:\n            trainer_cls = AxolotlCPOTrainer\n        else:\n            raise ValueError(f\"Unsupported RL: {self.cfg.rl}\")\n\n        if self.cfg.trainer_cls:\n            # override the trainer cls\n            try:\n                trainer_cls = get_cls_from_module_str(self.cfg.trainer_cls)\n                LOG.debug(f\"Using custom trainer class: {self.cfg.trainer_cls}\")\n            except (ImportError, AttributeError, ValueError) as e:\n                raise ValueError(\n                    f\"Failed to load custom trainer class '{self.cfg.trainer_cls}': {e}\"\n                ) from e\n\n        return trainer_cls, trainer_cls_args\n\n    def _build_training_arguments(self, total_num_steps):\n        \"\"\"\n        Returns training_args and trainer_kwargs\n        \"\"\"\n        from axolotl.core.training_args import (\n            AxolotlCPOConfig,\n            AxolotlKTOConfig,\n            AxolotlORPOConfig,\n        )\n\n        training_args_kwargs, trainer_kwargs = self._set_base_training_args(\n            total_num_steps=total_num_steps\n        )\n\n        if self.cfg.remove_unused_columns is not None:\n            training_args_kwargs[\"remove_unused_columns\"] = (\n                self.cfg.remove_unused_columns\n            )\n        else:\n            training_args_kwargs[\"remove_unused_columns\"] = False\n\n        if self.cfg.trl and self.cfg.trl.beta is not None:\n            training_args_kwargs[\"beta\"] = self.cfg.trl.beta\n        elif self.cfg.rl_beta is not None:\n            training_args_kwargs[\"beta\"] = self.cfg.rl_beta\n        elif self.cfg.orpo_alpha is not None:\n            # trl does some odd mapping of alpha to beta to reuse the beta parameter ???\n            training_args_kwargs[\"beta\"] = self.cfg.orpo_alpha\n\n        if self.cfg.rpo_alpha is not None:\n            training_args_kwargs[\"rpo_alpha\"] = self.cfg.rpo_alpha\n\n        if self.cfg.use_wandb:\n            training_args_kwargs[\"run_name\"] = self.cfg.wandb_name\n\n        training_args_cls = None\n        blocklist_args_kwargs = []\n        if self.cfg.rl is RLType.SIMPO:\n            training_args_cls = AxolotlCPOConfig\n            training_args_kwargs[\"loss_type\"] = \"simpo\"\n            training_args_kwargs[\"simpo_gamma\"] = self.cfg.simpo_gamma\n            if self.cfg.cpo_alpha is not None:\n                training_args_kwargs[\"cpo_alpha\"] = self.cfg.cpo_alpha\n\n            blocklist_args_kwargs.append(\"max_prompt_length\")\n\n        elif self.cfg.rl is RLType.ORPO:\n            training_args_cls = AxolotlORPOConfig\n\n            blocklist_args_kwargs.append(\"max_prompt_length\")\n\n        elif self.cfg.rl is RLType.KTO:\n            training_args_cls = AxolotlKTOConfig\n            # KTOConfig in TRL >= 0.27.0 no longer accepts max_prompt_length\n            blocklist_args_kwargs.append(\"max_prompt_length\")\n\n            training_args_kwargs[\"desirable_weight\"] = (\n                self.cfg.kto_desirable_weight or 1.0\n            )\n            training_args_kwargs[\"undesirable_weight\"] = (\n                self.cfg.kto_undesirable_weight or 1.0\n            )\n\n        elif self.cfg.rl in {RLType.GRPO, RLType.GDPO}:\n            from axolotl.core.trainers.grpo import GRPOStrategy\n\n            async_grpo = bool(\n                self.cfg.trl\n                and (\n                    getattr(self.cfg.trl, \"async_prefetch\", False)\n                    or getattr(self.cfg.trl, \"use_data_producer\", False)\n                )\n            )\n            training_args_cls = GRPOStrategy.get_training_args_class(\n                async_grpo=async_grpo\n            )\n            training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))\n            blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs()\n            if self.cfg.rl is RLType.GDPO:\n                training_args_kwargs.setdefault(\n                    \"multi_objective_aggregation\", \"normalize_then_sum\"\n                )\n\n        elif self.cfg.rl in [RLType.DPO, RLType.IPO]:\n            training_args_cls = AxolotlDPOConfig\n            training_args_kwargs.update(DPOStrategy.set_training_args_kwargs(self.cfg))\n        else:\n            raise ValueError(f\"Unsupported RL: {self.cfg.rl}\")\n\n        for blocklist_key in blocklist_args_kwargs:\n            if blocklist_key in training_args_kwargs:\n                del training_args_kwargs[blocklist_key]\n\n        if self.cfg.plugins:\n            plugin_manager = PluginManager.get_instance()\n            plugin_training_args = plugin_manager.get_training_args(self.cfg)\n            if plugin_training_args:\n                training_args_kwargs.update(plugin_training_args)\n\n        training_args = training_args_cls(\n            logging_first_step=True,\n            **training_args_kwargs,\n        )\n\n        # unset run_name so wandb sets up experiment names\n        if self.cfg.use_wandb and training_args.run_name == training_args.output_dir:\n            training_args.run_name = None\n\n        return training_args, trainer_kwargs\n\n    def build(self, total_num_steps):\n        training_args, trainer_kwargs = self._build_training_arguments(total_num_steps)\n\n        if self.eval_dataset:\n            trainer_kwargs[\"eval_dataset\"] = self.eval_dataset\n        if self.cfg.adapter and self.peft_config and self.cfg.rl is not RLType.GRPO:\n            trainer_kwargs[\"peft_config\"] = self.peft_config\n        if self.cfg.precompute_ref_log_probs is not None:\n            trainer_kwargs[\"precompute_ref_log_probs\"] = (\n                self.cfg.precompute_ref_log_probs\n            )\n\n        trainer_cls, trainer_cls_args = self._get_trainer_cls(trainer_kwargs)\n\n        sig = inspect.signature(trainer_cls)\n        if \"tokenizer\" in sig.parameters:\n            trainer_kwargs[\"tokenizer\"] = self.tokenizer\n        else:\n            trainer_kwargs[\"processing_class\"] = self.tokenizer\n\n        if self.cfg.datasets is not None and (\n            trainer_cls is DPOStrategy.get_trainer_class()\n        ):\n            trainer_kwargs[\"dataset_tags\"] = [\n                d[\"path\"] for d in self.cfg.datasets if not Path(d[\"path\"]).is_dir()\n            ]\n\n        trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(\n            trainer_kwargs, trainer_cls\n        )\n\n        # Allow FP8-quantized models to be fine-tuned with LoRA adapters.\n        # transformers' validate_quantization_for_training blocks FP8 because\n        # hf_quantizer.is_trainable is False, but LoRA only trains the adapters\n        # (base weights stay frozen in FP8).\n        _orig_validate_quant = None\n        if (\n            self.cfg.adapter\n            and hasattr(self.model, \"is_quantized\")\n            and self.model.is_quantized\n        ):\n            import transformers.trainer as _trainer_module\n\n            _orig_validate_quant = _trainer_module.validate_quantization_for_training\n            _trainer_module.validate_quantization_for_training = lambda model: None\n\n        try:\n            trainer = trainer_cls(\n                *trainer_cls_args,\n                args=training_args,\n                train_dataset=self.train_dataset,\n                callbacks=self.get_callbacks(),\n                **trainer_kwargs,\n            )\n        finally:\n            if _orig_validate_quant is not None:\n                import transformers.trainer as _trainer_module\n\n                _trainer_module.validate_quantization_for_training = (\n                    _orig_validate_quant\n                )\n        if self.cfg.fsdp_config or self.cfg.fsdp:\n            ensure_dtype(trainer.model, dtype=self.cfg.torch_dtype)\n            if self.cfg.rl in [RLType.DPO, RLType.IPO] and trainer.ref_model:\n                ensure_dtype(trainer.ref_model, dtype=self.cfg.torch_dtype)\n\n        trainer = self.hook_post_create_trainer(trainer)\n        for callback in self.get_post_trainer_create_callbacks(trainer):\n            trainer.add_callback(callback)\n\n        return trainer\n"
  },
  {
    "path": "src/axolotl/core/chat/__init__.py",
    "content": ""
  },
  {
    "path": "src/axolotl/core/chat/format/__init__.py",
    "content": ""
  },
  {
    "path": "src/axolotl/core/chat/format/chatml.py",
    "content": "\"\"\"\nChatML transformation functions for MessageContents\n\"\"\"\n\nfrom typing import Optional\n\nfrom ..messages import MessageContents, Messages\nfrom .shared import wrap_tools\n\n\ndef format_message(\n    message: Messages,\n    message_index: Optional[int] = None,\n) -> Messages:\n    if message.is_chat_formatted:\n        return message\n\n    # prepend the role prefix within a MessageContents to message.content\n    message.content.insert(\n        0,\n        MessageContents(\n            type=\"text\",\n            value=f\"<|im_start|>{message.role}\\n\",\n            weight=0,\n        ),\n    )\n    message.content.append(\n        MessageContents(type=\"text\", value=\"<|im_end|>\", weight=message.weight)\n    )\n    message.content.append(MessageContents(type=\"text\", value=\"\\n\", weight=0))\n\n    message = wrap_tools(message)\n\n    message.is_chat_formatted = True\n    return message\n"
  },
  {
    "path": "src/axolotl/core/chat/format/llama3x.py",
    "content": "\"\"\"\nLlama 3.x chat formatting functions for MessageContents\n\"\"\"\n\nfrom typing import Optional\n\nfrom ..messages import MessageContents, Messages\nfrom .shared import wrap_tools\n\n\ndef format_message(message: Messages, message_index: Optional[int] = None) -> Messages:\n    if message.is_chat_formatted:\n        return message\n\n    message_role = message.role\n    if message.role == \"tool\":\n        message_role = \"ipython\"\n\n    # prepend the role prefix within a MessageContents to message.content\n    message.content.insert(\n        0,\n        MessageContents(\n            type=\"text\",\n            value=f\"<|start_header_id|>{message_role}<|end_header_id|>\\n\\n\",\n            weight=0,\n        ),\n    )\n\n    message.content.append(\n        MessageContents(type=\"text\", value=\"<|eot_id|>\", weight=message.weight)\n    )\n\n    message = wrap_tools(message)\n\n    if message_index == 0:\n        message.content.insert(\n            0,\n            MessageContents(\n                type=\"text\",\n                value=\"<|begin_of_text|>\",\n                weight=0,\n            ),\n        )\n\n    message.is_chat_formatted = True\n    return message\n"
  },
  {
    "path": "src/axolotl/core/chat/format/shared.py",
    "content": "\"\"\"\nshared functions for format transforms\n\"\"\"\n\nfrom axolotl.core.chat.messages import MessageContents, Messages\n\n\ndef wrap_tools(message: Messages):\n    # loop over message.content by index to find tool calls, we need to wrap each with tags,\n    # so be wary of indexing issues when changing the list while iterating.\n    # iterate over the range in reverse order to avoid index shifting\n    for i in range(len(message.content) - 1, -1, -1):\n        if message.content[i].type == \"tool_call\":\n            # append a </tool_call> MessageContents text tag after\n            message.content.insert(\n                i + 1,\n                MessageContents(\n                    type=\"text\", value=\"</tool_call>\\n\", weight=message.weight\n                ),\n            )\n            # make sure the actual tool call content ends with a newline\n            message.content[i].has_newline = True\n            # prepend a <tool_call> MessageContents text tag before\n            message.content.insert(\n                i,\n                MessageContents(\n                    type=\"text\", value=\"<tool_call>\\n\", weight=message.weight\n                ),\n            )\n        elif message.content[i].type == \"tool_response\":\n            # append a </tool_call> MessageContents text tag after\n            message.content.insert(\n                i + 1,\n                MessageContents(\n                    type=\"text\", value=\"</tool_response>\\n\", weight=message.weight\n                ),\n            )\n            # make sure the actual tool response content ends with a newline\n            message.content[i].has_newline = True\n            # prepend a <tool_call> MessageContents text tag before\n            message.content.insert(\n                i,\n                MessageContents(\n                    type=\"text\", value=\"<tool_response>\\n\", weight=message.weight\n                ),\n            )\n\n    return message\n"
  },
  {
    "path": "src/axolotl/core/chat/messages.py",
    "content": "\"\"\"\ninternal message representations of chat messages\n\"\"\"\n\nimport json\nfrom enum import Enum\nfrom typing import Any, Callable, List, Optional, Union\n\nfrom pydantic import BaseModel\nfrom transformers import PreTrainedTokenizer\n\n\nclass MessageRoles(str, Enum):\n    \"\"\"\n    Message roles for the system, user, assistant, and tools\n    \"\"\"\n\n    system = \"system\"\n    user = \"user\"\n    assistant = \"assistant\"\n    tool = \"tool\"\n    ipython = (\n        # for responses from builtin tools\n        \"ipython\"\n    )\n\n\nclass MessageContentTypes(str, Enum):\n    \"\"\"\n    Message content types for text, image, audio, tool calls, and tool responses\n    \"\"\"\n\n    special_token = \"special_token\"  # nosec B105\n    text = \"text\"\n    image = \"image\"\n    audio = \"audio\"\n    tool_call = \"tool_call\"\n    tool_response = \"tool_response\"\n\n\nclass SpecialToken(str, Enum):\n    \"\"\"\n    Special tokens for beginning of string and end of string\n    \"\"\"\n\n    bos_token = \"bos_token\"  # nosec B105\n    eos_token = \"eos_token\"  # nosec B105\n\n\nclass ToolCallFunction(BaseModel):\n    \"\"\"\n    Tool call function with name and arguments\n    \"\"\"\n\n    name: str\n    arguments: dict[str, str]\n\n\nclass Tool(BaseModel):\n    \"\"\"\n    Tool with description, function, and parameters\n    \"\"\"\n\n    description: str\n    function: ToolCallFunction\n    parameters: dict[str, str]  # .properties\n\n\nclass ToolCallContents(BaseModel):\n    \"\"\"\n    Tool call contents with name, arguments, and optional id\n    \"\"\"\n\n    name: str\n    arguments: dict[str, Union[str, int]]\n    id: Optional[str] = None\n\n    def __str__(self) -> str:\n        data = {\"name\": self.name, \"arguments\": self.arguments}\n        if self.id is not None:\n            data[\"id\"] = self.id\n        return json.dumps(data)\n\n\nclass ToolResponseContents(BaseModel):\n    \"\"\"\n    Tool response contents with name, content, and optional id\n    \"\"\"\n\n    name: str\n    content: Union[str, dict[str, Union[str, int, float]]]\n    id: Optional[str] = None\n\n    def __str__(self) -> str:\n        data = {\"name\": self.name, \"content\": self.content}\n        if self.id is not None:\n            data[\"id\"] = self.id\n        return json.dumps(data)\n\n\nclass MessageContents(BaseModel):\n    \"\"\"\n    Message contents with type, value, metadata, weight, newline, and end of contents\n    \"\"\"\n\n    type: Union[str, MessageContentTypes]\n    value: Union[str, ToolCallContents, ToolResponseContents, SpecialToken]\n    meta: Optional[dict[str, Any]] = None  # support additional arbitrary metadata\n    weight: Optional[Union[int, float]] = None\n    has_newline: bool = False\n    eoc: bool = False  # end of contents\n\n    def __str__(self) -> str:\n        str_val = str(self.value)\n        if self.has_newline and not str_val.endswith(\"\\n\"):\n            str_val += \"\\n\"\n        return str_val\n\n\nclass Messages(BaseModel):\n    \"\"\"\n    Messages with role, content, metadata, weight, and chat formatting\n    \"\"\"\n\n    role: Union[MessageRoles, str]  # allows for arbitrary roles\n    content: List[\"MessageContents\"]\n    meta: Optional[dict[str, Any]] = None  # support additional arbitrary metadata\n    weight: Optional[Union[int, float]] = None\n    is_chat_formatted: bool = False\n\n    def __str__(self) -> str:\n        return \"\".join(str(c) for c in self.content)\n\n    def tokenized(\n        self, tokenizer: PreTrainedTokenizer, ignore_index=-100\n    ) -> dict[str, List[int]]:\n        # iterate over the contents, tokenizing the concatenated string values up to the current MessageContents\n        # returns a dictionary mapping w input_ids, attention_mask, and labels\n        input_ids: List[int] = []\n        labels: List[int] = []\n        pending_input_ids: List[int] = []\n        pending_weight = self.weight\n        running_content = \"\"\n        for _, msg_content in enumerate(self.content):\n            # TODO also handle non-text content types\n            if msg_content.type in [\n                MessageContentTypes.text.value,\n                MessageContentTypes.tool_call.value,\n                MessageContentTypes.tool_response.value,\n            ]:\n                running_content += str(msg_content)\n                tok_results = tokenizer(running_content, add_special_tokens=False)\n                tok_input_ids = tok_results[\"input_ids\"]\n                if pending_input_ids:\n                    new_pending_inputs = tok_input_ids[\n                        len(input_ids) : len(input_ids) + len(pending_input_ids)\n                    ]\n                    if new_pending_inputs != pending_input_ids:\n                        pending_input_ids = new_pending_inputs\n                    input_ids.extend(pending_input_ids)\n                    if pending_weight:\n                        labels.extend(pending_input_ids)\n                    else:\n                        labels.extend([ignore_index] * len(pending_input_ids))\n                pending_input_ids = tok_results[\"input_ids\"][len(input_ids) :]\n                pending_weight = self.weight and msg_content.weight not in [0, 0.0]\n        input_ids.extend(pending_input_ids)\n        if pending_weight:\n            labels.extend(pending_input_ids)\n        else:\n            labels.extend([ignore_index] * len(pending_input_ids))\n        attention_mask = [1] * len(input_ids)\n        return {\n            \"input_ids\": input_ids,\n            \"attention_mask\": attention_mask,\n            \"labels\": labels,\n        }\n\n\nclass Chats(BaseModel):\n    \"\"\"\n    top level data structure for chat conversations\n    \"\"\"\n\n    conversation: List[Messages]\n\n    def __str__(self) -> str:\n        return \"\".join(str(c) for c in self.conversation)\n\n    def tokenized(\n        self, tokenizer: Callable[[str], dict[str, List[int]]], ignore_index=-100\n    ) -> dict[str, List[int]]:\n        input_ids = []\n        attention_mask = []\n        labels = []\n        for msg in self.conversation:\n            msg_results = msg.tokenized(tokenizer, ignore_index)\n            input_ids.extend(msg_results[\"input_ids\"])\n            attention_mask.extend(msg_results[\"attention_mask\"])\n            labels.extend(msg_results[\"labels\"])\n        return {\n            \"input_ids\": input_ids,\n            \"attention_mask\": attention_mask,\n            \"labels\": labels,\n        }\n\n\nclass ChatFormattedChats(Chats):\n    \"\"\"\n    Chat formatted chats with formatter and optional train on inputs\n    \"\"\"\n\n    formatter: Callable  # [[Union[dict, Chats]], Chats]\n    train_on_inputs: bool = False\n\n    def model_post_init(self, __context):\n        for i, msg in enumerate(self.conversation):\n            self.conversation[i] = self.formatter(msg, message_index=i)\n            if self.train_on_inputs:\n                self.conversation[i].weight = 1\n\n\nclass PreferenceChats(BaseModel):\n    \"\"\"\n    representation for preference data for chat\n    \"\"\"\n\n    prompt: List[Messages]\n    chosen: Messages\n    rejected: Messages\n"
  },
  {
    "path": "src/axolotl/core/datasets/__init__.py",
    "content": ""
  },
  {
    "path": "src/axolotl/core/datasets/chat.py",
    "content": "\"\"\"\nchat dataset module\n\"\"\"\n\nfrom typing import Callable, Optional, Union\n\nfrom datasets import Dataset\nfrom transformers import PreTrainedTokenizer\n\nfrom axolotl.core.chat.messages import ChatFormattedChats\n\n\nclass TokenizedChatDataset(Dataset):\n    \"\"\"\n    Tokenized chat dataset\n    \"\"\"\n\n    def __init__(\n        self,\n        data: Dataset,\n        model_transform: Union[PreTrainedTokenizer, Callable],\n        *args,\n        message_transform: Optional[Callable] = None,\n        formatter=None,\n        process_count: Optional[int] = None,\n        keep_in_memory: Optional[bool] = False,\n        **kwargs,\n    ):\n        def map_fn(ex):\n            if message_transform is not None:\n                ex = message_transform(ex)\n            if formatter is not None:\n                ex = ChatFormattedChats(\n                    formatter=formatter,\n                    **ex,\n                )\n            else:\n                ex = ChatFormattedChats(\n                    **ex,\n                )\n            return ex.tokenized(model_transform)\n\n        features = data.features.keys()\n        tokenized_data = data.map(\n            map_fn,\n            num_proc=process_count,\n            keep_in_memory=keep_in_memory,\n            remove_columns=features,\n            desc=\"Tokenizing Chats\",\n        )\n        super().__init__(tokenized_data.data, *args, **kwargs)\n"
  },
  {
    "path": "src/axolotl/core/datasets/transforms/__init__.py",
    "content": ""
  },
  {
    "path": "src/axolotl/core/datasets/transforms/chat_builder.py",
    "content": "\"\"\"\nThis module contains a function that builds a transform that takes a row from the\ndataset and converts it to a Chat.\n\"\"\"\n\nfrom typing import Any, Mapping\n\n\ndef chat_message_transform_builder(\n    train_on_inputs=False,\n    conversations_field: str = \"messages\",\n    message_field_role: str | list[str] | None = None,  # commonly \"role\"\n    message_field_content: str | list[str] | None = None,  # commonly \"content\"\n    message_field_training: str | list[str] | None = None,  # commonly \"weight\"\n):\n    \"\"\"Builds a transform that takes a row from the dataset and converts it to a Chat\n\n    Args:\n        train_on_inputs (bool, optional):\n            If True, the transform will train on the inputs. If False, the transform will train on the targets.\n            Defaults to False.\n        conversations_field (str, optional):\n            The field name of the conversations. Defaults to \"messages\".\n        message_field_role (str | list[str], optional):\n            The field name of the role.\n        message_field_content (str | list[str], optional):\n            The field name of the message content.\n        message_field_training (str | list[str], optional):\n            The field name of the train/weight.\n\n    Returns:\n        Callable:\n            A function that takes a list of conversations and returns a list of messages.\n    \"\"\"\n\n    if message_field_training is None:\n        message_field_training = [\"train\", \"weight\"]\n    if message_field_content is None:\n        message_field_content = [\"value\", \"text\", \"content\"]\n    if message_field_role is None:\n        message_field_role = [\"role\", \"from\"]\n    message_field_role = (\n        [message_field_role]\n        if isinstance(message_field_role, str)\n        else message_field_role\n    )\n    message_field_content = (\n        [message_field_content]\n        if isinstance(message_field_content, str)\n        else message_field_content\n    )\n    message_weight_fields = (\n        [message_field_training]\n        if isinstance(message_field_training, str)\n        else message_field_training\n    )\n\n    role_value_mappings = {\n        \"system\": \"system\",\n        \"user\": \"user\",\n        \"human\": \"user\",\n        \"assistant\": \"assistant\",\n        \"gpt\": \"assistant\",\n        \"tool\": \"tool\",\n        \"ipython\": \"ipython\",\n    }\n    if train_on_inputs:\n        role_default_weights_mappings = {\n            \"system\": 1,\n            \"user\": 1,\n            \"assistant\": 1,\n            \"tool\": 1,\n            \"ipython\": 1,\n        }\n    else:\n        role_default_weights_mappings = {\n            \"system\": 0,\n            \"user\": 0,\n            \"assistant\": 1,\n            \"tool\": 0,\n            \"ipython\": 0,\n        }\n\n    def transform_builder(sample: Mapping[str, Any]):\n        if conversations_field not in sample:\n            raise ValueError(f\"Field '{conversations_field}' not found in sample.\")\n        # if none of the role fields are in the message, raise an error\n        if not any(\n            role in sample[conversations_field][0] for role in message_field_role\n        ):\n            raise ValueError(\"No role field found in message.\")\n        role_field = next(\n            role\n            for role in message_field_role\n            if role in sample[conversations_field][0]\n        )\n        if not any(\n            field in sample[conversations_field][0] for field in message_field_content\n        ):\n            raise ValueError(\"No message_content field found in message.\")\n        message_content_field = next(\n            field\n            for field in message_field_content\n            if field in sample[conversations_field][0]\n        )\n        if not any(\n            field in sample[conversations_field][0] for field in message_field_training\n        ):\n            message_weight_field = None\n        else:\n            message_weight_field = next(\n                field\n                for field in message_weight_fields\n                if field in sample[conversations_field][0]\n            )\n\n        messages = []\n        for message in sample[conversations_field]:\n            role = role_value_mappings[message[role_field]]\n            weight = (\n                int(message[message_weight_field])\n                if message_weight_field\n                else role_default_weights_mappings[role]\n            )\n\n            # TODO if \"tool_calls\" in message[message_content_field]: then convert tool call to ToolCallContents\n            if isinstance(message[message_content_field], str):\n                messages.append(\n                    {\n                        \"role\": role,\n                        \"content\": [\n                            {\n                                \"type\": \"text\",\n                                \"value\": message[message_content_field],\n                            }\n                        ],\n                        \"weight\": weight,\n                    }\n                )\n            else:\n                messages.append(\n                    {\n                        \"role\": role,\n                        \"content\": message[message_content_field],\n                        \"weight\": weight,\n                    }\n                )\n\n        return {\"conversation\": messages}\n\n    return transform_builder\n"
  },
  {
    "path": "src/axolotl/core/trainers/__init__.py",
    "content": "\"\"\"Init for axolotl.core.trainers\"\"\"\n\n# flake8: noqa\n\nfrom .base import AxolotlTrainer\nfrom .dpo.trainer import AxolotlDPOTrainer\nfrom .mamba import AxolotlMambaTrainer\nfrom .trl import (\n    AxolotlCPOTrainer,\n    AxolotlKTOTrainer,\n    AxolotlORPOTrainer,\n    AxolotlPRMTrainer,\n    AxolotlRewardTrainer,\n)\n"
  },
  {
    "path": "src/axolotl/core/trainers/base.py",
    "content": "\"\"\"Module for customized trainers\"\"\"\n\nfrom __future__ import annotations\n\nimport json\nimport math\nimport os\nfrom collections import defaultdict\nfrom functools import partial, wraps\nfrom typing import Any, Callable, Literal, Optional\n\nimport datasets\nimport safetensors\nimport torch\nfrom accelerate.state import AcceleratorState\nfrom datasets import Dataset\nfrom peft import PeftModel\nfrom torch.utils.data import (\n    BatchSampler,\n    DataLoader,\n    RandomSampler,\n    Sampler,\n    SequentialSampler,\n)\nfrom transformers import PreTrainedModel, Trainer\nfrom transformers.trainer import TRAINING_ARGS_NAME\nfrom transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length, seed_worker\nfrom transformers.utils import SAFE_WEIGHTS_NAME, is_peft_available\nfrom trl.experimental.utils import pad_to_length\nfrom typing_extensions import override\n\nfrom axolotl.core.trainers.mixins import (\n    ActivationOffloadingMixin,\n    CheckpointSaveMixin,\n    DistributedParallelMixin,\n    OptimizerMixin,\n    PackingMixin,\n    RngLoaderMixin,\n    SchedulerMixin,\n)\nfrom axolotl.core.trainers.utils import (\n    sanitize_kwargs_for_ds_tagging,\n    sanitize_kwargs_for_tagging,\n)\nfrom axolotl.utils import get_not_null\nfrom axolotl.utils.bench import get_gpu_memory_usage\nfrom axolotl.utils.dict import DictDefault\nfrom axolotl.utils.distributed import is_distributed, is_main_process\nfrom axolotl.utils.logging import get_logger\nfrom axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths\n\nLOG = get_logger(__name__)\n\nTOKENS_STATE_FILE = \"tokens_state.\"\n\nREDUCTION_FNS = {\n    \"mean\": torch.mean,\n    \"min\": torch.min,\n    \"max\": torch.max,\n    \"sum\": torch.sum,\n}\n\n\nclass AxolotlTrainer(\n    PackingMixin,\n    SchedulerMixin,\n    OptimizerMixin,\n    RngLoaderMixin,\n    CheckpointSaveMixin,\n    ActivationOffloadingMixin,\n    DistributedParallelMixin,\n    Trainer,\n):\n    \"\"\"Extend the base Trainer for axolotl helpers\"\"\"\n\n    args = None  # type: \"AxolotlTrainingArguments\"  # type: ignore[name-defined]\n    tag_names = [\"axolotl\"]\n    _axolotl_cfg: DictDefault | None = None\n\n    @property\n    def axolotl_cfg(self):\n        return self._axolotl_cfg\n\n    @axolotl_cfg.setter\n    def axolotl_cfg(self, cfg):\n        self._axolotl_cfg = cfg\n\n    def __init__(\n        self,\n        *_args,\n        bench_data_collator=None,\n        eval_data_collator=None,\n        dataset_tags=None,\n        **kwargs,\n    ):\n        self.bench_data_collator = bench_data_collator\n        self.eval_data_collator = eval_data_collator\n        self.dataset_tags = dataset_tags\n        self._signature_columns = None  # workaround for pylint\n\n        super().__init__(*_args, **kwargs)\n        self.train_data_collator = self.data_collator\n        self._stored_metrics = defaultdict(\n            lambda: defaultdict(lambda: {\"values\": [], \"reduction\": \"mean\"})\n        )\n        if self.args.orpo_alpha:\n            self.loss_fct = torch.nn.CrossEntropyLoss(reduction=\"none\")\n\n    def _create_multipack_sampler(\n        self, base_sampler: Sampler, dataset: Dataset\n    ) -> MultipackBatchSampler:\n        \"\"\"\n        Helper method to create a `MultipackBatchSampler` for multipacking sequences\n        for training.\n\n        Args:\n            base_sampler: Sampler to wrap with `MultipackBatchSampler`.\n            dataset: Dataset to sample from.\n\n        Returns:\n            Multipack (sample packing) batch sampler.\n        \"\"\"\n        if self.args.multipack_real_batches:\n            batch_size = self.args.per_device_train_batch_size\n            batch_max_len = self.args.max_seq_length\n        else:\n            batch_size = 1\n            train_batch_size = (\n                self.state.train_batch_size or self.args.per_device_train_batch_size\n            )\n            batch_max_len = train_batch_size * self.args.max_seq_length\n\n        sampler = MultipackBatchSampler(\n            base_sampler,\n            lengths=get_dataset_lengths(dataset),\n            packing_efficiency_estimate=self.args.sample_packing_efficiency,\n            batch_max_len=batch_max_len,\n            batch_size=batch_size,\n            group_size=self.args.sample_packing_group_size,\n            bin_size=self.args.sample_packing_bin_size,\n            sequential=self.args.sample_packing_sequentially,\n            drop_last=True,\n            num_processes=self.args.dataset_num_proc,\n            mp_start_method=self.args.sample_packing_mp_start_method or \"fork\",\n        )\n\n        len(sampler)\n        return sampler\n\n    def _get_train_sampler(\n        self, train_dataset: Dataset | None = None\n    ) -> Sampler | None:\n        \"\"\"\n        Helper method to get the sampler for training. Handles cases for sample packing\n        and curriculum sampling (sequential).\n\n        Returns:\n            If the dataset is non-empty, a sampler is returned, the type of which\n                depends on the passed training args.\n        \"\"\"\n        # from https://github.com/huggingface/transformers/blob/2166b6b4ff09f6dd3867ab982f262f66482aa968/src/transformers/trainer.py#L969C1-L972C24\n        if train_dataset is None:\n            train_dataset = self.train_dataset\n        if train_dataset is None or not has_length(train_dataset):\n            return None\n\n        use_sample_packing = self.args.sample_packing and not self.args.pretraining\n\n        # Determine the base sampler first\n        if self.args.curriculum_sampling:\n            base_sampler = SequentialSampler(train_dataset)\n        elif use_sample_packing:\n            base_sampler = RandomSampler(train_dataset)\n        else:\n            # Default to parent class implementation for standard random sampling\n            return super()._get_train_sampler(train_dataset)\n\n        # Apply multipack wrapper if needed\n        if use_sample_packing:\n            return self._create_multipack_sampler(\n                base_sampler=base_sampler,\n                dataset=train_dataset,\n            )\n\n        return base_sampler\n\n    def _get_eval_sampler(self, eval_dataset: Dataset | None = None) -> Sampler | None:\n        \"\"\"\n        Helper method to get the sampler for evaluation. Handles sample packing case.\n\n        Returns:\n            If the dataset is non-empty, a sampler is returned, the type of which\n                depends on the passed training args.\n        \"\"\"\n        # from https://github.com/huggingface/transformers/blob/2166b6b4ff09f6dd3867ab982f262f66482aa968/src/transformers/trainer.py#L1065C9-L1066C24\n        if eval_dataset is None or not has_length(eval_dataset):\n            return None\n\n        # Multipacking enabled if training is enabled and eval is not explicitly disabled\n        use_multipack = (\n            self.args.sample_packing and self.args.eval_sample_packing is not False\n        )\n\n        # Determine the base sampler\n        if use_multipack:\n            base_sampler = SequentialSampler(eval_dataset)\n        else:\n            return super()._get_eval_sampler(eval_dataset)\n\n        # Apply multipack wrapper if needed\n        if use_multipack:\n            return self._create_multipack_sampler(\n                base_sampler=base_sampler,\n                dataset=eval_dataset,\n            )\n\n        return base_sampler\n\n    def _get_dataloader(\n        self,\n        dataset: Dataset,\n        description: str,\n        batch_size: int,\n        sampler_fn: Optional[Callable[[Dataset], torch.utils.data.Sampler]] = None,\n        is_training: bool = False,\n        dataloader_key: Optional[str] = None,\n    ) -> DataLoader:\n        \"\"\"Create a [`~torch.utils.data.DataLoader`] from the given dataset.\"\"\"\n\n        data_collator = self.data_collator if is_training else self.eval_data_collator\n\n        if isinstance(dataset, datasets.Dataset):\n            if is_training:\n                if not self.args.sample_packing or self.args.pretraining:\n                    dataset = self._remove_unused_columns(\n                        dataset, description=\"training\"\n                    )\n            elif (\n                not is_training\n                and self.args.sample_packing\n                and self.args.eval_sample_packing is not False\n            ):\n                batch_size = (\n                    batch_size\n                    if self.args.sample_packing\n                    else self.args.per_device_eval_batch_size\n                )\n            else:\n                dataset = self._remove_unused_columns(dataset, description=description)\n        else:\n            data_collator = self._get_collator_with_removed_columns(\n                self.data_collator, description=description\n            )\n\n        dataloader_params = {\n            \"batch_size\": batch_size,\n            \"collate_fn\": data_collator,\n            \"num_workers\": self.args.dataloader_num_workers,\n            \"pin_memory\": self.args.dataloader_pin_memory,\n            \"persistent_workers\": self.args.dataloader_persistent_workers,\n        }\n\n        if not isinstance(dataset, torch.utils.data.IterableDataset):\n            dataloader_params[\"drop_last\"] = get_not_null(\n                self.args.dataloader_drop_last, True\n            )\n            if sampler_fn is not None:\n                sampler = sampler_fn(dataset)\n                if isinstance(sampler, BatchSampler):\n                    # batch_size and batch_sampler are mutually exclusive\n                    dataloader_params[\"batch_sampler\"] = sampler\n                    del dataloader_params[\"batch_size\"]\n                    del dataloader_params[\"drop_last\"]\n                else:\n                    dataloader_params[\"sampler\"] = sampler\n\n            dataloader_params[\"prefetch_factor\"] = self.args.dataloader_prefetch_factor\n            if is_training:\n                dataloader_params[\"worker_init_fn\"] = partial(\n                    seed_worker,\n                    num_workers=self.args.dataloader_num_workers,\n                    rank=self.args.process_index,\n                )\n        if self.args.sample_packing and (\n            (is_training and not self.args.pretraining)\n            or (not is_training and self.args.eval_sample_packing is not False)\n        ):\n            self.accelerator.even_batches = False\n\n        if dataset.column_names and \"length\" in dataset.column_names:\n            dataset = dataset.remove_columns([\"length\"])\n\n        if (\n            dataset.column_names\n            and \"position_ids\" in dataset.column_names\n            and \"attention_mask\" in dataset.column_names\n            and self.args.sample_packing\n            and self.args.sample_packing_drop_attention_mask\n        ):\n            dataset = dataset.remove_columns([\"attention_mask\"])\n\n        dataloader = DataLoader(dataset, **dataloader_params)\n\n        # Accelerator.free_memory() will destroy the references, so\n        # we need to store the non-prepared version for eval dataloaders.\n        # fmt: off\n        if dataloader_key is not None and self.args.dataloader_persistent_workers:\n            if hasattr(self, \"_eval_dataloaders\"):\n                self._eval_dataloaders[dataloader_key] = dataloader  # type: ignore\n            else:\n                self._eval_dataloaders = {dataloader_key: dataloader}\n        # fmt: on\n\n        return self.accelerator.prepare(dataloader)\n\n    def _get_bench_sampler(\n        self, bench_dataset: Dataset\n    ) -> torch.utils.data.Sampler | None:\n        if self.args.world_size <= 1:\n            return SequentialSampler(bench_dataset)\n        return None\n\n    def get_bench_dataloader(\n        self,\n        bench_dataset: Dataset,\n    ) -> DataLoader:\n        dataloader_params = {\n            \"batch_size\": self.args.eval_batch_size,\n            \"collate_fn\": self.bench_data_collator,\n            \"num_workers\": self.args.dataloader_num_workers,\n            \"pin_memory\": self.args.dataloader_pin_memory,\n        }\n        if self.args.dataloader_prefetch_factor:\n            dataloader_params[\"prefetch_factor\"] = self.args.dataloader_prefetch_factor\n\n        if not isinstance(bench_dataset, torch.utils.data.IterableDataset):\n            dataloader_params[\"sampler\"] = self._get_bench_sampler(bench_dataset)\n            dataloader_params[\"drop_last\"] = self.args.dataloader_drop_last\n\n        return DataLoader(bench_dataset, **dataloader_params)\n        # return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))\n\n    @override\n    def compute_loss(\n        self, model, inputs, return_outputs=False, num_items_in_batch=None\n    ):\n        # use one's weighted cross entropy loss calc\n        # if self.args.sample_packing:\n        #     labels = inputs.pop(\"labels\")\n        #     outputs = model(**inputs)\n        #     loss = trainer_weighted_loss(outputs, labels, shift_labels=True)\n        #     return (loss, outputs) if return_outputs else loss\n\n        # track number of tokens for tokens per second calculation\n        if self.args.include_tkps and model.training:\n            inputs_key = \"labels\" if \"labels\" in inputs else \"input_ids\"\n            trainable_tokens = (inputs[inputs_key] != -100).sum()\n            total_tokens = inputs[inputs_key].numel()\n            total_tokens = torch.tensor(total_tokens, device=inputs[inputs_key].device)\n\n            if is_distributed():\n                torch.distributed.all_reduce(\n                    trainable_tokens, op=torch.distributed.ReduceOp.SUM\n                )\n                torch.distributed.all_reduce(\n                    total_tokens, op=torch.distributed.ReduceOp.SUM\n                )\n\n            if not hasattr(self.state, \"tokens\"):\n                self.state.tokens = {\n                    \"trainable\": torch.zeros(1),\n                    \"total\": torch.zeros(1),\n                }\n\n            # trainable tokens for throughput and total token slots for summaries\n            self.state.tokens[\"trainable\"] = (\n                self.state.tokens[\"trainable\"] + trainable_tokens.detach().cpu()\n            )\n            self.state.tokens[\"total\"] = self.state.tokens[\"total\"] + total_tokens.cpu()\n            # Store per-step trainable tokens for throughput calculation\n            self.state.tokens[\"trainable_tokens\"] = trainable_tokens.detach().cpu()\n\n        if self.args.orpo_alpha:\n            return self.orpo_compute_loss(\n                model,\n                inputs,\n                return_outputs=return_outputs,\n                num_items_in_batch=num_items_in_batch,\n            )\n\n        return super().compute_loss(\n            model,\n            inputs,\n            return_outputs=return_outputs,\n            num_items_in_batch=num_items_in_batch,\n        )\n\n    @override\n    def evaluate(self, *args, **kwargs):\n        LOG.info(\"Running evaluation step...\")\n        return super().evaluate(*args, **kwargs)\n\n    @staticmethod\n    def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None):\n        concatenated_batch = {}\n\n        max_length = max(\n            inputs[\"input_ids\"].shape[1], inputs[\"rejected_input_ids\"].shape[1]\n        )\n        # Concatenate positive and negative inputs\n        concatenated_batch[\"input_ids\"] = pad_to_length(\n            inputs[\"input_ids\"], max_length, pad_token\n        )\n        concatenated_batch[\"rejected_input_ids\"] = pad_to_length(\n            inputs[\"rejected_input_ids\"], max_length, pad_token\n        )\n        concatenated_batch[\"labels\"] = pad_to_length(\n            inputs[\"labels\"], max_length, label_pad_token\n        )\n        concatenated_batch[\"rejected_labels\"] = pad_to_length(\n            inputs[\"rejected_labels\"], max_length, label_pad_token\n        )\n        concatenated_batch[\"attention_mask\"] = pad_to_length(\n            inputs[\"attention_mask\"], max_length, 0\n        )\n        concatenated_batch[\"rejected_attention_mask\"] = pad_to_length(\n            inputs[\"rejected_attention_mask\"], max_length, 0\n        )\n        concatenated_batch[\"prompt_attention_mask\"] = pad_to_length(\n            inputs[\"prompt_attention_mask\"], max_length, 0\n        ).to(device=device)\n\n        input_ids = torch.cat(\n            [concatenated_batch[\"input_ids\"], concatenated_batch[\"rejected_input_ids\"]],\n            dim=0,\n        ).to(device=device)\n        attention_mask = torch.cat(\n            [\n                concatenated_batch[\"attention_mask\"],\n                concatenated_batch[\"rejected_attention_mask\"],\n            ],\n            dim=0,\n        ).to(device=device)\n        labels = torch.cat(\n            [concatenated_batch[\"labels\"], concatenated_batch[\"rejected_labels\"]], dim=0\n        ).to(device=device)\n\n        return {\n            \"input_ids\": input_ids,\n            \"labels\": labels,\n            \"attention_mask\": attention_mask,\n            \"prompt_attention_mask\": concatenated_batch[\"prompt_attention_mask\"],\n        }\n\n    def orpo_compute_custom_loss(self, logits, labels):\n        logits = logits.contiguous()\n        loss = 0.0\n\n        if labels is not None:\n            # move labels to correct device to enable model parallelism\n            labels = labels.to(logits.device)\n            # Shift so that tokens < n predict n\n            shift_logits = logits[..., :-1, :].contiguous()\n            shift_labels = labels[..., 1:].contiguous()\n\n            # Flatten the tokens\n            loss = self.loss_fct(shift_logits.transpose(2, 1), shift_labels).mean(\n                dim=-1\n            )\n\n        return loss\n\n    def orpo_compute_logps(\n        self, prompt_attention_mask, chosen_inputs, chosen_attention_mask, logits\n    ):\n        # Get the shape of chosen_attention_mask[:, :-1]\n        chosen_shape = chosen_attention_mask[:, :-1].shape\n\n        # Calculate the padding size\n        pad_length = chosen_shape[1] - (prompt_attention_mask.shape[1] - 1)\n\n        # Pad prompt_attention_mask with zeros to match the desired shape\n        prompt_attention_mask_padded = torch.nn.functional.pad(\n            prompt_attention_mask[:, 1:], (0, pad_length), mode=\"constant\", value=0\n        )\n\n        # Perform the subtraction operation\n        mask = chosen_attention_mask[:, :-1] > prompt_attention_mask_padded\n\n        per_token_logps = torch.gather(\n            logits[:, :-1, :].log_softmax(-1),\n            dim=2,\n            index=(mask * chosen_inputs[:, 1:]).unsqueeze(2),\n        ).squeeze(2)\n        return torch.mul(per_token_logps, mask).sum(dim=1) / mask.sum(dim=1)\n\n    def orpo_compute_loss(\n        self,\n        model,\n        inputs,\n        return_outputs=False,\n        num_items_in_batch=None,\n    ):\n        concat_inputs = AxolotlTrainer.orpo_concatenate_inputs(\n            inputs,\n            label_pad_token=-100,\n            pad_token=self.tokenizer.pad_token_id,\n            device=self.accelerator.device,\n        )\n\n        # Perform a single forward pass\n        outputs = model(\n            **{\n                \"input_ids\": concat_inputs[\"input_ids\"],\n                \"attention_mask\": concat_inputs[\"attention_mask\"],\n                \"labels\": concat_inputs[\"labels\"],\n            },\n            output_hidden_states=True,\n        )\n\n        # Split the outputs for positive and negative examples\n        outputs_pos, outputs_neg = outputs.logits.chunk(2)\n\n        # Calculate NLL loss\n        pos_loss = self.orpo_compute_custom_loss(\n            logits=outputs_pos, labels=concat_inputs[\"input_ids\"].chunk(2)[0]\n        )\n\n        # Calculate Log Probability\n        pos_prob = self.orpo_compute_logps(\n            prompt_attention_mask=concat_inputs[\"prompt_attention_mask\"],\n            chosen_inputs=concat_inputs[\"input_ids\"].chunk(2)[0],\n            chosen_attention_mask=concat_inputs[\"attention_mask\"].chunk(2)[0],\n            logits=outputs_pos,\n        )\n        neg_prob = self.orpo_compute_logps(\n            prompt_attention_mask=concat_inputs[\"prompt_attention_mask\"],\n            chosen_inputs=concat_inputs[\"input_ids\"].chunk(2)[1],\n            chosen_attention_mask=concat_inputs[\"attention_mask\"].chunk(2)[1],\n            logits=outputs_neg,\n        )\n\n        # Calculate log odds\n        log_odds = (pos_prob - neg_prob) - (\n            torch.log(1 - torch.exp(pos_prob)) - torch.log(1 - torch.exp(neg_prob))\n        )\n        sig_ratio = torch.nn.functional.sigmoid(log_odds)\n        ratio = torch.log(sig_ratio)\n\n        # Calculate the Final Loss\n        loss = torch.mean(pos_loss - self.args.orpo_alpha * ratio).to(\n            dtype=torch.bfloat16\n        )\n\n        metrics = {}\n        metrics[\"chosen_geometric_mean\"] = torch.mean(pos_prob).cpu().item()\n        metrics[\"rejected_geometric_mean\"] = torch.mean(neg_prob).cpu().item()\n        metrics[\"log_odds_ratio\"] = torch.mean(ratio).cpu().item()\n        metrics[\"log_odds\"] = torch.mean(log_odds).cpu().item()\n        self.store_metrics(metrics, train_eval=\"train\")\n\n        return (loss, outputs_pos) if return_outputs else loss\n\n    @wraps(Trainer.push_to_hub)\n    def push_to_hub(self, *args, **kwargs) -> str:\n        \"\"\"\n        Overwrite the `push_to_hub` method in order to force-add the tags when pushing the\n        model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.\n        \"\"\"\n        kwargs = sanitize_kwargs_for_ds_tagging(\n            dataset_tags=self.dataset_tags, kwargs=kwargs\n        )\n        kwargs = sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)\n\n        return super().push_to_hub(*args, **kwargs)\n\n    @wraps(Trainer.create_accelerator_and_postprocess)\n    def create_accelerator_and_postprocess(self):\n        # cleanup the PartialState states so Accelerate automatically configures everything from the env vars\n        accelerator_config = self.args.accelerator_config.to_dict()\n        use_configured_state = accelerator_config.get(\"use_configured_state\", False)\n        if not use_configured_state:\n            AcceleratorState._reset_state(reset_partial_state=True)\n\n        super().create_accelerator_and_postprocess()\n\n    def additional_accelerator_args(\n        self, fp8: bool = False, enable_fsdp_float8_all_gather: bool = False, **kwargs\n    ) -> dict[str, Any]:\n        ret_kwargs = {}\n        if fp8:\n            from accelerate.utils import AORecipeKwargs\n            from torchao.float8 import Float8LinearConfig\n\n            # By default, Float8LinearConfig is instantiated using the \"tensorwise\"\n            # scaling strategy. See more details here:\n            # https://github.com/pytorch/ao/tree/main/torchao/float8.\n            config = Float8LinearConfig(\n                enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather,\n                force_recompute_fp8_weight_in_bwd=enable_fsdp_float8_all_gather is True,\n            )\n\n            ret_kwargs[\"mixed_precision\"] = \"fp8\"\n            ret_kwargs[\"kwargs_handlers\"] = [AORecipeKwargs(config=config)]  # type: ignore\n            os.environ[\"ACCELERATE_MIXED_PRECISION\"] = \"fp8\"\n\n        return ret_kwargs\n\n    def log(self, logs: dict[str, float], start_time: float | None = None) -> None:\n        \"\"\"\n        Log `logs` on the various objects watching training, including stored metrics.\n\n        Args:\n            logs: The values to log.\n            start_time: The start of training.\n        \"\"\"\n        # logs either has 'loss' or 'eval_loss'\n        train_eval = \"train\" if \"loss\" in logs else \"eval\"\n        metric_ndigits = int(os.getenv(\"AXOLOTL_METRIC_NDIGITS\", \"5\"))\n\n        for key, metric_data in self._stored_metrics[train_eval].items():\n            values = torch.tensor(metric_data[\"values\"])  # type: ignore[arg-type]\n            reduction_type = metric_data[\"reduction\"]\n\n            fn = REDUCTION_FNS.get(reduction_type)\n            if fn is None:\n                raise NotImplementedError(\n                    \"Metric reduction must be one of [mean, min, max, sum]\"\n                )\n            logs[key] = round(fn(values).item(), metric_ndigits)\n\n        if \"loss\" in logs:\n            try:\n                logs[\"ppl\"] = round(math.exp(logs[\"loss\"]), metric_ndigits)\n            except OverflowError:\n                logs[\"ppl\"] = float(\"inf\")\n        if \"eval_loss\" in logs:\n            try:\n                logs[\"eval_ppl\"] = round(math.exp(logs[\"eval_loss\"]), metric_ndigits)\n            except OverflowError:\n                logs[\"eval_ppl\"] = float(\"inf\")\n\n        if is_main_process():\n            # Add memory usage\n            try:\n                active, allocated, reserved = get_gpu_memory_usage()\n                logs[\"memory/max_active (GiB)\"] = round(active, 2)\n                logs[\"memory/max_allocated (GiB)\"] = round(allocated, 2)\n                logs[\"memory/device_reserved (GiB)\"] = round(reserved, 2)\n            except (ValueError, TypeError, FileNotFoundError):\n                pass\n\n        if (\n            self.args.include_tkps\n            and train_eval == \"train\"\n            and hasattr(self.state, \"tokens\")\n        ):\n            # each rank will log its own tokens per second\n            # for logging_steps > 1 we obtain a moving average of this metric\n            logs[\"tokens/train_per_sec_per_gpu\"] = round(\n                self.state.last_tokens_per_second.item() / self.args.logging_steps, 2\n            )\n            if \"total\" in self.state.tokens:\n                logs[\"tokens/total\"] = int(self.state.tokens[\"total\"].item())\n            if \"trainable\" in self.state.tokens:\n                logs[\"tokens/trainable\"] = int(self.state.tokens[\"trainable\"].item())\n\n        del self._stored_metrics[train_eval]\n\n        return super().log(logs, start_time)\n\n    def store_metrics(\n        self,\n        metrics: dict[str, float] | dict[str, tuple[int | float, str]],\n        train_eval: Literal[\"train\", \"eval\"] = \"train\",\n        reduction: Literal[\"mean\", \"min\", \"max\", \"sum\"] = \"mean\",\n    ) -> None:\n        \"\"\"\n        Store metrics with specified reduction type.\n\n        Args:\n            metrics: Dictionary of metric names to values, or metric names to (value,\n                reduction_type) tuples.\n            train_eval: Whether this is for training or evaluation.\n        \"\"\"\n        for key, value in metrics.items():\n            if isinstance(value, tuple):\n                value, _reduction = value  # type: ignore[assignment]\n            else:\n                value, _reduction = value, reduction\n\n            self._stored_metrics[train_eval][key][\"values\"].append(value)\n            self._stored_metrics[train_eval][key][\"reduction\"] = _reduction\n\n    def _save_checkpoint(self, model, trial, **kwargs):\n        # make sure the checkpoint dir exists, since trainer is flakey\n        checkpoint_folder = f\"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}\"\n        run_dir = self._get_output_dir(trial=trial)\n        output_dir = os.path.join(run_dir, checkpoint_folder)\n        os.makedirs(output_dir, exist_ok=True)\n\n        # Save total_tokens state if tracking is enabled\n        if self.args.include_tkps and hasattr(self.state, \"tokens\"):\n            tokens_state = {\n                \"total\": int(torch.as_tensor(self.state.tokens.get(\"total\", 0)).item()),\n                \"trainable\": int(\n                    torch.as_tensor(self.state.tokens.get(\"trainable\", 0)).item()\n                ),\n            }\n            tokens_state_path = os.path.join(output_dir, TOKENS_STATE_FILE)\n            with open(tokens_state_path, \"w\", encoding=\"utf-8\") as f:\n                json.dump(tokens_state, f)\n\n        return super()._save_checkpoint(model, trial, **kwargs)\n\n    # TODO(wing): remove once https://github.com/huggingface/transformers/pull/39866/files is merged\n    def _save(self, output_dir: Optional[str] = None, state_dict=None):\n        # If we are executing this function, we are the process zero, so we don't check for that.\n        output_dir = output_dir if output_dir is not None else self.args.output_dir\n        os.makedirs(output_dir, exist_ok=True)\n        LOG.info(f\"Saving model checkpoint to {output_dir}\")\n\n        # fix for Context Parallel save: CP eval invalidates tensor storage\n        # pointers, so clone to CPU to get fresh valid storage for safetensors\n        if (\n            state_dict is not None\n            and self.axolotl_cfg\n            and self.axolotl_cfg.context_parallel_size\n            and self.axolotl_cfg.context_parallel_size > 1\n        ):\n            state_dict = {\n                k: v.detach().cpu() if isinstance(v, torch.Tensor) else v\n                for k, v in state_dict.items()\n            }\n\n        supported_classes = (\n            (PreTrainedModel,)\n            if not is_peft_available()\n            else (PreTrainedModel, PeftModel)\n        )\n        # Save a trained model and configuration using `save_pretrained()`.\n        # They can then be reloaded using `from_pretrained()`\n        if not isinstance(self.model, supported_classes):\n            if state_dict is None:\n                state_dict = self.model.state_dict()\n\n            if isinstance(\n                self.accelerator.unwrap_model(self.model, keep_torch_compile=False),\n                supported_classes,\n            ):\n                self.accelerator.unwrap_model(\n                    self.model, keep_torch_compile=False\n                ).save_pretrained(\n                    output_dir,\n                    state_dict=state_dict,\n                    is_main_process=self.accelerator.is_main_process,\n                )\n            else:\n                LOG.info(\n                    \"Trainer.model is not a `PreTrainedModel`, only saving its state dict.\"\n                )\n                safetensors.torch.save_file(\n                    state_dict,\n                    os.path.join(output_dir, SAFE_WEIGHTS_NAME),\n                    metadata={\"format\": \"pt\"},\n                )\n        else:\n            self.model.save_pretrained(\n                output_dir,\n                state_dict=state_dict,\n                is_main_process=self.accelerator.is_main_process,\n            )\n\n        if self.processing_class is not None:\n            self.processing_class.save_pretrained(output_dir)\n        elif (\n            self.data_collator is not None\n            and hasattr(self.data_collator, \"tokenizer\")\n            and self.data_collator.tokenizer is not None\n        ):\n            LOG.info(\n                \"Saving Trainer.data_collator.tokenizer by default as Trainer.processing_class is `None`\"\n            )\n            self.data_collator.tokenizer.save_pretrained(output_dir)\n\n        # Good practice: save your training arguments together with the trained model\n        torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))\n"
  },
  {
    "path": "src/axolotl/core/trainers/dpo/__init__.py",
    "content": "\"\"\"DPO Specific Strategy for training\"\"\"\n\nfrom axolotl.core.trainers.dpo.trainer import AxolotlDPOTrainer\nfrom axolotl.utils.schemas.enums import RLType\n\n\nclass DPOStrategy:\n    \"\"\"Strategy for DPO training\"\"\"\n\n    @classmethod\n    def get_trainer_class(cls):\n        return AxolotlDPOTrainer\n\n    @classmethod\n    def get_training_args_class(cls):\n        from axolotl.core.trainers.dpo.args import AxolotlDPOConfig\n\n        return AxolotlDPOConfig\n\n    @classmethod\n    def set_training_args_kwargs(cls, cfg):\n        training_args_kwargs = {}\n        if cfg.rl is RLType.IPO:\n            training_args_kwargs[\"loss_type\"] = \"ipo\"\n        # Label smoothing is not compatible with IPO\n        if cfg.rl is RLType.DPO and cfg.dpo_label_smoothing:\n            training_args_kwargs[\"label_smoothing\"] = cfg.dpo_label_smoothing\n        training_args_kwargs[\"max_length\"] = cfg.sequence_len\n        if cfg.dpo_use_weighting is not None:\n            training_args_kwargs[\"use_weighting\"] = cfg.dpo_use_weighting\n        if cfg.dpo_padding_free is not None:\n            training_args_kwargs[\"padding_free\"] = cfg.dpo_padding_free\n        if cfg.dpo_norm_loss is not None:\n            training_args_kwargs[\"dpo_norm_loss\"] = cfg.dpo_norm_loss\n        if cfg.dpo_use_liger_kernel is not None:\n            training_args_kwargs[\"use_liger_kernel\"] = cfg.dpo_use_liger_kernel\n        return training_args_kwargs\n"
  },
  {
    "path": "src/axolotl/core/trainers/dpo/args.py",
    "content": "\"\"\"\nAxolotl specific DPO args\n\"\"\"\n\nfrom dataclasses import dataclass\n\nfrom trl import DPOConfig\n\nfrom axolotl.core.training_args import AxolotlTrainingMixins\n\n\n@dataclass\nclass AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):\n    \"\"\"\n    DPO config for DPO training\n    \"\"\"\n\n    dpo_norm_loss: bool | None = False\n"
  },
  {
    "path": "src/axolotl/core/trainers/dpo/trainer.py",
    "content": "\"\"\"DPO trainer for axolotl\"\"\"\n\nimport gc\nfrom functools import wraps\nfrom typing import Any, Dict, Union\n\nimport torch\nfrom torch import nn\nfrom trl import DPOTrainer\n\nfrom axolotl.core.trainers.mixins import (\n    DistributedParallelMixin,\n    RngLoaderMixin,\n    SchedulerMixin,\n)\nfrom axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, OptimizerMixin\nfrom axolotl.core.trainers.utils import (\n    sanitize_kwargs_for_ds_tagging,\n    sanitize_kwargs_for_tagging,\n)\n\n\nclass AxolotlDPOTrainer(\n    RngLoaderMixin,\n    SchedulerMixin,\n    OptimizerMixin,\n    OptimizerInitMixin,\n    DPOTrainer,\n    DistributedParallelMixin,\n):\n    \"\"\"Extend the base DPOTrainer for axolotl helpers.\"\"\"\n\n    tag_names = [\"axolotl\", \"dpo\"]\n\n    def __init__(self, *args, dataset_tags=None, **kwargs):\n        super().__init__(*args, **kwargs)\n\n        self.dataset_tags = dataset_tags\n        self.optimizer = None\n        self.model_accepts_loss_kwargs = False\n\n    @wraps(DPOTrainer.push_to_hub)\n    def push_to_hub(self, *args, **kwargs) -> str:\n        \"\"\"\n        Overwrite the `push_to_hub` method in order to force-add the tags when pushing\n        the model on the Hub. Please refer to `~transformers.Trainer.push_to_hub`\n        for more details.\n        \"\"\"\n        kwargs = sanitize_kwargs_for_ds_tagging(\n            dataset_tags=self.dataset_tags, kwargs=kwargs\n        )\n        kwargs = sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)\n\n        return super().push_to_hub(*args, **kwargs)\n\n    @staticmethod\n    def tokenize_row(\n        features,\n        processing_class,\n        max_prompt_length: int | None = None,\n        max_completion_length: int | None = None,\n        add_special_tokens: bool = True,\n        is_chat: bool = False,\n    ) -> Dict:\n        res = DPOTrainer.tokenize_row(\n            features,\n            processing_class,\n            max_prompt_length=max_prompt_length,\n            max_completion_length=max_completion_length,\n            add_special_tokens=add_special_tokens,\n            is_chat=is_chat,\n        )\n        # fix when the tokenizer doesn't have a bos_token_id, e.g. Qwen\n        if processing_class.bos_token is None and res[\"prompt_input_ids\"][0] is None:\n            for key in res.keys():\n                res[key] = res[key][1:]\n\n        if processing_class.bos_token and processing_class.bos_token_id is not None:\n            # dpo trainer may incorrectly prepend the bos_token_id to the dpo outputs\n            if res[\"chosen_input_ids\"][0] == processing_class.bos_token_id:\n                res[\"chosen_input_ids\"] = res[\"chosen_input_ids\"][1:]\n            if res[\"rejected_input_ids\"][0] == processing_class.bos_token_id:\n                res[\"rejected_input_ids\"] = res[\"rejected_input_ids\"][1:]\n\n        return res\n\n    def training_step(\n        self,\n        model: nn.Module,\n        inputs: Dict[str, Union[torch.Tensor, Any]],\n        num_items_in_batch=None,\n    ) -> torch.Tensor:\n        loss: torch.Tensor = super().training_step(model, inputs, num_items_in_batch)\n        gc.collect()\n        torch.cuda.empty_cache()\n        return loss\n\n    def concatenated_forward(\n        self,\n        model: nn.Module,\n        batch: dict[str, Union[list, torch.LongTensor]],\n        is_ref_model: bool = False,\n    ) -> dict[str, torch.Tensor]:\n        if self.args.dpo_norm_loss:\n            # fmt: off\n            loss_type: list[str] = self.loss_type  # type: ignore[has-type]\n            # fmt: on\n            # concatenated_forward handles avg token logprob for ipo case already\n            self.loss_type = [\"ipo\"]\n            res = super().concatenated_forward(model, batch, is_ref_model=is_ref_model)\n            self.loss_type = loss_type\n            return res\n        return super().concatenated_forward(model, batch, is_ref_model=is_ref_model)\n"
  },
  {
    "path": "src/axolotl/core/trainers/grpo/__init__.py",
    "content": "\"\"\"GRPO Specific Strategy for training\"\"\"\n\nimport importlib\nimport inspect\nimport os\nfrom typing import Any\n\nfrom huggingface_hub import snapshot_download\nfrom requests import HTTPError\nfrom trl.trainer.grpo_trainer import RewardFunc\n\nfrom axolotl.core.trainers.grpo.args import AxolotlAsyncGRPOConfig, AxolotlGRPOConfig\nfrom axolotl.core.trainers.grpo.trainer import (\n    AxolotlAsyncGRPOTrainer,\n    AxolotlGRPOSequenceParallelTrainer,\n    AxolotlGRPOTrainer,\n)\nfrom axolotl.utils.dict import DictDefault\nfrom axolotl.utils.logging import get_logger\nfrom axolotl.utils.schemas.trl import TRLConfig\nfrom axolotl.utils.schemas.vllm import VllmConfig\n\nLOG = get_logger(__name__)\n\n\nclass GRPOStrategy:\n    \"\"\"Strategy for GRPO training\"\"\"\n\n    @classmethod\n    def get_trainer_class(\n        cls,\n        sequence_parallel: bool,\n        async_grpo: bool = False,\n    ) -> (\n        type[AxolotlGRPOTrainer]\n        | type[AxolotlGRPOSequenceParallelTrainer]\n        | type[AxolotlAsyncGRPOTrainer]\n    ):\n        if sequence_parallel and async_grpo:\n            raise ValueError(\n                \"sequence_parallel and async_grpo cannot both be enabled. \"\n                \"Disable one of context_parallel_size > 1 or async_prefetch/use_data_producer.\"\n            )\n        if sequence_parallel:\n            return AxolotlGRPOSequenceParallelTrainer\n        if async_grpo:\n            return AxolotlAsyncGRPOTrainer\n        return AxolotlGRPOTrainer\n\n    @classmethod\n    def get_training_args_class(\n        cls, async_grpo: bool = False\n    ) -> type[AxolotlGRPOConfig] | type[AxolotlAsyncGRPOConfig]:\n        if async_grpo:\n            return AxolotlAsyncGRPOConfig\n        return AxolotlGRPOConfig\n\n    @classmethod\n    def set_training_args_kwargs(cls, cfg: DictDefault) -> dict[str, Any]:\n        grpo_args_kwargs: dict[str, Any] = {}\n\n        if not hasattr(cfg, \"trl\") or not cfg.trl:\n            return grpo_args_kwargs\n\n        trl: TRLConfig = cfg.trl  # type: ignore\n        vllm_cfg: VllmConfig = cfg.vllm  # type: ignore\n\n        if trl.use_vllm:\n            grpo_args_kwargs[\"use_vllm\"] = trl.use_vllm\n            if trl.vllm_mode:\n                grpo_args_kwargs[\"vllm_mode\"] = trl.vllm_mode\n            if trl.vllm_mode == \"colocate\":\n                grpo_args_kwargs[\"vllm_enable_sleep_mode\"] = trl.vllm_enable_sleep_mode  # type: ignore[attr-defined]\n                grpo_args_kwargs[\"vllm_gpu_memory_utilization\"] = (\n                    vllm_cfg.gpu_memory_utilization\n                )\n                grpo_args_kwargs[\"vllm_tensor_parallel_size\"] = (\n                    vllm_cfg.tensor_parallel_size\n                )\n            grpo_args_kwargs[\"vllm_server_host\"] = trl.vllm_server_host or trl.vllm.host  # type: ignore[attr-defined]\n            grpo_args_kwargs[\"vllm_server_port\"] = trl.vllm_server_port or trl.vllm.port  # type: ignore[attr-defined]\n            if trl.vllm_server_timeout:\n                grpo_args_kwargs[\"vllm_server_timeout\"] = trl.vllm_server_timeout\n            if trl.vllm_guided_decoding_regex:\n                grpo_args_kwargs[\"vllm_guided_decoding_regex\"] = (\n                    trl.vllm_guided_decoding_regex\n                )\n\n        if trl.num_generations:\n            grpo_args_kwargs[\"num_generations\"] = trl.num_generations\n\n        if trl.sync_ref_model:\n            grpo_args_kwargs[\"sync_ref_model\"] = trl.sync_ref_model\n\n            if trl.ref_model_mixup_alpha:\n                grpo_args_kwargs[\"ref_model_mixup_alpha\"] = trl.ref_model_mixup_alpha\n\n            if trl.ref_model_sync_steps:\n                grpo_args_kwargs[\"ref_model_sync_steps\"] = trl.ref_model_sync_steps\n\n        grpo_args_kwargs[\"max_completion_length\"] = trl.max_completion_length\n        grpo_args_kwargs[\"log_completions\"] = trl.log_completions\n        grpo_args_kwargs[\"num_completions_to_print\"] = trl.num_completions_to_print\n\n        if cfg.context_parallel_size > 1:\n            grpo_args_kwargs[\"context_parallel_size\"] = cfg.context_parallel_size\n\n        if trl.importance_sampling_level is not None:\n            grpo_args_kwargs[\"importance_sampling_level\"] = (\n                trl.importance_sampling_level\n            )\n\n        if trl.reward_weights:\n            grpo_args_kwargs[\"reward_weights\"] = trl.reward_weights\n\n        if trl.scale_rewards is not None:\n            grpo_args_kwargs[\"scale_rewards\"] = trl.scale_rewards\n\n        if trl.loss_type is not None:\n            grpo_args_kwargs[\"loss_type\"] = trl.loss_type\n        if trl.mask_truncated_completions is not None:\n            grpo_args_kwargs[\"mask_truncated_completions\"] = (\n                trl.mask_truncated_completions\n            )\n\n        if trl.temperature is not None:\n            grpo_args_kwargs[\"temperature\"] = trl.temperature\n        if trl.top_p is not None:\n            grpo_args_kwargs[\"top_p\"] = trl.top_p\n        if trl.top_k is not None:\n            grpo_args_kwargs[\"top_k\"] = trl.top_k\n        if trl.min_p is not None:\n            grpo_args_kwargs[\"min_p\"] = trl.min_p\n        if trl.repetition_penalty is not None:\n            grpo_args_kwargs[\"repetition_penalty\"] = trl.repetition_penalty\n\n        if trl.num_iterations is not None:\n            grpo_args_kwargs[\"num_iterations\"] = trl.num_iterations\n        if trl.epsilon is not None:\n            grpo_args_kwargs[\"epsilon\"] = trl.epsilon\n        if trl.epsilon_high is not None:\n            grpo_args_kwargs[\"epsilon_high\"] = trl.epsilon_high\n\n        if trl.use_liger_loss is not None:\n            grpo_args_kwargs[\"use_liger_kernel\"] = trl.use_liger_loss\n\n        if trl.multi_objective_aggregation is not None:\n            grpo_args_kwargs[\"multi_objective_aggregation\"] = (\n                trl.multi_objective_aggregation\n            )\n\n        # Async GRPO fields\n        if getattr(trl, \"use_data_producer\", None) is not None:\n            grpo_args_kwargs[\"use_data_producer\"] = trl.use_data_producer\n        if getattr(trl, \"async_prefetch\", None) is not None:\n            grpo_args_kwargs[\"async_prefetch\"] = trl.async_prefetch\n        if getattr(trl, \"prefetch_depth\", None) is not None:\n            grpo_args_kwargs[\"prefetch_depth\"] = trl.prefetch_depth\n        if getattr(trl, \"vllm_sync_interval\", None) is not None:\n            grpo_args_kwargs[\"vllm_sync_interval\"] = trl.vllm_sync_interval\n        if getattr(trl, \"streaming_partial_batch\", None) is not None:\n            grpo_args_kwargs[\"streaming_partial_batch\"] = trl.streaming_partial_batch\n        if getattr(trl, \"streaming_min_groups\", None) is not None:\n            grpo_args_kwargs[\"streaming_min_groups\"] = trl.streaming_min_groups\n        if getattr(trl, \"vllm_importance_sampling_correction\", None) is not None:\n            grpo_args_kwargs[\"vllm_importance_sampling_correction\"] = (\n                trl.vllm_importance_sampling_correction\n            )\n        if getattr(trl, \"vllm_importance_sampling_mode\", None) is not None:\n            grpo_args_kwargs[\"vllm_importance_sampling_mode\"] = (\n                trl.vllm_importance_sampling_mode\n            )\n        if getattr(trl, \"vllm_importance_sampling_cap\", None) is not None:\n            grpo_args_kwargs[\"vllm_importance_sampling_cap\"] = (\n                trl.vllm_importance_sampling_cap\n            )\n        if getattr(trl, \"off_policy_mask_threshold\", None) is not None:\n            grpo_args_kwargs[\"off_policy_mask_threshold\"] = (\n                trl.off_policy_mask_threshold\n            )\n        if getattr(trl, \"use_bias_correction_kl\", None) is not None:\n            grpo_args_kwargs[\"use_bias_correction_kl\"] = trl.use_bias_correction_kl\n\n        # Fast Async GRPO fields\n        if getattr(trl, \"reward_num_workers\", None) is not None:\n            grpo_args_kwargs[\"reward_num_workers\"] = trl.reward_num_workers\n        if getattr(trl, \"replay_buffer_size\", None) is not None:\n            grpo_args_kwargs[\"replay_buffer_size\"] = trl.replay_buffer_size\n        if getattr(trl, \"replay_recompute_logps\", None) is not None:\n            grpo_args_kwargs[\"replay_recompute_logps\"] = trl.replay_recompute_logps\n        if getattr(trl, \"reroll_start_fraction\", None) is not None:\n            grpo_args_kwargs[\"reroll_start_fraction\"] = trl.reroll_start_fraction\n        if getattr(trl, \"reroll_max_groups\", None) is not None:\n            grpo_args_kwargs[\"reroll_max_groups\"] = trl.reroll_max_groups\n        if getattr(trl, \"skip_zero_advantage_batches\", None) is not None:\n            grpo_args_kwargs[\"skip_zero_advantage_batches\"] = (\n                trl.skip_zero_advantage_batches\n            )\n        if getattr(trl, \"vllm_lora_sync\", None) is not None:\n            grpo_args_kwargs[\"vllm_lora_sync\"] = trl.vllm_lora_sync\n\n        return grpo_args_kwargs\n\n    @classmethod\n    def set_trainer_args(cls, cfg: DictDefault) -> list[Any]:\n        trainer_args = []\n        if cfg.trl and cfg.trl.reward_funcs:\n            reward_funcs = []\n            for reward_func_fqn in cfg.trl.reward_funcs:\n                reward_funcs.append(cls.get_reward_func(reward_func_fqn))\n            trainer_args.append(reward_funcs)\n\n        return trainer_args\n\n    @classmethod\n    def set_trainer_kwargs(cls, cfg: DictDefault) -> dict[str, Any]:\n        trainer_kwargs = {}\n        if cfg.trl and cfg.trl.reward_processing_classes:\n            trainer_kwargs[\"reward_processing_classes\"] = (\n                cfg.trl.reward_processing_classes\n            )\n        if cfg.trl and cfg.trl.rollout_func:\n            trainer_kwargs[\"rollout_func\"] = cls.get_rollout_func(cfg.trl.rollout_func)\n\n        return trainer_kwargs\n\n    @classmethod\n    def get_collator(cls, *args, **kwargs):\n        # No data collation is needed in GRPO, handled by trl's trainer __init__\n        return None\n\n    @classmethod\n    def get_blocklist_args_kwargs(cls) -> list[str]:\n        return [\n            \"dataset_num_proc\",\n            \"max_length\",\n            \"include_tokens_per_second\",\n            \"max_prompt_length\",\n        ]\n\n    @classmethod\n    def get_reward_func(cls, reward_func_fqn: str) -> RewardFunc:\n        \"\"\"\n        Returns the reward function from the given fully qualified name, or the path to the reward function model.\n\n        Args:\n            reward_func_fqn (str): Fully qualified name of the reward function (e.g. r1_grpo.gsm8k_transform),\n                or a HF hub path to the reward model.\n\n        Returns:\n            RewardFunc: A callable that accepts prompts and completions and returns rewards,\n                or a path to a reward model.\n\n        Raises:\n            ValueError: If the reward function does not accept at least two arguments.\n        \"\"\"\n        try:\n            # use importlib to dynamically load the reward function from the module\n            reward_func_module_name = reward_func_fqn.split(\".\")[-1]\n            reward_func_module = importlib.import_module(\n                \".\".join(reward_func_fqn.split(\".\")[:-1])\n            )\n            reward_func = getattr(reward_func_module, reward_func_module_name)\n            if not len(inspect.signature(reward_func).parameters) >= 2:\n                raise ValueError(\n                    \"Reward function must accept at least two arguments: prompts: list and completions: list\"\n                )\n            return reward_func\n        except ModuleNotFoundError as exc:\n            # the user has passed a string (ideally indicating the path of a reward model)\n            # check if it's a local dir path and not empty dir to a reward model\n            pretrained_log_msg = f\"Reward function {reward_func_fqn} is a pre-trained model path - if this is unexpected, please check the reward function path.\"\n            if os.path.isdir(reward_func_fqn) and os.listdir(reward_func_fqn):\n                LOG.info(pretrained_log_msg)\n                return reward_func_fqn\n            try:\n                snapshot_download(reward_func_fqn, repo_type=\"model\")\n                LOG.info(pretrained_log_msg)\n                return reward_func_fqn\n            except HTTPError:\n                raise ValueError(\n                    f\"Reward function {reward_func_fqn} not found.\"\n                ) from exc\n\n    @classmethod\n    def get_rollout_func(cls, rollout_func_fqn: str):\n        \"\"\"\n        Returns the rollout function from the given fully qualified name.\n\n        Args:\n            rollout_func_fqn (str): Fully qualified name of the rollout function\n                                    (e.g. my_module.my_rollout_func)\n\n        Returns:\n            Callable rollout function\n        \"\"\"\n        try:\n            rollout_func_module_name = rollout_func_fqn.split(\".\")[-1]\n            rollout_func_module = importlib.import_module(\n                \".\".join(rollout_func_fqn.split(\".\")[:-1])\n            )\n            rollout_func = getattr(rollout_func_module, rollout_func_module_name)\n\n            if not callable(rollout_func):\n                raise ValueError(\n                    f\"Rollout function {rollout_func_fqn} must be callable\"\n                )\n\n            return rollout_func\n\n        except ModuleNotFoundError as exc:\n            raise ValueError(f\"Rollout function {rollout_func_fqn} not found.\") from exc\n"
  },
  {
    "path": "src/axolotl/core/trainers/grpo/args.py",
    "content": "\"\"\"\nAxolotl Specific Training Args\n\"\"\"\n\nfrom dataclasses import dataclass\n\nfrom trl import GRPOConfig\n\nfrom axolotl.core.trainers.grpo.fast_async_trainer import FastAsyncGRPOConfig\nfrom axolotl.core.training_args import AxolotlTrainingMixins\n\n\n@dataclass\nclass AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig):\n    \"\"\"Axolotl GRPO Config for GRPO training\"\"\"\n\n    context_parallel_size: int | None = None\n\n\n@dataclass\nclass AxolotlAsyncGRPOConfig(AxolotlTrainingMixins, FastAsyncGRPOConfig):\n    \"\"\"Axolotl Async GRPO Config — adds async prefetch, streaming scoring, and IS correction.\"\"\"\n\n    context_parallel_size: int | None = None\n"
  },
  {
    "path": "src/axolotl/core/trainers/grpo/async_trainer.py",
    "content": "\"\"\"\nAsync GRPO training with streaming scoring and IS correction.\n\nWorks on stock TRL v0.29.0 and transformers v5.3.0 — no custom branches needed.\n\nFeatures:\n  - Async prefetch: background thread generates completions via vLLM while the main\n    thread trains on the previous rollout.\n  - Deferred scoring: rewards, advantages, and policy logprobs computed on the main\n    thread (thread-safe with GPU forward passes).\n  - Streaming group scoring: scores prompt groups incrementally so that reward\n    computation overlaps with the next group's logprob computation.\n  - Importance sampling (IS) correction: corrects for stale vLLM weights.\n  - Off-Policy Sequence Mask (OPSM): drops sequences with high KL + negative advantage.\n  - Configurable vLLM weight sync interval.\n\nClasses exported:\n  - AsyncGRPOConfig: GRPOConfig extended with async/streaming/IS fields\n  - AsyncGRPOTrainer: GRPOTrainer with async prefetch and IS correction\n  - ProducerConfig, DataProducer, BaseDataProducer, AsyncDataProducer: data producer protocol\n\"\"\"\n\nimport atexit\nimport concurrent.futures\nimport logging\nimport queue\nimport threading\nfrom abc import ABC, abstractmethod\nfrom collections import deque\nfrom contextlib import nullcontext\nfrom dataclasses import dataclass, field\nfrom typing import Any\n\nimport torch\nfrom torch.utils.data import DataLoader, Dataset\nfrom trl.extras.profiling import profiling_decorator\nfrom trl.trainer import GRPOConfig, GRPOTrainer\nfrom trl.trainer.utils import (\n    RepeatSampler,\n    entropy_from_logits,\n    nanmax,\n    nanmin,\n    nanstd,\n    pad,\n    selective_log_softmax,\n    shuffle_sequence_dict,\n    split_pixel_values_by_grid,\n    split_tensor_dict,\n    unsplit_pixel_values_by_grid,\n)\n\ntry:\n    from trl.data_utils import (\n        apply_chat_template,\n        is_conversational,\n        prepare_multimodal_messages,\n    )\nexcept ImportError:\n    from trl.chat_template_utils import apply_chat_template\n    from trl.data_utils import is_conversational, prepare_multimodal_messages\n\ntry:\n    from trl.models.utils import disable_gradient_checkpointing\nexcept ImportError:\n    from contextlib import contextmanager\n\n    @contextmanager\n    def disable_gradient_checkpointing(model, kwargs):\n        yield\n\n\ntry:\n    from accelerate.utils import gather_object\nexcept ImportError:\n    gather_object = None\n\ntry:\n    from peft import PeftModel\n    from trl.trainer.utils import use_adapter\nexcept ImportError:\n    PeftModel = None\n    use_adapter = nullcontext\n\ntry:\n    from liger_kernel.ops.grpo_loss import (\n        fused_selective_log_softmax as _fused_selective_log_softmax,\n    )\nexcept ImportError:\n    _fused_selective_log_softmax = None\n\n\n# ---------------------------------------------------------------------------\n# Config\n# ---------------------------------------------------------------------------\n\n\n@dataclass\nclass AsyncGRPOConfig(GRPOConfig):\n    \"\"\"GRPOConfig extended with async prefetch, streaming scoring, and IS correction fields.\n\n    Fields already present in stock GRPOConfig (e.g. ``importance_sampling_level``,\n    ``multi_objective_aggregation``) are listed here for safety: if the stock version\n    does not define them, the defaults below ensure everything works.\n    \"\"\"\n\n    # --- Data producer ---\n    use_data_producer: bool = field(\n        default=False,\n        metadata={\n            \"help\": \"Use the GRPODataProducer protocol for online data generation.\"\n        },\n    )\n\n    # --- Async data production ---\n    async_prefetch: bool = field(\n        default=False,\n        metadata={\n            \"help\": \"Generate rollouts in a background thread while training on the previous rollout.\"\n        },\n    )\n    prefetch_depth: int = field(\n        default=1,\n        metadata={\"help\": \"Number of rollouts to prefetch ahead of training.\"},\n    )\n    vllm_sync_interval: int = field(\n        default=1,\n        metadata={\n            \"help\": \"Sync model weights to vLLM every N optimizer steps (async mode only).\"\n        },\n    )\n\n    # --- Streaming scoring ---\n    streaming_partial_batch: bool = field(\n        default=False,\n        metadata={\n            \"help\": \"Score prompt groups incrementally instead of the full batch at once.\"\n        },\n    )\n    streaming_min_groups: int = field(\n        default=1,\n        metadata={\"help\": \"Minimum prompt groups to score per streaming chunk.\"},\n    )\n\n    # --- vLLM importance sampling correction ---\n    vllm_importance_sampling_correction: bool = field(\n        default=True,\n        metadata={\n            \"help\": \"Apply IS correction for distribution mismatch between vLLM and training model.\"\n        },\n    )\n    vllm_importance_sampling_mode: str = field(\n        default=\"token_truncate\",\n        metadata={\n            \"help\": \"IS mode: token_truncate, token_mask, sequence_truncate, or sequence_mask.\"\n        },\n    )\n    vllm_importance_sampling_cap: float = field(\n        default=3.0,\n        metadata={\"help\": \"Cap C for IS ratio clipping/masking.\"},\n    )\n\n    # --- Off-policy sequence mask (OPSM) ---\n    off_policy_mask_threshold: float | None = field(\n        default=None,\n        metadata={\"help\": \"KL threshold for OPSM (DeepSeek-V3.2). None = disabled.\"},\n    )\n\n    # --- Bias-corrected KL ---\n    use_bias_correction_kl: bool = field(\n        default=False,\n        metadata={\"help\": \"Apply IS correction to KL divergence term.\"},\n    )\n\n\n# ---------------------------------------------------------------------------\n# Data Producer Protocol (standalone — no transformers branch needed)\n# ---------------------------------------------------------------------------\n\nlogger = logging.getLogger(__name__)\n_dp_logger = logging.getLogger(__name__ + \".data_producer\")\n\n\n@dataclass\nclass ProducerConfig:\n    \"\"\"Configuration for a :class:`DataProducer`.\n\n    Args:\n        mini_epochs: Number of training passes over each produced dataset.\n        max_rollouts: Maximum number of produce-then-train rounds (None = unlimited).\n        steps_per_generation: Optimisation steps per produced dataset before regenerating.\n        num_iterations: Number of times to reuse each generation across optimisation steps.\n        async_prefetch: Produce the next dataset in a background thread.\n        prefetch_depth: How many rollouts to queue ahead when async.\n        sync_warmup_rollouts: Initial on-policy rollouts before switching to async.\n        eval_during_produce: Switch model to eval() during produce().\n        empty_cache_before_produce: torch.cuda.empty_cache() before produce().\n        empty_cache_after_produce: torch.cuda.empty_cache() after produce().\n    \"\"\"\n\n    mini_epochs: int = 1\n    max_rollouts: int | None = None\n    steps_per_generation: int | None = None\n    num_iterations: int = 1\n    async_prefetch: bool = False\n    prefetch_depth: int = 1\n    sync_warmup_rollouts: int = 0\n    eval_during_produce: bool = True\n    empty_cache_before_produce: bool = False\n    empty_cache_after_produce: bool = False\n\n    def __post_init__(self):\n        if self.mini_epochs < 1:\n            raise ValueError(f\"mini_epochs must be >= 1, got {self.mini_epochs}\")\n        if self.max_rollouts is not None and self.max_rollouts < 1:\n            raise ValueError(\n                f\"max_rollouts must be >= 1 or None, got {self.max_rollouts}\"\n            )\n        if self.num_iterations < 1:\n            raise ValueError(f\"num_iterations must be >= 1, got {self.num_iterations}\")\n        if self.steps_per_generation is not None and self.steps_per_generation < 1:\n            raise ValueError(\n                f\"steps_per_generation must be >= 1 or None, got {self.steps_per_generation}\"\n            )\n        if self.prefetch_depth < 1:\n            raise ValueError(f\"prefetch_depth must be >= 1, got {self.prefetch_depth}\")\n        if self.sync_warmup_rollouts < 0:\n            raise ValueError(\n                f\"sync_warmup_rollouts must be >= 0, got {self.sync_warmup_rollouts}\"\n            )\n\n\nclass DataProducer(ABC):\n    \"\"\"Abstract base class for online data producers.\n\n    Subclass this and implement :meth:`produce` to supply fresh training data\n    each rollout round.\n    \"\"\"\n\n    config: ProducerConfig\n\n    @abstractmethod\n    def produce(\n        self,\n        model: Any,\n        global_step: int,\n        *,\n        processing_class: Any = None,\n        accelerator: Any = None,\n        args: Any = None,\n        **kwargs,\n    ) -> Dataset:\n        \"\"\"Generate a fresh training dataset.\"\"\"\n        ...\n\n\nclass BaseDataProducer(DataProducer):\n    \"\"\"Convenience base class with a default :class:`ProducerConfig` and lifecycle hooks.\"\"\"\n\n    def __init__(self, config: ProducerConfig | None = None):\n        self.config = config or ProducerConfig()\n\n    def on_rollout_begin(self, global_step: int) -> None:\n        \"\"\"Called before each produce() invocation.\"\"\"\n\n    def on_rollout_end(self, dataset: Dataset, global_step: int) -> None:\n        \"\"\"Called after each produce() invocation with the produced dataset.\"\"\"\n\n\nclass AsyncDataProducer:\n    \"\"\"Wraps a synchronous :class:`DataProducer` for background-thread data generation.\n\n    While the Trainer trains on the current rollout, this wrapper produces upcoming\n    datasets in a background thread.\n\n    FSDP compatibility: Background threads must NOT call cross-rank collectives\n    (gather_object, broadcast_object_list, FSDP all-gather) because the main thread\n    may be doing FSDP forward/backward concurrently, causing deadlocks. When\n    ``num_processes > 1``, only rank 0 runs BG generation; results are broadcast\n    to other ranks on the main thread when ``produce()`` is next called.\n    \"\"\"\n\n    def __init__(\n        self, inner: DataProducer, background_produce_kwargs: dict | None = None\n    ):\n        self._inner = inner\n        self._depth = inner.config.prefetch_depth\n        self._warmup_remaining = inner.config.sync_warmup_rollouts\n        self._background_kwargs = background_produce_kwargs or {}\n        self._executor = concurrent.futures.ThreadPoolExecutor(\n            max_workers=1, thread_name_prefix=\"async-producer\"\n        )\n        self._queue: deque[concurrent.futures.Future] = deque()\n        self._initialized = False\n        # Lock held by the background thread during vLLM generation.\n        # The main thread acquires this lock for weight sync to ensure\n        # merge_adapter/unmerge_adapter don't overlap with generation.\n        self._generate_lock = threading.Lock()\n        # Detected at first produce() call\n        self._num_processes: int | None = None\n        self._is_main: bool | None = None\n\n    @property\n    def config(self) -> ProducerConfig:\n        return self._inner.config\n\n    def produce(self, model: Any, global_step: int, **kwargs) -> Dataset:\n        \"\"\"Return the next dataset, blocking if the prefetch hasn't finished.\"\"\"\n        # Detect multi-process on first call\n        if self._num_processes is None:\n            accelerator = kwargs.get(\"accelerator\")\n            if accelerator is not None:\n                self._num_processes = accelerator.num_processes\n                self._is_main = accelerator.is_main_process\n            else:\n                self._num_processes = 1\n                self._is_main = True\n\n        # During warmup, produce synchronously (on-policy)\n        if self._warmup_remaining > 0:\n            self._warmup_remaining -= 1\n            _dp_logger.info(\n                f\"AsyncDataProducer: sync warmup rollout (remaining={self._warmup_remaining})\"\n            )\n            return self._inner.produce(model, global_step, **kwargs)\n\n        if not self._initialized:\n            dataset = self._inner.produce(model, global_step, **kwargs)\n            bg_kwargs = {**kwargs, **self._background_kwargs}\n            # With FSDP (multi-process), only submit BG tasks on rank 0.\n            # Non-rank-0 processes will receive data via broadcast.\n            if self._num_processes > 1:\n                bg_kwargs[\"_rank0_only\"] = True\n            for i in range(1, self._depth + 1):\n                self._queue.append(\n                    self._executor.submit(\n                        self._locked_produce, model, global_step + i, **bg_kwargs\n                    )\n                )\n            self._initialized = True\n            return dataset\n\n        # Get the pre-generated dataset from the BG thread\n        dataset = self._queue.popleft().result()\n\n        # With FSDP: BG thread only ran on rank 0. Broadcast to all ranks.\n        if self._num_processes > 1:\n            dataset = self._broadcast_dataset(dataset)\n\n        bg_kwargs = {**kwargs, **self._background_kwargs}\n        if self._num_processes > 1:\n            bg_kwargs[\"_rank0_only\"] = True\n        next_step = global_step + self._depth\n        self._queue.append(\n            self._executor.submit(self._locked_produce, model, next_step, **bg_kwargs)\n        )\n        return dataset\n\n    def _broadcast_dataset(self, dataset) -> Dataset:\n        \"\"\"Broadcast a prefetched dataset from rank 0 to all ranks (main thread).\n\n        Rank 0 has a full RolloutDataset from BG generation; other ranks have None.\n        After broadcast, tensors are moved to each rank's local device.\n        \"\"\"\n        import torch.distributed as dist\n\n        if not dist.is_initialized():\n            return dataset\n\n        # Rank 0 sends _data dict; others receive it\n        obj_list = [dataset._data if self._is_main else None]\n        dist.broadcast_object_list(obj_list, src=0)\n\n        data: dict[str, Any] = obj_list[0]  # type: ignore[assignment]\n\n        # Move tensors to local device (broadcast_object_list deserializes to CPU)\n        accelerator = self._inner._trainer.accelerator  # type: ignore[attr-defined]\n        device = accelerator.device\n        for key, val in data.items():\n            if isinstance(val, torch.Tensor) and val.device != device:\n                data[key] = val.to(device)\n\n        if not self._is_main:\n            from axolotl.core.trainers.grpo.async_trainer import RolloutDataset\n\n            dataset = RolloutDataset(data)\n        else:\n            # Rank 0 already has the dataset, but update _data with device-moved tensors\n            dataset._data = data\n        return dataset\n\n    def _locked_produce(self, model: Any, global_step: int, **kwargs) -> Dataset:\n        \"\"\"Run produce while holding the generate lock.\"\"\"\n        with self._generate_lock:\n            return self._inner.produce(model, global_step, **kwargs)\n\n    def on_rollout_begin(self, global_step: int) -> None:\n        if hasattr(self._inner, \"on_rollout_begin\"):\n            self._inner.on_rollout_begin(global_step)\n\n    def on_rollout_end(self, dataset: Dataset, global_step: int) -> None:\n        if hasattr(self._inner, \"on_rollout_end\"):\n            self._inner.on_rollout_end(dataset, global_step)\n\n    def shutdown(self) -> None:\n        \"\"\"Shut down the background thread pool and cancel pending futures.\"\"\"\n        for future in self._queue:\n            future.cancel()\n        self._queue.clear()\n        self._executor.shutdown(wait=False)\n\n\nclass DataProducerCallback:\n    \"\"\"Marker class: if a DataProducer also inherits from this, the Trainer will\n    automatically register it as a callback.\"\"\"\n\n    pass\n\n\n# ---------------------------------------------------------------------------\n# RolloutDataset + GRPODataProducer\n# ---------------------------------------------------------------------------\n\n\nclass RolloutDataset(Dataset):\n    \"\"\"A Dataset wrapping the output dict from _generate_and_score_completions.\n\n    Per-sample tensors are sliced by index; shared metadata is passed through.\n    \"\"\"\n\n    _ALWAYS_SHARED = frozenset(\n        {\"num_items_in_batch\", \"_pending_policy_logps\", \"_rank0_only\"}\n    )\n\n    def __init__(self, data: dict[str, Any]):\n        self._data = data\n        self._shared_keys: set[str] = set()\n        self._sample_keys: set[str] = set()\n\n        for key, val in data.items():\n            if key in self._ALWAYS_SHARED:\n                self._shared_keys.add(key)\n            elif not isinstance(val, torch.Tensor):\n                self._shared_keys.add(key)\n            elif val.dim() == 0:\n                self._shared_keys.add(key)\n            else:\n                self._sample_keys.add(key)\n\n        self._num_samples = 0\n        for key in self._sample_keys:\n            n = data[key].size(0)\n            if self._num_samples == 0:\n                self._num_samples = n\n            elif n != self._num_samples:\n                raise ValueError(\n                    f\"Inconsistent sample count: key '{key}' has {n}, expected {self._num_samples}\"\n                )\n        if self._num_samples == 0:\n            raise ValueError(\"No per-sample tensors found in rollout data\")\n\n    def __len__(self) -> int:\n        return self._num_samples\n\n    def __getitem__(self, idx: int) -> dict[str, Any]:\n        item: dict[str, Any] = {}\n        for key in self._sample_keys:\n            item[key] = self._data[key][idx]\n        for key in self._shared_keys:\n            item[key] = self._data[key]\n        return item\n\n\ndef make_rollout_collator(shared_keys: set[str]):\n    \"\"\"Return a collator that stacks per-sample tensors and passes shared keys through.\"\"\"\n\n    def _collate(batch: list[dict[str, Any]]) -> dict[str, Any]:\n        result: dict[str, Any] = {}\n        for key in batch[0]:\n            if key in shared_keys:\n                result[key] = batch[0][key]\n            else:\n                values = [item[key] for item in batch]\n                if isinstance(values[0], torch.Tensor):\n                    result[key] = torch.stack(values)\n                else:\n                    result[key] = values\n        return result\n\n    return _collate\n\n\nclass GRPODataProducer(BaseDataProducer):\n    \"\"\"Produces GRPO training rollouts using the trainer's generation pipeline.\n\n    Created before Trainer.__init__ completes; the trainer reference is injected\n    later via set_trainer().\n    \"\"\"\n\n    def __init__(\n        self,\n        config: ProducerConfig,\n        prompt_dataset,\n        *,\n        num_generations: int,\n        generation_batch_size: int,\n        train_batch_size: int,\n        steps_per_generation: int,\n        shuffle_dataset: bool,\n        seed: int,\n    ):\n        super().__init__(config)\n        self._dataset = prompt_dataset\n        self._num_generations = num_generations\n        self._generation_batch_size = generation_batch_size\n        self._train_batch_size = train_batch_size\n        self._steps_per_generation = steps_per_generation\n        self._shuffle_dataset = shuffle_dataset\n        self._seed = seed\n        self._trainer: Any = None\n        self._prompt_dl: Any = None\n        self._prompt_iter: Any = None\n\n    def set_trainer(self, trainer) -> None:\n        \"\"\"Inject the live trainer reference and create the prompt DataLoader.\"\"\"\n        self._trainer = trainer\n        self._init_prompt_dataloader()\n\n    def _init_prompt_dataloader(self) -> None:\n        from functools import partial\n\n        from transformers.trainer_utils import seed_worker\n\n        trainer = self._trainer\n        sampler = RepeatSampler(\n            data_source=self._dataset,\n            mini_repeat_count=self._num_generations,\n            batch_size=self._generation_batch_size // self._num_generations,\n            repeat_count=1,\n            shuffle=self._shuffle_dataset,\n            seed=self._seed,\n        )\n\n        # Use identity collator (same as stock GRPOTrainer)\n        def _identity(x):\n            return x\n\n        dl = DataLoader(\n            self._dataset,\n            batch_size=self._train_batch_size * self._steps_per_generation,\n            sampler=sampler,\n            collate_fn=_identity,\n            num_workers=trainer.args.dataloader_num_workers,\n            pin_memory=trainer.args.dataloader_pin_memory,\n            persistent_workers=trainer.args.dataloader_persistent_workers,\n            worker_init_fn=partial(\n                seed_worker,\n                num_workers=trainer.args.dataloader_num_workers,\n                rank=trainer.args.process_index,\n            ),\n        )\n        self._prompt_dl = trainer.accelerator.prepare(dl)\n\n        # Don't let accelerator track this dataloader\n        acc_dls = trainer.accelerator._dataloaders\n        if self._prompt_dl in acc_dls:\n            acc_dls.remove(self._prompt_dl)\n\n        self._prompt_iter = iter(self._prompt_dl)\n\n    def produce(\n        self,\n        model: Any,\n        global_step: int,\n        *,\n        skip_policy_logps: bool = False,\n        processing_class: Any = None,\n        accelerator: Any = None,\n        args: Any = None,\n        _rank0_only: bool = False,\n        **kwargs,\n    ) -> RolloutDataset | None:\n        \"\"\"Generate a fresh GRPO training rollout.\"\"\"\n        is_main = self._trainer.accelerator.is_main_process\n\n        # FSDP rank0-only mode: non-rank-0 returns None (broadcast fills it later)\n        if _rank0_only and not is_main:\n            return None\n\n        try:\n            inputs = next(self._prompt_iter)\n        except StopIteration:\n            self._prompt_iter = iter(self._prompt_dl)\n            inputs = next(self._prompt_iter)\n\n        if skip_policy_logps:\n            # Async path: use _generate_only (generation without scoring) which\n            # works on stock TRL (no skip_policy_logps parameter needed).\n            output = self._trainer._generate_only(inputs, rank0_only=_rank0_only)\n        else:\n            # Sync path: full generation + scoring\n            output = self._trainer._generate_and_score_completions(inputs)\n\n            # Strip non-sequence metadata before shuffling\n            metadata = {}\n            for key in list(output.keys()):\n                val = output[key]\n                if not isinstance(val, (torch.Tensor, list)):\n                    metadata[key] = output.pop(key)\n                elif isinstance(val, torch.Tensor) and val.dim() == 0:\n                    metadata[key] = output.pop(key)\n\n            output = shuffle_sequence_dict(output)\n            output.update(metadata)\n\n        return RolloutDataset(output)\n\n\n# ---------------------------------------------------------------------------\n# Trainer\n# ---------------------------------------------------------------------------\n\n\nclass AsyncGRPOTrainer(GRPOTrainer):\n    \"\"\"GRPOTrainer with async prefetch, streaming scoring, and IS correction.\n\n    Drop-in replacement: pass ``AsyncGRPOConfig`` as ``args`` and use this trainer\n    instead of ``GRPOTrainer``.\n    \"\"\"\n\n    def __init__(self, *args, **kwargs):\n        # When using native LoRA sync, skip the NCCL communicator init in VLLMGeneration.\n        # The communicator is not needed because weight sync happens via filesystem + HTTP,\n        # and it fails when vLLM and a trainer rank share the same CUDA device.\n        training_args = kwargs.get(\"args\") or (args[1] if len(args) > 1 else None)\n        if training_args is not None and getattr(\n            training_args, \"vllm_lora_sync\", False\n        ):\n            from trl.generation.vllm_generation import VLLMGeneration\n\n            _orig_init_vllm = VLLMGeneration._init_vllm\n\n            def _init_vllm_no_communicator(self_vllm):\n                \"\"\"Init vLLM client without NCCL communicator (LoRA sync uses filesystem).\"\"\"\n                if self_vllm.mode == \"server\" and self_vllm.accelerator.is_main_process:\n                    from trl.generation.vllm_client import VLLMClient\n\n                    if self_vllm.server_base_url is not None:\n                        base_url = self_vllm.server_base_url\n                    else:\n                        base_url = (\n                            f\"http://{self_vllm.server_host}:{self_vllm.server_port}\"\n                        )\n                    self_vllm.vllm_client = VLLMClient(\n                        base_url=base_url,\n                        group_port=self_vllm.group_port,\n                        connection_timeout=self_vllm.server_timeout,\n                    )\n                    # Deliberately skip init_communicator — no NCCL needed\n                elif self_vllm.mode != \"server\":\n                    _orig_init_vllm(self_vllm)\n\n            VLLMGeneration._init_vllm = _init_vllm_no_communicator\n\n        super().__init__(*args, **kwargs)\n\n        # FP8 models: zero out the pad token embedding so that padding\n        # positions have zero hidden states throughout the network.\n        # FP8 linear layers produce NaN on non-zero inputs at masked\n        # positions (the Triton fp8 matmul kernel can't handle the\n        # extreme values that accumulate at unattended positions).\n        self._zero_pad_embedding_for_fp8()\n\n        # Ensure custom attributes exist (stock GRPOTrainer.__init__ may not set them).\n        for attr, cfg_key, default in [\n            (\n                \"vllm_importance_sampling_correction\",\n                \"vllm_importance_sampling_correction\",\n                True,\n            ),\n            (\n                \"vllm_importance_sampling_mode\",\n                \"vllm_importance_sampling_mode\",\n                \"token_truncate\",\n            ),\n            (\"vllm_importance_sampling_cap\", \"vllm_importance_sampling_cap\", 3.0),\n            (\"off_policy_mask_threshold\", \"off_policy_mask_threshold\", None),\n        ]:\n            if not hasattr(self, attr):\n                setattr(self, attr, getattr(self.args, cfg_key, default))\n\n        # Async state\n        self._async_queue: queue.Queue | None = None\n        self._executor: concurrent.futures.ThreadPoolExecutor | None = None\n        self._prompt_iter = None\n        self._last_synced_step = -1\n        self._buffered_inputs: list | None = None  # override stock attr\n\n        # Data producer (the proper architecture for async generation)\n        self.data_producer = None\n        if getattr(self.args, \"use_data_producer\", False):\n            self.data_producer = self._create_data_producer(\n                kwargs[\"args\"], kwargs[\"train_dataset\"]\n            )\n\n        if self.args.async_prefetch and self.data_producer is None:\n            # Legacy path: direct _prepare_inputs override without data producer\n            self._setup_async()\n\n    def _create_data_producer(self, args, train_dataset):\n        \"\"\"Create and return the GRPODataProducer (possibly wrapped in AsyncDataProducer).\"\"\"\n        producer_config = ProducerConfig(\n            mini_epochs=args.num_iterations,\n            max_rollouts=None,\n            eval_during_produce=False,\n            empty_cache_before_produce=True,\n            empty_cache_after_produce=True,\n            async_prefetch=args.async_prefetch,\n            prefetch_depth=args.prefetch_depth,\n        )\n        data_producer = GRPODataProducer(\n            config=producer_config,\n            prompt_dataset=train_dataset,\n            num_generations=self.num_generations,\n            generation_batch_size=args.generation_batch_size,\n            train_batch_size=args.per_device_train_batch_size,\n            steps_per_generation=args.steps_per_generation,\n            shuffle_dataset=getattr(self, \"shuffle_dataset\", True),\n            seed=args.seed,\n        )\n        data_producer.set_trainer(self)\n\n        if args.async_prefetch:\n            data_producer = AsyncDataProducer(\n                data_producer,\n                background_produce_kwargs={\"skip_policy_logps\": True},\n            )\n        return data_producer\n\n    # ------------------------------------------------------------------\n    # Async setup / teardown\n    # ------------------------------------------------------------------\n\n    def _setup_async(self):\n        \"\"\"Create background thread pool, prompt iterator, and pre-fill the async queue.\"\"\"\n        gen_batch_size = getattr(\n            self.args,\n            \"generation_batch_size\",\n            self._train_batch_size * self.args.gradient_accumulation_steps,\n        )\n        # RepeatSampler groups prompts with num_generations repetitions each.\n        # DataLoader batches the yielded indices into generation-sized batches.\n        sampler = RepeatSampler(\n            data_source=self.train_dataset,\n            mini_repeat_count=self.num_generations,\n            batch_size=gen_batch_size // self.num_generations,\n            repeat_count=10_000,  # effectively infinite\n            shuffle=True,\n            seed=self.args.seed,\n        )\n        self._prompt_dataloader = DataLoader(\n            self.train_dataset,\n            batch_size=gen_batch_size,\n            sampler=sampler,\n            collate_fn=self.data_collator,\n            num_workers=0,\n        )\n        self._prompt_iter = iter(self._prompt_dataloader)\n        self._async_queue = queue.Queue(maxsize=self.args.prefetch_depth)\n        self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)\n\n        # Pre-submit generations to fill the queue\n        for _ in range(self.args.prefetch_depth):\n            self._submit_generation()\n\n        atexit.register(self._shutdown_async)\n\n    def _shutdown_async(self):\n        if self._executor is not None:\n            self._executor.shutdown(wait=False, cancel_futures=True)\n            self._executor = None\n\n    def _submit_generation(self):\n        \"\"\"Submit the next background generation job.\"\"\"\n        batch = next(self._prompt_iter)\n        future = self._executor.submit(self._generate_only, batch)\n        self._async_queue.put(future)\n\n    # ------------------------------------------------------------------\n    # Weight sync\n    # ------------------------------------------------------------------\n\n    def _sync_peft_weights_no_merge(self):\n        \"\"\"Thread-safe weight sync: compute merged LoRA weights without in-place modification.\n\n        Required for FP8 models where merge_adapter() fails (addmm not implemented\n        for Float8), and also safe for concurrent use since it never modifies base\n        weights in-place.\n        \"\"\"\n        model = self.vllm_generation.model\n        accelerator = self.vllm_generation.accelerator\n        vllm_client = self.vllm_generation.vllm_client\n        fix_name = self.vllm_generation._fix_param_name_to_vllm\n\n        if not (self.vllm_generation.mode == \"server\" and accelerator.is_main_process):\n            return\n\n        # Build lookup: module_path -> (A, B, scaling) for all active LoRA layers\n        lora_info = {}\n        for mod_name, module in model.base_model.model.named_modules():\n            if not hasattr(module, \"lora_A\") or not hasattr(module, \"active_adapters\"):\n                continue\n            active = module.active_adapters[0]\n            if active not in module.lora_A:\n                continue\n            lora_info[mod_name] = (\n                module.lora_A[active].weight.data,\n                module.lora_B[active].weight.data,\n                module.scaling[active],\n            )\n\n        # Build lookup for FP8 scale_inv parameters (needed for dequantization)\n        scale_inv_lookup = {}\n        for pname, pparam in model.named_parameters():\n            if \"weight_scale_inv\" in pname:\n                # Map weight name -> scale_inv tensor\n                weight_name = pname.replace(\".weight_scale_inv\", \".weight\")\n                scale_inv_lookup[weight_name] = pparam.data\n\n        # Iterate all parameters, computing merged weights for LoRA layers.\n        # Skip LoRA-specific params and FP8 scale params (scales will be\n        # recomputed by vLLM when it receives the merged bf16 weight).\n        params_to_sync = []\n        for name, param in model.named_parameters():\n            vllm_name = name.removeprefix(\"base_model.model.\").replace(\n                \".base_layer\", \"\"\n            )\n            if model.prefix in vllm_name:\n                continue\n            if \"original_module\" in vllm_name:\n                continue\n            # Skip FP8 quantization scale parameters - they are recomputed\n            # on the vLLM side when we update the weight itself\n            if \"weight_scale_inv\" in vllm_name or \"input_scale\" in vllm_name:\n                continue\n            vllm_name = fix_name(vllm_name, extra_prefixes=[\"modules_to_save.default.\"])\n\n            data = param.data\n            compute_dtype = torch.bfloat16\n\n            if vllm_name.endswith(\".weight\"):\n                # Dequantize FP8 weights before merging\n                if data.dtype == torch.float8_e4m3fn and name in scale_inv_lookup:\n                    scale_inv = scale_inv_lookup[name]\n                    # Block dequantization: weight * scale_inv (with broadcasting)\n                    fp8_bf16 = data.to(compute_dtype)\n                    if scale_inv.dim() == 2 and fp8_bf16.dim() == 2:\n                        # Block-quantized: scale_inv shape (rows/block, cols/block)\n                        sr, sc = scale_inv.shape\n                        br = fp8_bf16.shape[0] // sr  # block height\n                        bc = fp8_bf16.shape[1] // sc  # block width\n                        # Reshape → multiply by block scale → reshape back\n                        data = (\n                            fp8_bf16.reshape(sr, br, sc, bc)\n                            * scale_inv[:, None, :, None].to(compute_dtype)\n                        ).reshape(fp8_bf16.shape)\n                    elif scale_inv.dim() <= 1:\n                        # Per-tensor or per-channel scale\n                        data = fp8_bf16 * scale_inv.to(compute_dtype)\n                    else:\n                        data = fp8_bf16\n                elif data.dtype == torch.float8_e4m3fn:\n                    # FP8 but no scale found - just cast (lossy)\n                    data = data.to(compute_dtype)\n\n                mod_path = vllm_name[: -len(\".weight\")]\n                if mod_path in lora_info:\n                    A, B, s = lora_info[mod_path]\n                    merged = data.to(compute_dtype) + s * (\n                        B.to(compute_dtype) @ A.to(compute_dtype)\n                    )\n                    data = merged\n\n            params_to_sync.append((vllm_name, data))\n\n        # Batch sync all params in one HTTP+NCCL call (vs individual calls)\n        if params_to_sync:\n            vllm_client.batch_update_named_params(params_to_sync)\n\n        # Reset prefix cache after weight update\n        vllm_client.reset_prefix_cache()\n\n    def _sync_lora_adapter(self):\n        \"\"\"Sync LoRA adapter to vLLM via filesystem (native LoRA mode).\n\n        Saves the PEFT adapter to a temp directory and POSTs the path to vLLM's\n        /set_lora_adapter/ endpoint. vLLM loads the adapter natively using Punica\n        kernels, avoiding the need to merge weights and NCCL-broadcast the full model.\n\n        Syncs only the LoRA adapter weights via filesystem instead of the full merged model via NCCL.\n\n        FSDP/DeepSpeed: All ranks must participate in the state_dict gather.\n        accelerator.get_state_dict() handles this (FSDP uses FullStateDictConfig\n        with rank0_only=True). Only rank 0 gets the full dict, writes files, and\n        does the HTTP POST.\n        \"\"\"\n        import os\n        import tempfile\n\n        accelerator = self.vllm_generation.accelerator\n        model = self.vllm_generation.model\n\n        if self.vllm_generation.mode != \"server\":\n            return\n\n        is_main = accelerator.is_main_process\n\n        # Increment adapter version (all ranks, kept in sync)\n        if not hasattr(self, \"_lora_sync_version\"):\n            self._lora_sync_version = 0\n            if is_main:\n                self._lora_sync_dir = tempfile.mkdtemp(prefix=\"lora_sync_\")\n            else:\n                self._lora_sync_dir = None\n            # Broadcast sync dir from rank 0 to all ranks\n            if accelerator.num_processes > 1:\n                import torch.distributed as dist\n\n                if dist.is_initialized():\n                    obj_list = [self._lora_sync_dir]\n                    dist.broadcast_object_list(obj_list, src=0)\n                    self._lora_sync_dir = obj_list[0]\n        self._lora_sync_version += 1\n\n        adapter_path = os.path.join(self._lora_sync_dir, f\"v{self._lora_sync_version}\")\n\n        # Gather state dict from all ranks (FSDP/DeepSpeed gather, rank0_only)\n        # All ranks must participate even though only rank 0 gets the result.\n        # Use self.model_wrapped (the DeepSpeed/FSDP engine) for get_state_dict,\n        # since it has the necessary hooks (e.g. zero_gather_16bit_weights_on_model_save).\n        # self.vllm_generation.model is the unwrapped PEFT model which lacks these.\n        wrapped_model = getattr(self, \"model_wrapped\", model)\n        state_dict = accelerator.get_state_dict(wrapped_model)\n\n        if is_main:\n            # Unwrap to access PEFT's save_pretrained\n            unwrapped = accelerator.unwrap_model(model)\n            unwrapped.save_pretrained(adapter_path, state_dict=state_dict)\n\n            import requests\n\n            vllm_client = self.vllm_generation.vllm_client\n            url = f\"{vllm_client.base_url}/set_lora_adapter/\"\n            response = requests.post(\n                url,\n                json={\n                    \"lora_name\": \"active_lora\",\n                    \"lora_int_id\": self._lora_sync_version,\n                    \"lora_path\": adapter_path,\n                },\n                timeout=30,\n            )\n            if response.status_code != 200:\n                logger.warning(\n                    \"Failed to set LoRA adapter: %s %s\",\n                    response.status_code,\n                    response.text,\n                )\n                return\n\n            # Reset prefix cache after adapter update\n            vllm_client.reset_prefix_cache()\n\n            # Clean up old adapter versions (keep only current)\n            if self._lora_sync_version > 1:\n                old_path = os.path.join(\n                    self._lora_sync_dir, f\"v{self._lora_sync_version - 1}\"\n                )\n                if os.path.exists(old_path):\n                    import shutil\n\n                    shutil.rmtree(old_path, ignore_errors=True)\n\n            logger.info(\n                \"Synced LoRA adapter v%d to vLLM (%s)\",\n                self._lora_sync_version,\n                adapter_path,\n            )\n\n        # Barrier to ensure all ranks complete before resuming forward passes.\n        # Without this, rank 1 may start a forward pass (triggering FSDP unshard)\n        # while rank 0 is still doing save_pretrained, causing FSDP all-gather deadlock.\n        if accelerator.num_processes > 1:\n            import torch.distributed as dist\n\n            if dist.is_initialized():\n                dist.barrier()\n\n    def _maybe_sync_vllm_weights(self):\n        \"\"\"Sync model weights to vLLM if the interval has elapsed.\n\n        Dispatches to one of three strategies:\n        - vllm_lora_sync: saves adapter to filesystem, vLLM loads natively\n        - PEFT no-merge: computes merged weights as new tensors, NCCL broadcast\n        - Non-PEFT: stock sync_weights via merge_adapter + NCCL\n        \"\"\"\n        if not (self.use_vllm and self.args.async_prefetch):\n            return\n        step = self.state.global_step\n        interval = self.args.vllm_sync_interval\n        if step != self._last_synced_step and step % interval == 0:\n            if getattr(self.args, \"vllm_lora_sync\", False):\n                if step == 0:\n                    logger.info(\"Skipping LoRA sync at step 0 (no training yet)\")\n                    self._last_synced_step = step\n                    return\n                # Native LoRA sync: save adapter to filesystem, vLLM loads it directly\n                self._sync_lora_adapter()\n            else:\n                from accelerate.utils import is_peft_model\n\n                use_no_merge = is_peft_model(self.vllm_generation.model)\n\n                if use_no_merge:\n                    # No-merge sync: computes merged weights as new tensors\n                    # (doesn't modify base weights in-place), so it's safe to\n                    # run concurrently with BG generation — no lock needed.\n                    self._sync_peft_weights_no_merge()\n                else:\n                    # Non-PEFT: use stock sync (acquires lock to avoid overlap)\n                    if self.data_producer is not None and hasattr(\n                        self.data_producer, \"_generate_lock\"\n                    ):\n                        with self.data_producer._generate_lock:\n                            self.vllm_generation.sync_weights()\n                    elif self._async_queue is not None:\n                        pending = list(self._async_queue.queue)\n                        for f in pending:\n                            if isinstance(f, concurrent.futures.Future):\n                                f.result()\n                        self.vllm_generation.sync_weights()\n                    else:\n                        self.vllm_generation.sync_weights()\n            self._last_synced_step = step\n\n    def _zero_pad_embedding_for_fp8(self):\n        \"\"\"Zero out the pad token embedding for FP8 models.\n\n        FP8 linear layers produce NaN when processing positions with\n        attention_mask=0 (the hidden states at those positions have\n        unconstrained values that overflow FP8 range during\n        quantization). By setting the pad token embedding to zeros,\n        padding positions start with zero hidden states and stay zero\n        through masked attention, preventing NaN from FP8 matmul.\n        \"\"\"\n        model = self.accelerator.unwrap_model(self.model)\n        # Check if model has FP8 weights\n        has_fp8 = any(\n            p.dtype == torch.float8_e4m3fn\n            for p in model.parameters()\n            if not p.requires_grad\n        )\n        if not has_fp8:\n            return\n\n        # Find the embedding layer\n        if hasattr(model, \"model\") and hasattr(model.model, \"embed_tokens\"):\n            embed = model.model.embed_tokens\n        elif hasattr(model, \"base_model\") and hasattr(model.base_model, \"model\"):\n            m = model.base_model.model\n            if hasattr(m, \"model\") and hasattr(m.model, \"embed_tokens\"):\n                embed = m.model.embed_tokens\n            else:\n                return\n        else:\n            return\n\n        pad_id = self.processing_class.pad_token_id\n        if pad_id is not None and pad_id < embed.weight.shape[0]:\n            with torch.no_grad():\n                embed.weight.data[pad_id].zero_()\n            import logging\n\n            logging.getLogger(\"async_grpo\").info(\n                f\"Zeroed pad token embedding (id={pad_id}) for FP8 NaN prevention\"\n            )\n\n    # ------------------------------------------------------------------\n    # Background-thread generation (no scoring)\n    # ------------------------------------------------------------------\n\n    def _generate_single_turn(self, prompts, **kwargs):\n        \"\"\"Override to prevent weight sync from background thread and to use\n        no-merge sync for PEFT models (FP8 models can't merge_adapter).\"\"\"\n        is_bg = threading.current_thread() is not threading.main_thread()\n        saved_step = None\n\n        if is_bg and self.use_vllm:\n            # Trick: match _last_loaded_step so the stock sync check is a no-op\n            saved_step = getattr(self, \"_last_loaded_step\", None)\n            self._last_loaded_step = self.state.global_step\n\n        # Permanently replace vllm_generation.sync_weights with our custom\n        # sync to avoid merge_adapter (fails on FP8 / races with training).\n        # For LoRA sync mode, make it a no-op here since _maybe_sync_vllm_weights\n        # handles the sync with proper interval tracking.\n        if not getattr(self, \"_patched_sync_weights\", False):\n            if self.use_vllm and hasattr(self, \"vllm_generation\"):\n                if getattr(self.args, \"vllm_lora_sync\", False):\n                    # No-op: LoRA sync is driven by _maybe_sync_vllm_weights\n                    self.vllm_generation.sync_weights = lambda: None\n                    self._patched_sync_weights = True\n                else:\n                    from accelerate.utils import is_peft_model\n\n                    if is_peft_model(self.vllm_generation.model):\n\n                        def _no_merge_sync():\n                            self._sync_peft_weights_no_merge()\n\n                        self.vllm_generation.sync_weights = _no_merge_sync\n                        self._patched_sync_weights = True\n\n        try:\n            return super()._generate_single_turn(prompts, **kwargs)\n        finally:\n            if saved_step is not None:\n                self._last_loaded_step = saved_step\n\n    def _generate_rank0_only(self, prompts):\n        \"\"\"Generate using vLLM directly on rank 0 without cross-rank collectives.\n\n        Called from BG thread in FSDP mode. Bypasses ``gather_object`` /\n        ``broadcast_object_list`` since the main thread may be running FSDP\n        collectives concurrently.\n\n        Returns the same tuple as ``_generate``.\n        \"\"\"\n        import copy\n\n        prompts = copy.deepcopy(prompts)\n\n        # Duplicate prompts for num_generations (same as TRL's gather+unique pattern)\n        num_generations = self.num_generations\n        unique_prompts = prompts[::num_generations]\n\n        # Build sampling params\n        vg = self.vllm_generation\n        sampling_params = {\n            \"n\": num_generations,\n            \"repetition_penalty\": vg.repetition_penalty,\n            \"temperature\": vg.temperature,\n            \"top_p\": vg.top_p,\n            \"top_k\": vg.top_k,\n            \"min_p\": 0.0 if vg.min_p is None else vg.min_p,\n            \"max_tokens\": vg.max_completion_length,\n            \"logprobs\": vg.logprobs,\n            \"structured_outputs_regex\": vg.structured_outputs_regex,\n            \"generation_kwargs\": vg.generation_kwargs,\n        }\n\n        # Call vLLM directly (no collectives)\n        from trl.data_utils import is_conversational\n\n        if is_conversational({\"prompt\": unique_prompts[0]}):\n            output = vg.vllm_client.chat(\n                messages=unique_prompts,\n                **sampling_params,\n                chat_template_kwargs=vg.chat_template_kwargs,\n                tools=vg.tools,\n                chat_template=vg.chat_template,\n            )\n        else:\n            output = vg.vllm_client.generate(prompts=unique_prompts, **sampling_params)\n\n        # vLLM returns 1 prompt_ids per unique prompt, but num_generations completion_ids.\n        # Duplicate prompt_ids to match completions (one per generation).\n        raw_prompt_ids = output[\"prompt_ids\"]\n        prompt_ids = [pid for pid in raw_prompt_ids for _ in range(num_generations)]\n        completion_ids = output[\"completion_ids\"]\n        logprobs_raw = output[\"logprobs\"]\n        extra_fields = {\n            k: v\n            for k, v in output.items()\n            if k\n            not in {\"prompt_ids\", \"completion_ids\", \"logprobs\", \"logprob_token_ids\"}\n        }\n\n        # Extract top-1 logprob per token\n        logprobs = [[lp[0] for lp in seq] for seq in logprobs_raw]\n\n        # Decode completions\n        if is_conversational({\"prompt\": prompts[0]}):\n            contents = self.processing_class.batch_decode(\n                completion_ids, skip_special_tokens=True\n            )\n            completions = [[{\"role\": \"assistant\", \"content\": c}] for c in contents]\n        else:\n            completions = self.processing_class.batch_decode(\n                completion_ids, skip_special_tokens=True\n            )\n\n        tool_mask = extra_fields.pop(\"env_mask\", None)\n\n        # Compute total completion tokens locally (no gather)\n        total_completion_tokens = sum(len(ids) for ids in completion_ids)\n\n        return (\n            prompt_ids,\n            completion_ids,\n            tool_mask,\n            completions,\n            total_completion_tokens,\n            logprobs,\n            extra_fields,\n        )\n\n    def _generate_only(self, inputs, rank0_only=False):\n        \"\"\"Generate completions without scoring.  Runs on background thread.\n\n        Mirrors the first half of ``_generate_and_score_completions`` (prompt\n        extraction → vLLM generation → tensor padding) and returns a deferred\n        output dict for main-thread scoring.\n\n        When ``rank0_only=True`` (FSDP mode), bypasses ``gather_object`` /\n        ``broadcast_object_list`` collectives and calls vLLM directly on rank 0.\n        Results are broadcast to other ranks on the main thread later.\n\n        Args:\n            inputs: list of dicts (one per sample), as yielded by the DataLoader\n                    with ``identity`` collate_fn.\n        \"\"\"\n        device = self.accelerator.device\n\n        prompts = [x[\"prompt\"] for x in inputs]\n\n        # --- Handle images (multimodal) ---\n        if \"images\" in inputs[0]:\n            images = [ex.get(\"images\") for ex in inputs]\n        elif \"image\" in inputs[0]:\n            images = [\n                [ex.get(\"image\")] if ex.get(\"image\") is not None else None\n                for ex in inputs\n            ]\n        else:\n            images = None\n        if images is not None and all(img == [] for img in images):\n            images = None\n\n        if images is not None:\n            if not is_conversational(inputs[0]):\n                raise ValueError(\"Multimodal training requires conversational prompts.\")\n            prompts = [\n                prepare_multimodal_messages(p, il)\n                for p, il in zip(prompts, images, strict=True)\n            ]\n\n        # --- Generate completions ---\n        if rank0_only:\n            # FSDP mode: call vLLM directly without cross-rank collectives\n            (\n                prompt_ids_list,\n                completion_ids_list,\n                tool_mask_list,\n                completions,\n                num_items_in_batch,\n                sampling_per_token_logps_list,\n                extra_fields,\n            ) = self._generate_rank0_only(prompts)\n        else:\n            (\n                prompt_ids_list,\n                completion_ids_list,\n                tool_mask_list,\n                completions,\n                num_items_in_batch,\n                sampling_per_token_logps_list,\n                extra_fields,\n            ) = self._generate(prompts)\n            # _generate gathers prompts from all ranks internally. Gather inputs\n            # to match the full-batch output size.\n            if self.accelerator.num_processes > 1:\n                from accelerate.utils import gather_object\n\n                inputs = gather_object(inputs)\n                prompts = [x[\"prompt\"] for x in inputs]\n\n        # --- Pad to tensors ---\n        prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list]\n        prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids]\n        prompt_ids = pad(\n            prompt_ids, padding_value=self.pad_token_id, padding_side=\"left\"\n        )\n        prompt_mask = pad(prompt_mask, padding_value=0, padding_side=\"left\")\n\n        completion_ids = [\n            torch.tensor(ids, device=device) for ids in completion_ids_list\n        ]\n        completion_mask = [\n            torch.ones_like(ids, dtype=torch.long) for ids in completion_ids\n        ]\n        completion_ids = pad(\n            completion_ids, padding_value=self.pad_token_id, padding_side=\"right\"\n        )\n        completion_mask = pad(completion_mask, padding_value=0, padding_side=\"right\")\n\n        if sampling_per_token_logps_list is not None:\n            sampling_logps = [\n                torch.tensor(lp, device=device) for lp in sampling_per_token_logps_list\n            ]\n            sampling_per_token_logps = pad(\n                sampling_logps, padding_value=0.0, padding_side=\"right\"\n            )\n        else:\n            sampling_per_token_logps = None\n\n        if tool_mask_list is not None:\n            tool_mask = [torch.tensor(m, device=device) for m in tool_mask_list]\n            tool_mask = pad(tool_mask, padding_value=1, padding_side=\"right\")\n        else:\n            tool_mask = None\n\n        # --- Mask truncated completions ---\n        if self.mask_truncated_completions:\n            eos_and_pad = [self.eos_token_id, self.pad_token_id]\n            is_trunc = torch.tensor(\n                [ids[-1] not in eos_and_pad for ids in completion_ids_list],\n                device=device,\n            )\n            completion_mask = completion_mask * (~is_trunc).unsqueeze(1).int()\n            if tool_mask is not None:\n                tool_mask = tool_mask * (~is_trunc).unsqueeze(1).int()\n\n        # --- Multimodal forward kwargs ---\n        num_images = [len(il) for il in images] if images is not None else None\n        if images is not None:\n            prompts_text = [\n                apply_chat_template(\n                    {\"prompt\": p},\n                    self.processing_class,\n                    tools=self.tools,\n                    **self.chat_template_kwargs,\n                )[\"prompt\"]\n                for p in prompts\n            ]\n            prompt_inputs = self.processing_class(\n                images=images, text=prompts_text, padding=True, return_tensors=\"pt\"\n            )\n            forward_kwargs = {\n                k: v.to(device) if isinstance(v, torch.Tensor) else v\n                for k, v in prompt_inputs.items()\n                if k not in (\"input_ids\", \"attention_mask\")\n            }\n        else:\n            forward_kwargs = {}\n\n        # Extend token_type_ids / mm_token_type_ids for completion tokens\n        for ttid_key in (\"token_type_ids\", \"mm_token_type_ids\"):\n            if ttid_key in forward_kwargs:\n                tt = forward_kwargs[ttid_key]\n                forward_kwargs[ttid_key] = torch.cat(\n                    [tt, tt.new_zeros(completion_ids.shape)], dim=1\n                )\n\n        # Merge extra_fields from rollout_func into inputs\n        if extra_fields:\n            for i, inp in enumerate(inputs):\n                for key, values in extra_fields.items():\n                    if isinstance(values, list) and i < len(values):\n                        inp[key] = values[i]\n                    elif not isinstance(values, list):\n                        inp[key] = values\n\n        # No explicit CUDA sync needed here — both threads share the\n        # default stream, so operations are naturally ordered.\n\n        # --- Construct deferred output ---\n        output = {\n            \"prompt_ids\": prompt_ids,\n            \"prompt_mask\": prompt_mask,\n            \"completion_ids\": completion_ids,\n            \"completion_mask\": completion_mask,\n            \"num_items_in_batch\": num_items_in_batch,\n            \"advantages\": torch.zeros(completion_ids.size(0), device=device),\n            # Sentinels for deferred scoring\n            \"_pending_policy_logps\": True,\n            \"_deferred_inputs\": inputs,\n            \"_deferred_prompts\": prompts,\n            \"_deferred_completions\": completions,\n            \"_deferred_completion_ids_list\": completion_ids_list,\n            \"_rank0_only\": rank0_only,\n        }\n        if sampling_per_token_logps is not None:\n            output[\"sampling_per_token_logps\"] = sampling_per_token_logps\n        if tool_mask is not None:\n            output[\"tool_mask\"] = tool_mask\n        if images is not None:\n            output[\"num_images\"] = num_images\n        for k in (\n            \"pixel_values\",\n            \"image_grid_thw\",\n            \"pixel_attention_mask\",\n            \"image_sizes\",\n            \"token_type_ids\",\n            \"mm_token_type_ids\",\n        ):\n            if k in forward_kwargs:\n                output[k] = forward_kwargs[k]\n        return output\n\n    # ------------------------------------------------------------------\n    # Hooks (overridden by subclasses like FastAsyncGRPOTrainer)\n    # ------------------------------------------------------------------\n\n    def _compute_rewards_for_batch(\n        self, inputs, prompts, completions, completion_ids_list\n    ):\n        \"\"\"Compute rewards for a batch. Override for parallel workers, caching, etc.\"\"\"\n        return self._calculate_rewards(\n            inputs, prompts, completions, completion_ids_list\n        )\n\n    def _launch_reward_workers(self, inputs, prompts, completions, completion_ids_list):\n        \"\"\"Launch reward computation in background. Override for parallel dispatch.\n\n        Default: no-op (rewards computed synchronously in _collect_reward_workers).\n        \"\"\"\n        self._pending_reward_args = (inputs, prompts, completions, completion_ids_list)\n\n    def _collect_reward_workers(\n        self, inputs, prompts, completions, completion_ids_list\n    ):\n        \"\"\"Collect reward results. Override to collect from parallel workers.\n\n        Default: compute rewards synchronously now.\n        \"\"\"\n        args = getattr(self, \"_pending_reward_args\", None)\n        if args is not None:\n            self._pending_reward_args = None\n            return self._compute_rewards_for_batch(*args)\n        return self._compute_rewards_for_batch(\n            inputs, prompts, completions, completion_ids_list\n        )\n\n    def _post_advantage_hook(\n        self,\n        data: dict,\n        rewards_per_func,\n        advantages,\n        inputs: list,\n        num_generations: int,\n        mode: str,\n        s_start: int | None = None,\n        s_end: int | None = None,\n        is_last_chunk: bool = True,\n    ) -> None:\n        \"\"\"Called after advantages are computed. Override for replay buffer, re-roll, etc.\"\"\"\n\n    # ------------------------------------------------------------------\n    # Main-thread scoring\n    # ------------------------------------------------------------------\n\n    @torch.no_grad()\n    def _compute_deferred_scores(self, rollout: dict) -> dict:\n        \"\"\"Compute rewards, advantages, policy logprobs, and IS ratio on the main thread.\n\n        Takes the deferred output from ``_generate_only`` and produces a fully\n        scored dict ready for ``split_tensor_dict`` → micro-batches.\n        \"\"\"\n        device = self.accelerator.device\n        batch_size = self.args.per_device_train_batch_size\n        num_generations = self.num_generations\n        mode = \"train\"\n\n        # --- Extract deferred data ---\n        data = rollout\n        inputs = data.pop(\"_deferred_inputs\")\n        prompts = data.pop(\"_deferred_prompts\")\n        completions = data.pop(\"_deferred_completions\")\n        completion_ids_list = data.pop(\"_deferred_completion_ids_list\")\n        rank0_only = data.pop(\"_rank0_only\", False)\n        del data[\"_pending_policy_logps\"]\n\n        prompt_ids = data[\"prompt_ids\"]\n        completion_ids = data[\"completion_ids\"]\n        prompt_mask = data[\"prompt_mask\"]\n        completion_mask = data[\"completion_mask\"]\n        prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)\n        attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)\n        logits_to_keep = completion_ids.size(1)\n\n        # Multimodal forward kwargs\n        forward_kwargs = {}\n        for key in (\n            \"pixel_values\",\n            \"image_grid_thw\",\n            \"pixel_attention_mask\",\n            \"image_sizes\",\n            \"token_type_ids\",\n            \"mm_token_type_ids\",\n        ):\n            if key in data:\n                forward_kwargs[key] = data[key]\n        num_images = data.get(\"num_images\")\n\n        # --- Launch rewards in parallel with logprobs ---\n        self._launch_reward_workers(inputs, prompts, completions, completion_ids_list)\n\n        # --- Policy logprobs ---\n        logprob_batch_size = min(batch_size * 4, len(prompt_ids))\n        with disable_gradient_checkpointing(\n            self.model, self.args.gradient_checkpointing_kwargs\n        ):\n            generate_every = self.args.steps_per_generation * self.num_iterations\n            if self.args.gradient_accumulation_steps % generate_every != 0 or (\n                self.use_vllm\n                and getattr(self, \"vllm_importance_sampling_correction\", False)\n            ):\n                old_per_token_logps, _ = self._get_per_token_logps_and_entropies(\n                    self.model,\n                    prompt_completion_ids,\n                    attention_mask,\n                    logits_to_keep,\n                    logprob_batch_size,\n                    num_images=num_images,\n                    **forward_kwargs,\n                )\n                data[\"old_per_token_logps\"] = old_per_token_logps\n            else:\n                old_per_token_logps = None\n\n            # Reference model logprobs\n            if self.beta != 0.0:\n                if self.ref_model is not None:\n                    ref_logps, _ = self._get_per_token_logps_and_entropies(\n                        self.ref_model,\n                        prompt_completion_ids,\n                        attention_mask,\n                        logits_to_keep,\n                        batch_size,\n                        num_images=num_images,\n                        **forward_kwargs,\n                    )\n                else:\n                    unwrapped = self.accelerator.unwrap_model(self.model)\n                    adapter_name = (\n                        \"ref\"\n                        if hasattr(unwrapped, \"peft_config\")\n                        and \"ref\" in unwrapped.peft_config\n                        else None\n                    )\n                    with use_adapter(unwrapped, adapter_name=adapter_name):\n                        ref_logps, _ = self._get_per_token_logps_and_entropies(\n                            self.model,\n                            prompt_completion_ids,\n                            attention_mask,\n                            logits_to_keep,\n                            batch_size,\n                            num_images=num_images,\n                            **forward_kwargs,\n                        )\n                data[\"ref_per_token_logps\"] = ref_logps\n\n        # --- IS ratio ---\n        if (\n            self.use_vllm\n            and getattr(self, \"vllm_importance_sampling_correction\", False)\n            and old_per_token_logps is not None\n            and \"sampling_per_token_logps\" in data\n        ):\n            sampling_logps = data[\"sampling_per_token_logps\"]\n            is_mask = (\n                completion_mask\n                if \"tool_mask\" not in data\n                else completion_mask * data[\"tool_mask\"]\n            )\n            per_token_logps_diff = (old_per_token_logps - sampling_logps) * is_mask\n\n            is_mode = getattr(self, \"vllm_importance_sampling_mode\", \"token_truncate\")\n            is_cap = getattr(self, \"vllm_importance_sampling_cap\", 3.0)\n            sequence_level_is = is_mode in (\"sequence_mask\", \"sequence_truncate\")\n            if sequence_level_is:\n                logps_diff = per_token_logps_diff.sum(dim=-1, keepdim=True)\n            else:\n                logps_diff = per_token_logps_diff\n\n            is_ratio = torch.exp(logps_diff)\n            if is_mode in (\"sequence_truncate\", \"token_truncate\"):\n                is_ratio = torch.clamp(is_ratio, max=is_cap)\n            elif is_mode in (\"sequence_mask\", \"token_mask\"):\n                is_ratio = is_ratio.masked_fill(is_ratio > is_cap, value=0.0)\n            data[\"importance_sampling_ratio\"] = is_ratio\n\n        # --- Collect rewards (launched before logprobs, should be done) ---\n        rewards_per_func = self._collect_reward_workers(\n            inputs, prompts, completions, completion_ids_list\n        )\n        # In rank0_only mode, all ranks compute the same rewards on identical data.\n        # _calculate_rewards / _collect_reward_workers always `gather()` across ranks,\n        # which duplicates the rows (N_local * num_processes).  De-duplicate so that\n        # rewards_per_func matches the data dict (which has N_local rows).\n        if rank0_only and rewards_per_func.size(0) > len(prompts):\n            rewards_per_func = rewards_per_func[: len(prompts)]\n\n        # --- Advantages ---\n        if self.multi_objective_aggregation == \"sum_then_normalize\":\n            rewards = (\n                rewards_per_func * self.reward_weights.to(device).unsqueeze(0)\n            ).nansum(dim=1)\n            mean_grouped = (\n                rewards.view(-1, num_generations)\n                .mean(dim=1)\n                .repeat_interleave(num_generations)\n            )\n            if self.scale_rewards in (\"group\", \"none\"):\n                if num_generations > 1:\n                    std_rewards = (\n                        rewards.view(-1, num_generations)\n                        .std(dim=1)\n                        .repeat_interleave(num_generations)\n                    )\n                else:\n                    std_rewards = torch.zeros_like(rewards)\n            elif self.scale_rewards == \"batch\":\n                std_rewards = (\n                    rewards.std().expand_as(rewards)\n                    if rewards.numel() > 1\n                    else torch.zeros_like(rewards)\n                )\n            else:\n                raise ValueError(f\"Invalid scale_rewards: {self.scale_rewards}\")\n            advantages = rewards - mean_grouped\n            if self.scale_rewards != \"none\":\n                advantages = advantages / (std_rewards + 1e-4)\n            is_std_zero = torch.isclose(std_rewards, torch.zeros_like(std_rewards))\n\n        elif self.multi_objective_aggregation == \"normalize_then_sum\":\n            grouped = rewards_per_func.view(-1, num_generations, len(self.reward_funcs))\n            mean_k = torch.nanmean(grouped, dim=1, keepdim=True)\n            std_k = (\n                nanstd(grouped, dim=1, keepdim=True)\n                if num_generations > 1\n                else torch.zeros_like(mean_k)\n            )\n            reward_k = (grouped - mean_k) / (std_k + 1e-4)\n            reward_k = reward_k.view(-1, len(self.reward_funcs))\n            rewards = (reward_k * self.reward_weights.to(device).unsqueeze(0)).nansum(\n                dim=1\n            )\n            std_rewards = (\n                rewards.std().expand_as(rewards)\n                if rewards.numel() > 1\n                else torch.zeros_like(rewards)\n            )\n            advantages = (rewards - rewards.mean()) / (std_rewards + 1e-4)\n            is_std_zero = torch.isclose(std_rewards, torch.zeros_like(std_rewards))\n        else:\n            raise ValueError(\n                f\"Invalid multi_objective_aggregation: {self.multi_objective_aggregation}\"\n            )\n\n        # Slice for local process\n        # In rank0_only mode, all ranks already have identical data from broadcast,\n        # so no slicing needed. Otherwise, each rank takes its portion.\n        if rank0_only:\n            process_slice = slice(0, len(prompts))\n        else:\n            process_slice = slice(\n                self.accelerator.process_index * len(prompts),\n                (self.accelerator.process_index + 1) * len(prompts),\n            )\n        all_advantages = advantages.clone()\n        advantages = advantages[process_slice]\n        data[\"advantages\"] = advantages\n\n        # --- Post-advantage hook (for replay buffer, re-roll, etc.) ---\n        self._post_advantage_hook(\n            data,\n            rewards_per_func,\n            advantages,\n            inputs,\n            num_generations,\n            mode,\n        )\n\n        # --- Metrics ---\n        for i, name in enumerate(self.reward_func_names):\n            self._metrics[mode][f\"rewards/{name}/mean\"].append(\n                torch.nanmean(rewards_per_func[:, i]).item()\n            )\n            self._metrics[mode][f\"rewards/{name}/std\"].append(\n                nanstd(rewards_per_func[:, i]).item()\n            )\n        agg_rewards = rewards_per_func.nansum(dim=1)\n        self._metrics[mode][\"reward\"].append(agg_rewards.mean().item())\n        self._metrics[mode][\"reward_std\"].append(agg_rewards.std().item())\n        self._metrics[mode][\"frac_reward_zero_std\"].append(\n            is_std_zero.float().mean().item()\n        )\n\n        # Token counting\n        total_prompt = self.accelerator.gather(prompt_mask.sum()).sum()\n        total_completion = self.accelerator.gather(completion_mask.sum()).sum()\n        self.state.num_input_tokens_seen += (total_prompt + total_completion).item()\n        self._metrics[mode][\"num_tokens\"] = [self.state.num_input_tokens_seen]\n\n        # Completion length metrics\n        comp_lengths = completion_mask.sum(dim=1)\n        agg_lengths = self.accelerator.gather(comp_lengths)\n        self._metrics[mode][\"completions/mean_length\"].append(\n            agg_lengths.float().mean().item()\n        )\n        self._metrics[mode][\"completions/min_length\"].append(\n            agg_lengths.float().min().item()\n        )\n        self._metrics[mode][\"completions/max_length\"].append(\n            agg_lengths.float().max().item()\n        )\n\n        eos_and_pad = [self.eos_token_id, self.pad_token_id]\n        is_trunc = torch.tensor(\n            [ids[-1].item() not in eos_and_pad for ids in completion_ids], device=device\n        )\n        agg_trunc = self.accelerator.gather(is_trunc)\n        self._metrics[mode][\"completions/clipped_ratio\"].append(\n            agg_trunc.float().mean().item()\n        )\n        term_lengths = agg_lengths[~agg_trunc]\n        if len(term_lengths) == 0:\n            term_lengths = torch.zeros(1, device=device)\n        self._metrics[mode][\"completions/mean_terminated_length\"].append(\n            term_lengths.float().mean().item()\n        )\n        self._metrics[mode][\"completions/min_terminated_length\"].append(\n            term_lengths.float().min().item()\n        )\n        self._metrics[mode][\"completions/max_terminated_length\"].append(\n            term_lengths.float().max().item()\n        )\n\n        # IS metrics\n        if \"importance_sampling_ratio\" in data and \"sampling_per_token_logps\" in data:\n            old_lp = data[\"old_per_token_logps\"]\n            samp_lp = data[\"sampling_per_token_logps\"]\n            mask = completion_mask.bool()\n            delta = torch.abs(old_lp - samp_lp)\n            delta_m = delta[mask]\n            md = (\n                torch.mean(delta_m)\n                if delta_m.numel() > 0\n                else torch.tensor(0.0, device=device)\n            )\n            xd = (\n                torch.max(delta_m)\n                if delta_m.numel() > 0\n                else torch.tensor(0.0, device=device)\n            )\n            self._metrics[mode][\"sampling/sampling_logp_difference/mean\"].append(\n                self.accelerator.gather(md).mean().item()\n            )\n            self._metrics[mode][\"sampling/sampling_logp_difference/max\"].append(\n                self.accelerator.gather(xd).max().item()\n            )\n            isr = data[\"importance_sampling_ratio\"]\n            is_mode = getattr(self, \"vllm_importance_sampling_mode\", \"token_truncate\")\n            if is_mode in (\"sequence_mask\", \"sequence_truncate\"):\n                flat_isr = isr.flatten()\n            else:\n                flat_isr = isr[mask]\n            if flat_isr.numel() > 0:\n                self._metrics[mode][\"sampling/importance_sampling_ratio/min\"].append(\n                    nanmin(self.accelerator.gather(torch.min(flat_isr))).item()\n                )\n                self._metrics[mode][\"sampling/importance_sampling_ratio/mean\"].append(\n                    self.accelerator.gather(torch.mean(flat_isr)).nanmean().item()\n                )\n                self._metrics[mode][\"sampling/importance_sampling_ratio/max\"].append(\n                    nanmax(self.accelerator.gather(torch.max(flat_isr))).item()\n                )\n\n        # Log prompt/completion texts\n        prompts_text = self.processing_class.batch_decode(\n            prompt_ids, skip_special_tokens=True\n        )\n        completions_text = self.processing_class.batch_decode(\n            completion_ids, skip_special_tokens=True\n        )\n        if gather_object is not None:\n            self._logs[\"prompt\"].extend(gather_object(prompts_text))\n            self._logs[\"completion\"].extend(gather_object(completions_text))\n        for i, name in enumerate(self.reward_func_names):\n            self._logs[\"rewards\"][name].extend(rewards_per_func[:, i].tolist())\n        self._logs[\"advantages\"].extend(all_advantages.tolist())\n\n        # Remove deferred keys\n        for k in list(data.keys()):\n            if k.startswith(\"_deferred\") or k == \"_pending_policy_logps\":\n                data.pop(k, None)\n\n        return data\n\n    @torch.no_grad()\n    def _compute_streaming_group_scores(\n        self,\n        data,\n        s_start,\n        s_end,\n        inputs,\n        prompts,\n        completions,\n        completion_ids_list,\n        is_last_chunk,\n        rank0_only=False,\n    ):\n        \"\"\"Score a chunk of prompt groups: rewards, policy logprobs, advantages.\n\n        Called during streaming scoring to incrementally score groups.\n        Writes results directly into ``data`` at positions ``s_start:s_end``.\n        \"\"\"\n        device = self.accelerator.device\n        batch_size = self.args.per_device_train_batch_size\n        num_generations = self.num_generations\n        mode = \"train\"\n        chunk_size = s_end - s_start\n\n        # --- Policy logprobs for this chunk ---\n        chunk_prompt_ids = data[\"prompt_ids\"][s_start:s_end]\n        chunk_completion_ids = data[\"completion_ids\"][s_start:s_end]\n        chunk_prompt_mask = data[\"prompt_mask\"][s_start:s_end]\n        chunk_completion_mask = data[\"completion_mask\"][s_start:s_end]\n        prompt_completion_ids = torch.cat(\n            [chunk_prompt_ids, chunk_completion_ids], dim=1\n        )\n        attention_mask = torch.cat([chunk_prompt_mask, chunk_completion_mask], dim=1)\n        logits_to_keep = chunk_completion_ids.size(1)\n\n        # Slice multimodal forward kwargs for this chunk\n        forward_kwargs = {}\n        for key in (\n            \"pixel_values\",\n            \"image_grid_thw\",\n            \"pixel_attention_mask\",\n            \"image_sizes\",\n            \"token_type_ids\",\n            \"mm_token_type_ids\",\n        ):\n            if key in data:\n                val = data[key]\n                if (\n                    isinstance(val, torch.Tensor)\n                    and val.dim() > 0\n                    and val.size(0) == len(data[\"prompt_ids\"])\n                ):\n                    forward_kwargs[key] = val[s_start:s_end]\n                else:\n                    forward_kwargs[key] = val\n        num_images = data.get(\"num_images\")\n        if (\n            num_images is not None\n            and hasattr(num_images, \"__getitem__\")\n            and len(num_images) == len(data[\"prompt_ids\"])\n        ):\n            num_images = num_images[s_start:s_end]\n\n        # --- Launch rewards in parallel with logprobs ---\n        self._launch_reward_workers(inputs, prompts, completions, completion_ids_list)\n\n        # --- Policy logprobs for this chunk (GPU, overlaps with BG rewards) ---\n        logprob_batch_size = min(batch_size * 2, chunk_size)\n        with disable_gradient_checkpointing(\n            self.model, self.args.gradient_checkpointing_kwargs\n        ):\n            generate_every = self.args.steps_per_generation * self.num_iterations\n            if self.args.gradient_accumulation_steps % generate_every != 0 or (\n                self.use_vllm\n                and getattr(self, \"vllm_importance_sampling_correction\", False)\n            ):\n                old_logps, _ = self._get_per_token_logps_and_entropies(\n                    self.model,\n                    prompt_completion_ids,\n                    attention_mask,\n                    logits_to_keep,\n                    logprob_batch_size,\n                    num_images=num_images,\n                    **forward_kwargs,\n                )\n                if \"old_per_token_logps\" not in data:\n                    total = len(data[\"prompt_ids\"])\n                    data[\"old_per_token_logps\"] = torch.zeros(\n                        total, old_logps.size(1), device=device, dtype=old_logps.dtype\n                    )\n                data[\"old_per_token_logps\"][s_start:s_end] = old_logps\n\n                # Compute IS ratio for this chunk\n                if \"sampling_per_token_logps\" in data:\n                    samp_chunk = data[\"sampling_per_token_logps\"][s_start:s_end]\n                    is_mask = (\n                        chunk_completion_mask\n                        if \"tool_mask\" not in data\n                        else (chunk_completion_mask * data[\"tool_mask\"][s_start:s_end])\n                    )\n                    diff = (old_logps - samp_chunk) * is_mask\n                    is_mode = getattr(\n                        self, \"vllm_importance_sampling_mode\", \"token_truncate\"\n                    )\n                    is_cap = getattr(self, \"vllm_importance_sampling_cap\", 3.0)\n                    seq_is = is_mode in (\"sequence_mask\", \"sequence_truncate\")\n                    logps_diff = diff.sum(dim=-1, keepdim=True) if seq_is else diff\n                    is_ratio = torch.exp(logps_diff)\n                    if is_mode in (\"sequence_truncate\", \"token_truncate\"):\n                        is_ratio = torch.clamp(is_ratio, max=is_cap)\n                    elif is_mode in (\"sequence_mask\", \"token_mask\"):\n                        is_ratio = is_ratio.masked_fill(is_ratio > is_cap, value=0.0)\n                    if \"importance_sampling_ratio\" not in data:\n                        total = len(data[\"prompt_ids\"])\n                        shape = (total, 1) if seq_is else (total, is_ratio.size(1))\n                        data[\"importance_sampling_ratio\"] = torch.ones(\n                            *shape, device=device, dtype=is_ratio.dtype\n                        )\n                    data[\"importance_sampling_ratio\"][s_start:s_end] = is_ratio\n\n            # Reference logprobs\n            if self.beta != 0.0:\n                if self.ref_model is not None:\n                    ref_logps, _ = self._get_per_token_logps_and_entropies(\n                        self.ref_model,\n                        prompt_completion_ids,\n                        attention_mask,\n                        logits_to_keep,\n                        batch_size,\n                        num_images=num_images,\n                        **forward_kwargs,\n                    )\n                else:\n                    unwrapped = self.accelerator.unwrap_model(self.model)\n                    adapter_name = (\n                        \"ref\"\n                        if hasattr(unwrapped, \"peft_config\")\n                        and \"ref\" in unwrapped.peft_config\n                        else None\n                    )\n                    with use_adapter(unwrapped, adapter_name=adapter_name):\n                        ref_logps, _ = self._get_per_token_logps_and_entropies(\n                            self.model,\n                            prompt_completion_ids,\n                            attention_mask,\n                            logits_to_keep,\n                            batch_size,\n                            num_images=num_images,\n                            **forward_kwargs,\n                        )\n                if \"ref_per_token_logps\" not in data:\n                    total = len(data[\"prompt_ids\"])\n                    data[\"ref_per_token_logps\"] = torch.zeros(\n                        total, ref_logps.size(1), device=device, dtype=ref_logps.dtype\n                    )\n                data[\"ref_per_token_logps\"][s_start:s_end] = ref_logps\n\n        # --- Collect rewards (should already be done, ran in parallel with logprobs) ---\n        rewards_per_func = self._collect_reward_workers(\n            inputs, prompts, completions, completion_ids_list\n        )\n        # De-duplicate gathered rewards when all ranks computed the same data.\n        # _calculate_rewards always gather()s, which duplicates rows in rank0_only mode.\n        if rewards_per_func.size(0) > chunk_size:\n            rewards_per_func = rewards_per_func[:chunk_size]\n\n        # --- Advantages (group-level normalization) ---\n        if self.multi_objective_aggregation == \"sum_then_normalize\":\n            rewards = (\n                rewards_per_func * self.reward_weights.to(device).unsqueeze(0)\n            ).nansum(dim=1)\n            mean_g = (\n                rewards.view(-1, num_generations)\n                .mean(dim=1)\n                .repeat_interleave(num_generations)\n            )\n            if num_generations > 1:\n                std_r = (\n                    rewards.view(-1, num_generations)\n                    .std(dim=1)\n                    .repeat_interleave(num_generations)\n                )\n            else:\n                std_r = torch.zeros_like(rewards)\n            advantages = rewards - mean_g\n            if self.scale_rewards != \"none\":\n                advantages = advantages / (std_r + 1e-4)\n            is_std_zero = torch.isclose(std_r, torch.zeros_like(std_r))\n\n        elif self.multi_objective_aggregation == \"normalize_then_sum\":\n            grouped = rewards_per_func.view(-1, num_generations, len(self.reward_funcs))\n            mean_k = torch.nanmean(grouped, dim=1, keepdim=True)\n            std_k = (\n                nanstd(grouped, dim=1, keepdim=True)\n                if num_generations > 1\n                else torch.zeros_like(mean_k)\n            )\n            reward_k = ((grouped - mean_k) / (std_k + 1e-4)).view(\n                -1, len(self.reward_funcs)\n            )\n            rewards = (reward_k * self.reward_weights.to(device).unsqueeze(0)).nansum(\n                dim=1\n            )\n            std_r = (\n                rewards.view(-1, num_generations)\n                .std(dim=1)\n                .repeat_interleave(num_generations)\n            )\n            mean_r = (\n                rewards.view(-1, num_generations)\n                .mean(dim=1)\n                .repeat_interleave(num_generations)\n            )\n            advantages = (rewards - mean_r) / (std_r + 1e-4)\n            is_std_zero = torch.isclose(std_r, torch.zeros_like(std_r))\n        else:\n            raise ValueError(\n                f\"Invalid multi_objective_aggregation: {self.multi_objective_aggregation}\"\n            )\n\n        if rank0_only:\n            process_slice = slice(0, len(prompts))\n        else:\n            process_slice = slice(\n                self.accelerator.process_index * len(prompts),\n                (self.accelerator.process_index + 1) * len(prompts),\n            )\n        advantages = advantages[process_slice]\n\n        if \"advantages\" not in data or not isinstance(data[\"advantages\"], torch.Tensor):\n            data[\"advantages\"] = torch.zeros(len(data[\"prompt_ids\"]), device=device)\n        data[\"advantages\"][s_start:s_end] = advantages\n\n        # --- Post-advantage hook (for replay buffer, re-roll, etc.) ---\n        self._post_advantage_hook(\n            data,\n            rewards_per_func,\n            advantages,\n            inputs,\n            num_generations,\n            mode,\n            s_start=s_start,\n            s_end=s_end,\n            is_last_chunk=is_last_chunk,\n        )\n\n        # --- Chunk metrics ---\n        for i, name in enumerate(self.reward_func_names):\n            self._metrics[mode][f\"rewards/{name}/mean\"].append(\n                torch.nanmean(rewards_per_func[:, i]).item()\n            )\n            self._metrics[mode][f\"rewards/{name}/std\"].append(\n                nanstd(rewards_per_func[:, i]).item()\n            )\n        agg_rewards = rewards_per_func.nansum(dim=1)\n        self._metrics[mode][\"reward\"].append(agg_rewards.mean().item())\n        self._metrics[mode][\"reward_std\"].append(agg_rewards.std().item())\n        self._metrics[mode][\"frac_reward_zero_std\"].append(\n            is_std_zero.float().mean().item()\n        )\n\n        # --- Full-batch metrics on last chunk ---\n        if is_last_chunk:\n            all_prompt_mask = data[\"prompt_mask\"]\n            all_completion_mask = data[\"completion_mask\"]\n            all_completion_ids = data[\"completion_ids\"]\n            total_p = self.accelerator.gather(all_prompt_mask.sum()).sum()\n            total_c = self.accelerator.gather(all_completion_mask.sum()).sum()\n            self.state.num_input_tokens_seen += (total_p + total_c).item()\n            self._metrics[mode][\"num_tokens\"] = [self.state.num_input_tokens_seen]\n\n            comp_lengths = all_completion_mask.sum(dim=1)\n            agg_lengths = self.accelerator.gather(comp_lengths)\n            self._metrics[mode][\"completions/mean_length\"].append(\n                agg_lengths.float().mean().item()\n            )\n            self._metrics[mode][\"completions/min_length\"].append(\n                agg_lengths.float().min().item()\n            )\n            self._metrics[mode][\"completions/max_length\"].append(\n                agg_lengths.float().max().item()\n            )\n\n            eos_and_pad = [self.eos_token_id, self.pad_token_id]\n            is_trunc = torch.tensor(\n                [ids[-1].item() not in eos_and_pad for ids in all_completion_ids],\n                device=device,\n            )\n            agg_trunc = self.accelerator.gather(is_trunc)\n            self._metrics[mode][\"completions/clipped_ratio\"].append(\n                agg_trunc.float().mean().item()\n            )\n            term = agg_lengths[~agg_trunc]\n            if len(term) == 0:\n                term = torch.zeros(1, device=device)\n            self._metrics[mode][\"completions/mean_terminated_length\"].append(\n                term.float().mean().item()\n            )\n            self._metrics[mode][\"completions/min_terminated_length\"].append(\n                term.float().min().item()\n            )\n            self._metrics[mode][\"completions/max_terminated_length\"].append(\n                term.float().max().item()\n            )\n\n            # IS metrics\n            if (\n                self.use_vllm\n                and getattr(self, \"vllm_importance_sampling_correction\", False)\n                and \"sampling_per_token_logps\" in data\n                and \"old_per_token_logps\" in data\n            ):\n                old_lp = data[\"old_per_token_logps\"]\n                samp_lp = data[\"sampling_per_token_logps\"]\n                mask = all_completion_mask.bool()\n                delta = torch.abs(old_lp - samp_lp)[mask]\n                md = (\n                    torch.mean(delta)\n                    if delta.numel() > 0\n                    else torch.tensor(0.0, device=device)\n                )\n                xd = (\n                    torch.max(delta)\n                    if delta.numel() > 0\n                    else torch.tensor(0.0, device=device)\n                )\n                self._metrics[mode][\"sampling/sampling_logp_difference/mean\"].append(\n                    self.accelerator.gather(md).mean().item()\n                )\n                self._metrics[mode][\"sampling/sampling_logp_difference/max\"].append(\n                    self.accelerator.gather(xd).max().item()\n                )\n                is_mode = getattr(\n                    self, \"vllm_importance_sampling_mode\", \"token_truncate\"\n                )\n                isr = data[\"importance_sampling_ratio\"]\n                flat = (\n                    isr.flatten()\n                    if is_mode in (\"sequence_mask\", \"sequence_truncate\")\n                    else isr[mask]\n                )\n                if flat.numel() > 0:\n                    self._metrics[mode][\n                        \"sampling/importance_sampling_ratio/min\"\n                    ].append(nanmin(self.accelerator.gather(torch.min(flat))).item())\n                    self._metrics[mode][\n                        \"sampling/importance_sampling_ratio/mean\"\n                    ].append(self.accelerator.gather(torch.mean(flat)).nanmean().item())\n                    self._metrics[mode][\n                        \"sampling/importance_sampling_ratio/max\"\n                    ].append(nanmax(self.accelerator.gather(torch.max(flat))).item())\n\n    def _score_streaming(self, rollout: dict) -> list[dict]:\n        \"\"\"Score a rollout using streaming group scoring.  Returns list of micro-batches.\"\"\"\n        data = rollout\n        num_gen = self.num_generations\n        n_groups = len(data[\"prompt_ids\"]) // num_gen\n        batch_size = self.args.per_device_train_batch_size\n        min_groups = max(1, self.args.streaming_min_groups)\n\n        # Extract deferred data\n        inputs = data.pop(\"_deferred_inputs\")\n        prompts = data.pop(\"_deferred_prompts\")\n        completions = data.pop(\"_deferred_completions\")\n        completion_ids_list = data.pop(\"_deferred_completion_ids_list\")\n        rank0_only = data.pop(\"_rank0_only\", False)\n        del data[\"_pending_policy_logps\"]\n\n        all_micro_batches = []\n        shared_keys = {\"num_items_in_batch\"}\n\n        for chunk_start_g in range(0, n_groups, min_groups):\n            chunk_end_g = min(chunk_start_g + min_groups, n_groups)\n            s_start = chunk_start_g * num_gen\n            s_end = chunk_end_g * num_gen\n\n            self._compute_streaming_group_scores(\n                data=data,\n                s_start=s_start,\n                s_end=s_end,\n                inputs=inputs[s_start:s_end],\n                prompts=prompts[s_start:s_end],\n                completions=completions[s_start:s_end],\n                completion_ids_list=completion_ids_list[s_start:s_end],\n                is_last_chunk=(chunk_end_g == n_groups),\n                rank0_only=rank0_only,\n            )\n\n            # Yield micro-batches from this scored chunk\n            chunk_size = s_end - s_start\n            perm = torch.randperm(chunk_size)\n            for mb_off in range(0, chunk_size, batch_size):\n                mb_idx = perm[mb_off : mb_off + batch_size]\n                abs_idx = mb_idx + s_start\n                mb = {}\n                for key in data:\n                    if key.startswith(\"_\"):\n                        continue\n                    val = data[key]\n                    if key in shared_keys:\n                        mb[key] = val\n                    elif isinstance(val, torch.Tensor) and val.dim() > 0:\n                        mb[key] = val[abs_idx]\n                    else:\n                        mb[key] = val\n                all_micro_batches.append(mb)\n\n        # Repeat for num_iterations\n        return all_micro_batches * self.num_iterations\n\n    # ------------------------------------------------------------------\n    # _prepare_inputs override\n    # ------------------------------------------------------------------\n\n    def _prepare_inputs(self, generation_batch):\n        \"\"\"Override to support data producer and async prefetch paths.\"\"\"\n        mode = \"train\" if self.model.training else \"eval\"\n\n        # --- Data producer path ---\n        if mode == \"train\" and self.data_producer is not None:\n            return self._prepare_inputs_data_producer(generation_batch)\n\n        # --- Legacy async prefetch path (no data producer) ---\n        if mode == \"train\" and self.args.async_prefetch:\n            return self._prepare_inputs_legacy_async(generation_batch)\n\n        # --- Stock path ---\n        return super()._prepare_inputs(generation_batch)\n\n    def _prepare_inputs_data_producer(self, generation_batch):\n        \"\"\"Data producer path: produce rollout, score deferred logps, split into micro-batches.\"\"\"\n        # Return from buffer if available\n        if self._buffered_inputs:\n            return self._buffered_inputs.pop(0)\n\n        # Produce a new rollout\n        self._maybe_sync_vllm_weights()\n\n        rollout_dataset = self.data_producer.produce(\n            self.model,\n            self.state.global_step,\n            processing_class=self.processing_class,\n            accelerator=self.accelerator,\n            args=self.args,\n        )\n\n        # Convert RolloutDataset back to a dict for scoring/splitting\n        rollout = rollout_dataset._data\n\n        # If async (skip_policy_logps=True), score deferred logps on main thread\n        if rollout.get(\"_pending_policy_logps\"):\n            if self.args.streaming_partial_batch:\n                micro_batches = self._score_streaming(rollout)\n            else:\n                scored = self._compute_deferred_scores(rollout)\n                scored = split_pixel_values_by_grid(scored)\n                scored = shuffle_sequence_dict(scored)\n                batches = split_tensor_dict(scored, self.args.steps_per_generation)\n                micro_batches = [unsplit_pixel_values_by_grid(b) for b in batches]\n                micro_batches = micro_batches * self.num_iterations\n        else:\n            # Sync path: data is already fully scored\n            rollout = split_pixel_values_by_grid(rollout)\n            batches = split_tensor_dict(rollout, self.args.steps_per_generation)\n            micro_batches = [unsplit_pixel_values_by_grid(b) for b in batches]\n            micro_batches = micro_batches * self.num_iterations\n\n        self._buffered_inputs = micro_batches[1:]\n        return micro_batches[0]\n\n    def _prepare_inputs_legacy_async(self, generation_batch):\n        \"\"\"Legacy async path: direct queue-based prefetch without data producer.\"\"\"\n        # Return from buffer if available\n        if self._buffered_inputs:\n            return self._buffered_inputs.pop(0)\n\n        # Need a new rollout\n        self._maybe_sync_vllm_weights()\n        future = self._async_queue.get()\n        rollout = future.result()\n        self._submit_generation()\n\n        if self.args.streaming_partial_batch:\n            micro_batches = self._score_streaming(rollout)\n        else:\n            scored = self._compute_deferred_scores(rollout)\n            scored = split_pixel_values_by_grid(scored)\n            scored = shuffle_sequence_dict(scored)\n            batches = split_tensor_dict(scored, self.args.steps_per_generation)\n            micro_batches = [unsplit_pixel_values_by_grid(b) for b in batches]\n            micro_batches = micro_batches * self.num_iterations\n\n        self._buffered_inputs = micro_batches[1:]\n\n        # Release cached CUDA memory from scoring\n        # before training allocations begin, reducing peak reserved memory.\n        torch.cuda.empty_cache()\n\n        return micro_batches[0]\n\n    @profiling_decorator\n    def _get_per_token_logps_and_entropies(\n        self,\n        model,\n        input_ids,\n        attention_mask,\n        logits_to_keep,\n        batch_size=None,\n        compute_entropy=False,\n        pixel_values=None,\n        image_grid_thw=None,\n        num_images=None,\n        pixel_attention_mask=None,\n        image_sizes=None,\n        token_type_ids=None,\n        mm_token_type_ids=None,\n    ) -> tuple[Any, torch.Tensor | None]:\n        \"\"\"Compute log-probs and (optionally) entropies for each token.\n\n        When running under no_grad (scoring path), bypasses accelerate's\n        ConvertOutputsToFp32 wrapper to avoid a fp32 copy of the\n        logits tensor.\n        \"\"\"\n        # Bypass accelerate's ConvertOutputsToFp32 wrapper which converts the\n        # entire (B, L, V) logits tensor from bf16 to fp32 — unnecessary and\n        # extremely wasteful for large vocabularies.\n        # Skip unwrapping for FSDP — parameters are only valid inside FSDP's\n        # forward context; unwrapping exposes flattened/sharded tensors.\n        if not self.is_fsdp_enabled:\n            model = self.accelerator.unwrap_model(model, keep_fp32_wrapper=False)\n        autocast_ctx = torch.autocast(\n            device_type=input_ids.device.type, dtype=torch.bfloat16\n        )\n\n        # Use Liger's Triton kernel in scoring path (no grad): fuses\n        # temperature + log_softmax + gather into a single kernel pass.\n        use_fused = (\n            self.use_liger_kernel\n            and _fused_selective_log_softmax is not None\n            and not torch.is_grad_enabled()\n        )\n\n        batch_size = batch_size or input_ids.size(0)\n        all_logps = []\n        all_entropies = []\n        with autocast_ctx:\n            for start in range(0, input_ids.size(0), batch_size):\n                input_ids_batch = input_ids[start : start + batch_size]\n                attention_mask_batch = attention_mask[start : start + batch_size]\n\n                # Build model inputs\n                model_inputs = {\n                    \"input_ids\": input_ids_batch,\n                    \"attention_mask\": attention_mask_batch,\n                }\n                if image_grid_thw is not None and pixel_values is not None:\n                    rows_per_image = image_grid_thw.prod(dim=-1)\n                    rows_per_sample = torch.split(rows_per_image, num_images)\n                    rows_per_sample = torch.stack([s.sum() for s in rows_per_sample])\n                    cum_rows = torch.cat(\n                        [\n                            torch.tensor([0], device=rows_per_sample.device),\n                            rows_per_sample.cumsum(0),\n                        ]\n                    )\n                    row_start, row_end = (\n                        cum_rows[start].item(),\n                        cum_rows[start + batch_size].item(),\n                    )\n                    model_inputs[\"pixel_values\"] = pixel_values[row_start:row_end]\n                    cum_imgs = torch.tensor([0] + num_images).cumsum(0)\n                    img_start, img_end = cum_imgs[start], cum_imgs[start + batch_size]\n                    model_inputs[\"image_grid_thw\"] = image_grid_thw[img_start:img_end]\n                elif pixel_values is not None:\n                    model_inputs[\"pixel_values\"] = pixel_values[\n                        start : start + batch_size\n                    ]\n                if pixel_attention_mask is not None:\n                    model_inputs[\"pixel_attention_mask\"] = pixel_attention_mask[\n                        start : start + batch_size\n                    ]\n                if image_sizes is not None:\n                    model_inputs[\"image_sizes\"] = image_sizes[\n                        start : start + batch_size\n                    ]\n                if token_type_ids is not None:\n                    model_inputs[\"token_type_ids\"] = token_type_ids[\n                        start : start + batch_size\n                    ]\n                if mm_token_type_ids is not None:\n                    model_inputs[\"mm_token_type_ids\"] = mm_token_type_ids[\n                        start : start + batch_size\n                    ]\n\n                if \"logits_to_keep\" in self.model_kwarg_keys:\n                    model_inputs[\"logits_to_keep\"] = logits_to_keep + 1\n\n                model_inputs[\"use_cache\"] = False\n\n                logits = model(**model_inputs).logits\n                completion_ids = input_ids_batch[:, -logits_to_keep:]\n                # FP8 models produce NaN logits at positions where\n                # attention_mask=0 (padding). Replace NaN with 0 so\n                # log_softmax yields uniform distribution for those positions.\n                # The completion_mask ensures these don't affect the loss.\n                logits = torch.nan_to_num(logits, nan=0.0)\n\n                if use_fused:\n                    logits = logits[:, -(logits_to_keep + 1) :, :]\n                    if not logits.is_contiguous():\n                        logits = logits.contiguous()\n                    logps = _fused_selective_log_softmax(\n                        logits, completion_ids, self.temperature\n                    )\n                    all_logps.append(logps)\n                else:\n                    logits = logits[:, :-1, :]\n                    logits = logits[:, -logits_to_keep:, :]\n                    logits.div_(self.temperature)\n                    logps = selective_log_softmax(logits, completion_ids)\n                    all_logps.append(logps)\n\n                    if compute_entropy:\n                        with torch.no_grad():\n                            entropies = entropy_from_logits(logits)\n                        all_entropies.append(entropies)\n\n        logps = torch.cat(all_logps, dim=0)\n        entropies = torch.cat(all_entropies, dim=0) if compute_entropy else None\n        return logps, entropies\n\n    # ------------------------------------------------------------------\n    # Loss override (adds IS ratio + OPSM)\n    # ------------------------------------------------------------------\n\n    @staticmethod\n    def get_off_policy_mask(\n        advantages,\n        per_token_logps,\n        sampling_per_token_logps,\n        mask,\n        off_policy_threshold,\n    ):\n        \"\"\"OPSM from DeepSeek-V3.2: drop sequences with negative advantage + high KL.\"\"\"\n        kl_div = sampling_per_token_logps - per_token_logps.detach()\n        seq_kl = (kl_div * mask).sum(dim=1, keepdim=True) / mask.sum(\n            dim=1, keepdim=True\n        ).clamp(min=1.0)\n        is_pos_adv = advantages >= 0\n        is_low_kl = seq_kl <= off_policy_threshold\n        return (is_pos_adv | is_low_kl).to(dtype=mask.dtype)\n\n    def _compute_loss(self, model, inputs):\n        \"\"\"Override to add IS ratio correction and off-policy sequence masking.\"\"\"\n        prompt_ids, prompt_mask = inputs[\"prompt_ids\"], inputs[\"prompt_mask\"]\n        completion_ids, completion_mask = (\n            inputs[\"completion_ids\"],\n            inputs[\"completion_mask\"],\n        )\n        input_ids = torch.cat([prompt_ids, completion_ids], dim=1)\n        attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)\n        logits_to_keep = completion_ids.size(1)\n        mask = (\n            completion_mask\n            if \"tool_mask\" not in inputs\n            else completion_mask * inputs[\"tool_mask\"]\n        )\n\n        per_token_logps, entropies = self._get_per_token_logps_and_entropies(\n            model,\n            input_ids,\n            attention_mask,\n            logits_to_keep,\n            compute_entropy=True,\n            pixel_values=inputs.get(\"pixel_values\"),\n            image_grid_thw=inputs.get(\"image_grid_thw\"),\n            num_images=inputs.get(\"num_images\"),\n            pixel_attention_mask=inputs.get(\"pixel_attention_mask\"),\n            image_sizes=inputs.get(\"image_sizes\"),\n            token_type_ids=inputs.get(\"token_type_ids\"),\n            mm_token_type_ids=inputs.get(\"mm_token_type_ids\"),\n        )\n        if self.top_entropy_quantile < 1.0:\n            entropy_mask = self.get_high_entropy_mask(\n                entropies, mask, 1 - self.top_entropy_quantile\n            )\n        else:\n            entropy_mask = None\n\n        advantages = inputs[\"advantages\"]\n        if advantages.dim() == 1:\n            advantages = advantages.unsqueeze(1)\n\n        old_per_token_logps = inputs.get(\"old_per_token_logps\")\n        old_per_token_logps = (\n            per_token_logps.detach()\n            if old_per_token_logps is None\n            else old_per_token_logps\n        )\n\n        # --- OPSM (off-policy sequence mask) ---\n        off_policy_mask = None\n        if getattr(self, \"off_policy_mask_threshold\", None) is not None:\n            sampling_per_token_logps = inputs.get(\n                \"sampling_per_token_logps\", old_per_token_logps\n            )\n            off_policy_mask = self.get_off_policy_mask(\n                advantages=advantages,\n                per_token_logps=per_token_logps,\n                sampling_per_token_logps=sampling_per_token_logps,\n                mask=mask,\n                off_policy_threshold=self.off_policy_mask_threshold,\n            )\n\n        # --- Importance weights ---\n        log_ratio = per_token_logps - old_per_token_logps\n        is_level = getattr(\n            self,\n            \"importance_sampling_level\",\n            getattr(self.args, \"importance_sampling_level\", \"token\"),\n        )\n        if is_level == \"token\":\n            log_importance_weights = log_ratio\n        elif is_level == \"sequence\":\n            log_importance_weights = (log_ratio * mask).sum(-1) / mask.sum(-1).clamp(\n                min=1.0\n            )\n            log_importance_weights = log_importance_weights.unsqueeze(-1)\n        else:\n            raise ValueError(f\"Unknown importance sampling level: {is_level}\")\n\n        coef_1 = torch.exp(log_importance_weights)\n\n        # --- KL divergence ---\n        if self.beta != 0.0:\n            ref_per_token_logps = inputs[\"ref_per_token_logps\"]\n            per_token_kl = (\n                torch.exp(ref_per_token_logps - per_token_logps)\n                - (ref_per_token_logps - per_token_logps)\n                - 1\n            )\n            if getattr(self.args, \"use_bias_correction_kl\", False):\n                per_token_kl = per_token_kl * coef_1\n\n        # --- Per-token loss ---\n        if self.loss_type == \"cispo\":\n            clamped = torch.clamp(coef_1, max=self.epsilon_high).detach()\n            per_token_loss = -clamped * advantages * per_token_logps\n        elif self.loss_type in (\"grpo\", \"bnpo\", \"dr_grpo\", \"dapo\", \"luspo\"):\n            coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high)\n            if self.args.delta is not None:\n                coef_1_c = torch.clamp(coef_1, max=self.args.delta)\n            else:\n                coef_1_c = coef_1\n            per_token_loss = -torch.min(coef_1_c * advantages, coef_2 * advantages)\n        elif self.loss_type == \"sapo\":\n            temps = torch.where(\n                advantages > 0,\n                self.args.sapo_temperature_pos,\n                self.args.sapo_temperature_neg,\n            )\n            soft = torch.sigmoid(temps * (coef_1 - 1)) * 4 / temps\n            per_token_loss = -soft * advantages\n        else:\n            raise ValueError(f\"Unknown loss type: {self.loss_type}\")\n\n        # --- Apply masks ---\n        if off_policy_mask is not None:\n            per_token_loss = per_token_loss * off_policy_mask\n        if entropy_mask is not None:\n            per_token_loss = per_token_loss * entropy_mask\n\n        # --- IS ratio correction (vLLM distribution mismatch) ---\n        if (\n            self.use_vllm\n            and getattr(self, \"vllm_importance_sampling_correction\", False)\n            and \"importance_sampling_ratio\" in inputs\n        ):\n            per_token_loss = per_token_loss * inputs[\"importance_sampling_ratio\"]\n\n        if self.beta != 0.0:\n            per_token_loss = per_token_loss + self.beta * per_token_kl\n\n        # --- Aggregate loss ---\n        mode = \"train\" if self.model.training else \"eval\"\n        normalizer = (\n            self.current_gradient_accumulation_steps if mode == \"train\" else 1.0\n        )\n\n        if self.loss_type in (\"grpo\", \"sapo\"):\n            loss = (\n                (per_token_loss * mask).sum(-1) / mask.sum(-1).clamp(min=1.0)\n            ).mean() / normalizer\n        elif self.loss_type == \"bnpo\":\n            loss = (\n                (per_token_loss * mask).sum() / mask.sum().clamp(min=1.0) / normalizer\n            )\n        elif self.loss_type == \"dr_grpo\":\n            loss = (\n                (per_token_loss * mask).sum()\n                / (per_token_loss.size(0) * self.max_completion_length)\n                / normalizer\n            )\n        elif self.loss_type in (\"cispo\", \"dapo\"):\n            norm = inputs[\"num_items_in_batch\"] / self.accelerator.num_processes\n            loss = (per_token_loss * mask).sum() / norm\n        elif self.loss_type == \"luspo\":\n            loss = (per_token_loss * mask.sum(1, keepdim=True)).mean() / normalizer\n        else:\n            raise ValueError(f\"Unknown loss type: {self.loss_type}\")\n\n        # --- Metrics ---\n        completion_token_count = mask.sum().clamp(min=1.0)\n\n        def masked_batch_mean(x):\n            return (\n                x.mean()\n                if x.shape[1] == 1\n                else (x * mask).sum() / completion_token_count\n            )\n\n        if self.beta != 0.0:\n            mean_kl = masked_batch_mean(per_token_kl)\n            self._metrics[mode][\"kl\"].append(\n                self.accelerator.gather(mean_kl).nanmean().item()\n            )\n\n        mean_entropy = masked_batch_mean(entropies)\n        self._metrics[mode][\"entropy\"].append(\n            self.accelerator.gather(mean_entropy).nanmean().item()\n        )\n\n        if self.loss_type in (\"grpo\", \"bnpo\", \"dr_grpo\", \"dapo\", \"luspo\"):\n            is_low = (coef_1 < 1 - self.epsilon_low) & (advantages < 0)\n            is_high = (coef_1 > 1 + self.epsilon_high) & (advantages > 0)\n            is_region = is_low | is_high\n            low_clip = masked_batch_mean(is_low.float())\n            high_clip = masked_batch_mean(is_high.float())\n            clip_ratio = masked_batch_mean(is_region.float())\n            g_low = self.accelerator.gather(low_clip)\n            self._metrics[mode][\"clip_ratio/low_mean\"].append(g_low.nanmean().item())\n            self._metrics[mode][\"clip_ratio/low_min\"].append(nanmin(g_low).item())\n            g_high = self.accelerator.gather(high_clip)\n            self._metrics[mode][\"clip_ratio/high_mean\"].append(g_high.nanmean().item())\n            self._metrics[mode][\"clip_ratio/high_max\"].append(nanmax(g_high).item())\n            g_clip = self.accelerator.gather(clip_ratio)\n            self._metrics[mode][\"clip_ratio/region_mean\"].append(\n                g_clip.nanmean().item()\n            )\n        elif self.loss_type == \"cispo\":\n            is_cispo = (coef_1 > self.epsilon_high) & (advantages > 0)\n            cr = masked_batch_mean(is_cispo.float())\n            self._metrics[mode][\"cispo_clip_ratio\"].append(\n                self.accelerator.gather(cr).nanmean().item()\n            )\n\n        return loss\n"
  },
  {
    "path": "src/axolotl/core/trainers/grpo/fast_async_trainer.py",
    "content": "# Copyright 2020-2026 Axolotl AI. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nExperimental GRPO extensions: parallel reward workers, replay buffer,\ndeferred re-roll, and zero-advantage skipping.\n\nThese features are built as subclasses of GRPOTrainer and GRPODataProducer,\nusing the hook system (_compute_rewards_for_batch, _post_advantage_hook,\n_pre_produce_hook) defined in the base classes.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport asyncio\nimport logging\nimport threading\nfrom dataclasses import dataclass, field\n\nimport torch\nfrom torch import nn\nfrom trl import GRPOTrainer\n\nfrom axolotl.core.trainers.grpo.async_trainer import (\n    AsyncGRPOConfig,\n    AsyncGRPOTrainer,\n    GRPODataProducer,\n)\nfrom axolotl.core.trainers.grpo.replay_buffer import ReplayBuffer\n\nlogger = logging.getLogger(__name__)\n\n\n# ---------------------------------------------------------------------------\n# Extended config\n# ---------------------------------------------------------------------------\n\n\n@dataclass\nclass FastAsyncGRPOConfig(AsyncGRPOConfig):\n    \"\"\"GRPOConfig with additional experimental parameters.\"\"\"\n\n    reward_num_workers: int = field(\n        default=1,\n        metadata={\n            \"help\": \"Number of persistent subprocess workers for parallel reward computation. Each worker has its \"\n            \"own main thread so signal.alarm() (used by math_verify) works correctly. Work is sharded across \"\n            \"workers by prompt groups. Only used with use_data_producer=True and non-nn.Module reward functions.\"\n        },\n    )\n    replay_buffer_size: int = field(\n        default=0,\n        metadata={\n            \"help\": \"[Experimental, disabled by default] Size of the replay buffer for storing high-signal rollout \"\n            \"groups. When > 0, groups with reward variance are cached and used to replace zero-signal groups \"\n            \"(where all rewards are identical). Set to 0 to disable. Only used with use_data_producer=True.\"\n        },\n    )\n    replay_recompute_logps: bool = field(\n        default=True,\n        metadata={\n            \"help\": \"When True (default), recompute old_per_token_logps for replayed groups using the current \"\n            \"training model. This fixes the importance sampling mismatch that occurs when replaying stale data. \"\n            \"Only relevant when replay_buffer_size > 0.\"\n        },\n    )\n    reroll_start_fraction: float = field(\n        default=0.5,\n        metadata={\n            \"help\": \"Fraction of total training steps after which deferred re-rolling begins. Zero-signal prompts \"\n            \"(where all rewards in a group are identical) are buffered and re-injected into later batches when the \"\n            \"model is more likely to solve them. Set to 1.0 to disable. Only used with use_data_producer=True.\"\n        },\n    )\n    reroll_max_groups: int = field(\n        default=1,\n        metadata={\n            \"help\": \"Maximum number of prompt groups to replace with re-roll candidates per batch. Higher values \"\n            \"increase data utilization but reduce prompt diversity. Only used with use_data_producer=True.\"\n        },\n    )\n    skip_zero_advantage_batches: bool = field(\n        default=True,\n        metadata={\n            \"help\": \"When True, skip gradient computation for micro-batches where all advantages are zero (no learning \"\n            \"signal). This avoids the forward/backward pass entirely when no learning signal is present. The step is \"\n            \"logged with skipped_zero_adv_batches=1 for monitoring.\"\n        },\n    )\n    vllm_lora_sync: bool = field(\n        default=False,\n        metadata={\n            \"help\": \"When True, sync LoRA adapter weights to vLLM via filesystem instead of merging into base model \"\n            \"and NCCL-broadcasting all parameters. vLLM loads the adapter natively using Punica kernels. \"\n            \"Requires vllm_serve_lora serve module (auto-selected when this is True). \"\n            \"Syncs only LoRA adapter weights (much smaller) vs full merged model. Legacy merge behavior is used when False.\"\n        },\n    )\n\n\n# ---------------------------------------------------------------------------\n# Extended data producer with re-roll injection\n# ---------------------------------------------------------------------------\n\n\nclass RerollDataProducer(GRPODataProducer):\n    \"\"\"GRPODataProducer that injects re-roll candidates into prompt batches.\n\n    Reads from the trainer's ``_reroll_buffer`` (populated by\n    ``GRPOExperimentalTrainer._post_advantage_hook``) and replaces the\n    last N prompt groups with previously-failed prompts.\n    \"\"\"\n\n    def _pre_produce_hook(self, inputs: list, global_step: int) -> list:\n        trainer = self._trainer\n        reroll_buf = getattr(trainer, \"_reroll_buffer\", None)\n        reroll_lock = getattr(trainer, \"_reroll_lock\", None)\n        if reroll_buf is None or reroll_lock is None:\n            return inputs\n\n        max_steps = getattr(trainer.args, \"max_steps\", -1)\n        start_frac = getattr(trainer.args, \"reroll_start_fraction\", 1.0)\n        max_groups = getattr(trainer.args, \"reroll_max_groups\", 1)\n        reroll_start_step = (\n            max(1, int(max_steps * start_frac)) if max_steps > 0 else float(\"inf\")\n        )\n\n        if global_step < reroll_start_step:\n            return inputs\n\n        with reroll_lock:\n            n_to_take = min(max_groups, len(reroll_buf))\n            reroll_prompts = [reroll_buf.pop(0) for _ in range(n_to_take)]\n\n        if reroll_prompts:\n            num_gen = self._num_generations\n            n_groups = len(inputs) // num_gen\n            for i, reroll_prompt in enumerate(reroll_prompts):\n                group_idx = n_groups - 1 - i\n                if group_idx < 0:\n                    break\n                start = group_idx * num_gen\n                for j in range(num_gen):\n                    inputs[start + j] = reroll_prompt\n            logger.info(\n                f\"[REROLL] Step {global_step}: replaced {len(reroll_prompts)}/{n_groups} prompt groups \"\n                f\"with deferred re-roll candidates ({len(reroll_buf)} remaining)\"\n            )\n\n        return inputs\n\n\n# ---------------------------------------------------------------------------\n# Persistent reward subprocess pool\n# ---------------------------------------------------------------------------\n\n\ndef _persistent_reward_worker(conn):\n    \"\"\"Long-lived reward worker. Receives work items, returns results.\"\"\"\n    while True:\n        try:\n            msg = conn.recv()\n        except EOFError:\n            break\n        if msg is None:  # Shutdown signal\n            break\n        (\n            reward_funcs,\n            prompts,\n            completions,\n            completion_ids_list,\n            inputs,\n            reward_func_names,\n        ) = msg\n        try:\n            keys = [\n                key\n                for key in inputs[0]\n                if key not in [\"prompt\", \"completion\", \"completion_ids\"]\n            ]\n            reward_kwargs = {key: [example[key] for example in inputs] for key in keys}\n            results = []\n            for reward_func, _reward_func_name in zip(\n                reward_funcs, reward_func_names, strict=True\n            ):\n                output = reward_func(\n                    prompts=prompts,\n                    completions=completions,\n                    completion_ids=completion_ids_list,\n                    **reward_kwargs,\n                )\n                results.append(\n                    [float(r) if r is not None else float(\"nan\") for r in output]\n                )\n            conn.send(results)\n        except Exception:\n            conn.send(None)\n\n\n# ---------------------------------------------------------------------------\n# Extended trainer\n# ---------------------------------------------------------------------------\n\n\nclass FastAsyncGRPOTrainer(AsyncGRPOTrainer):\n    \"\"\"GRPOTrainer with experimental extensions.\n\n    Adds:\n    - Parallel reward subprocess workers (``reward_num_workers``)\n    - Replay buffer for high-signal group reuse (``replay_buffer_size``)\n    - Deferred re-roll of failed prompts (``reroll_start_fraction``)\n    - Zero-advantage micro-batch skipping\n    \"\"\"\n\n    def __init__(self, *args, **kwargs):\n        # These must be initialized before super().__init__() because\n        # _create_data_producer (called during super().__init__) needs them.\n        self._reroll_buffer: list = []\n        self._reroll_lock = threading.Lock()\n\n        # Temporarily suppress the base class's Liger + OPSM validation check,\n        # since this subclass supports it via a custom compute_liger_loss override.\n        grpo_args = kwargs.get(\"args\")\n        if grpo_args is None:\n            for a in args:\n                if hasattr(a, \"off_policy_mask_threshold\"):\n                    grpo_args = a\n                    break\n        saved_threshold = None\n        if grpo_args is not None and getattr(grpo_args, \"use_liger_kernel\", False):\n            saved_threshold = grpo_args.off_policy_mask_threshold\n            grpo_args.off_policy_mask_threshold = None\n\n        super().__init__(*args, **kwargs)\n\n        if saved_threshold is not None:\n            grpo_args.off_policy_mask_threshold = saved_threshold\n            self.off_policy_mask_threshold = saved_threshold\n\n        # Replay buffer\n        if getattr(self.args, \"replay_buffer_size\", 0) > 0:\n            self._replay_buffer = ReplayBuffer(max_size=self.args.replay_buffer_size)\n        else:\n            self._replay_buffer = None\n        self._replay_recompute_logps = getattr(\n            self.args, \"replay_recompute_logps\", True\n        )\n\n        # Reward worker pool (lazy-initialized)\n        self._reward_workers = None\n\n    # -- Factory override: use RerollDataProducer ----------------------------\n\n    def _create_data_producer(self, args, train_dataset):\n        \"\"\"Override to use RerollDataProducer for re-roll prompt injection.\"\"\"\n        from axolotl.core.trainers.grpo.async_trainer import (\n            AsyncDataProducer,\n            ProducerConfig,\n        )\n\n        producer_config = ProducerConfig(\n            mini_epochs=args.num_iterations,\n            max_rollouts=None,\n            eval_during_produce=False,\n            empty_cache_before_produce=True,\n            empty_cache_after_produce=True,\n            async_prefetch=args.async_prefetch,\n            prefetch_depth=args.prefetch_depth,\n        )\n        data_producer = RerollDataProducer(\n            config=producer_config,\n            prompt_dataset=train_dataset,\n            num_generations=self.num_generations,\n            generation_batch_size=args.generation_batch_size,\n            train_batch_size=args.per_device_train_batch_size,\n            steps_per_generation=args.steps_per_generation,\n            shuffle_dataset=self.shuffle_dataset,\n            seed=args.seed,\n        )\n        data_producer.set_trainer(self)\n        if args.async_prefetch:\n            data_producer = AsyncDataProducer(\n                data_producer,\n                background_produce_kwargs={\"skip_policy_logps\": True},\n            )\n        return data_producer\n\n    # -- Reward worker pool --------------------------------------------------\n\n    def _get_reward_workers(self):\n        \"\"\"Return a list of persistent reward worker subprocesses (lazy-initialized).\"\"\"\n        import multiprocessing as _mp\n\n        num_workers = getattr(self.args, \"reward_num_workers\", 1)\n        if num_workers < 1:\n            num_workers = 1\n\n        if self._reward_workers is not None:\n            alive = all(proc.is_alive() for conn, proc in self._reward_workers)\n            if alive and len(self._reward_workers) == num_workers:\n                return self._reward_workers\n            self._shutdown_reward_workers()\n\n        workers = []\n        for _ in range(num_workers):\n            parent_conn, child_conn = _mp.Pipe()\n            proc = _mp.Process(\n                target=_persistent_reward_worker, args=(child_conn,), daemon=True\n            )\n            proc.start()\n            child_conn.close()\n            workers.append((parent_conn, proc))\n\n        self._reward_workers = workers\n        return workers\n\n    def _shutdown_reward_workers(self):\n        \"\"\"Shut down all persistent reward workers.\"\"\"\n        if self._reward_workers is None:\n            return\n        for conn, proc in self._reward_workers:\n            try:\n                conn.send(None)\n                proc.join(timeout=5)\n            except Exception:\n                pass\n            try:\n                conn.close()\n            except Exception:\n                pass\n        self._reward_workers = None\n\n    # -- Hook overrides ------------------------------------------------------\n\n    def _compute_rewards_for_batch(\n        self, inputs, prompts, completions, completion_ids_list\n    ):\n        \"\"\"Dispatch rewards to parallel subprocess workers (synchronous wrapper).\"\"\"\n        self._launch_reward_workers(inputs, prompts, completions, completion_ids_list)\n        return self._collect_reward_workers(\n            inputs, prompts, completions, completion_ids_list\n        )\n\n    def _launch_reward_workers(self, inputs, prompts, completions, completion_ids_list):\n        \"\"\"Send reward work to subprocess workers (non-blocking).\n\n        Results are collected later by _collect_reward_workers, allowing GPU\n        logprob computation to overlap with CPU reward computation.\n        \"\"\"\n        reward_can_bg = all(\n            callable(rf)\n            and not isinstance(rf, nn.Module)\n            and not asyncio.iscoroutinefunction(rf)\n            for rf in self.reward_funcs\n        )\n        num_workers = getattr(self.args, \"reward_num_workers\", 1)\n\n        if not reward_can_bg or num_workers <= 1:\n            # Can't parallelize — store args for sync fallback in collect\n            self._reward_workers_used = None\n            self._pending_reward_args = (\n                inputs,\n                prompts,\n                completions,\n                completion_ids_list,\n            )\n            return\n\n        workers = self._get_reward_workers()\n        num_generations = self.num_generations\n        num_prompts = len(prompts)\n        num_groups = num_prompts // num_generations\n\n        # Shard by prompt groups across workers\n        groups_per_worker = max(1, (num_groups + len(workers) - 1) // len(workers))\n        workers_used = []\n        for w_idx, (conn, _proc) in enumerate(workers):\n            g_start = w_idx * groups_per_worker\n            g_end = min((w_idx + 1) * groups_per_worker, num_groups)\n            if g_start >= num_groups:\n                break\n            s_start = g_start * num_generations\n            s_end = g_end * num_generations\n            conn.send(\n                (\n                    self.reward_funcs,\n                    prompts[s_start:s_end],\n                    completions[s_start:s_end],\n                    completion_ids_list[s_start:s_end],\n                    inputs[s_start:s_end],\n                    self.reward_func_names,\n                )\n            )\n            workers_used.append(conn)\n\n        self._reward_workers_used = workers_used\n        self._pending_reward_args = (inputs, prompts, completions, completion_ids_list)\n\n    def _collect_reward_workers(\n        self, inputs, prompts, completions, completion_ids_list\n    ):\n        \"\"\"Collect reward results from subprocess workers (blocks until done).\"\"\"\n        from accelerate.utils import gather\n\n        workers_used = getattr(self, \"_reward_workers_used\", None)\n        args = getattr(self, \"_pending_reward_args\", None)\n        self._reward_workers_used = None\n        self._pending_reward_args = None\n\n        if workers_used is None:\n            # Sync fallback — compute on main thread\n            if args is not None:\n                return self._calculate_rewards(*args)\n            return self._calculate_rewards(\n                inputs, prompts, completions, completion_ids_list\n            )\n\n        device = self.accelerator.device\n        num_prompts = len(args[1]) if args else len(prompts)\n\n        # Collect results from workers\n        all_worker_results = []\n        any_failed = False\n        for conn in workers_used:\n            result = conn.recv()\n            if result is None:\n                any_failed = True\n                # Drain remaining workers to prevent stale results in pipes\n                for remaining_conn in workers_used:\n                    if remaining_conn is not conn:\n                        try:\n                            remaining_conn.recv()\n                        except Exception:\n                            pass\n                break\n            all_worker_results.append(result)\n\n        if not any_failed:\n            rewards_per_func = torch.zeros(\n                num_prompts, len(self.reward_funcs), device=device\n            )\n            offset = 0\n            for worker_result in all_worker_results:\n                chunk_size = len(worker_result[0])\n                for i, result in enumerate(worker_result):\n                    rewards_per_func[offset : offset + chunk_size, i] = torch.tensor(\n                        result, dtype=torch.float32, device=device\n                    )\n                offset += chunk_size\n            return gather(rewards_per_func)\n\n        # Fallback to main thread on failure\n        if args is not None:\n            return self._calculate_rewards(*args)\n        return self._calculate_rewards(\n            inputs, prompts, completions, completion_ids_list\n        )\n\n    def _post_advantage_hook(\n        self,\n        data: dict,\n        rewards_per_func,\n        advantages,\n        inputs: list,\n        num_generations: int,\n        mode: str,\n        s_start: int | None = None,\n        s_end: int | None = None,\n        is_last_chunk: bool = True,\n    ) -> None:\n        \"\"\"Replay buffer store/replace + re-roll buffering.\"\"\"\n        from trl.models.utils import disable_gradient_checkpointing\n\n        # -- Replay buffer: store high-signal groups --\n        if self._replay_buffer is not None:\n            local_grouped = rewards_per_func.view(\n                -1, num_generations, len(self.reward_funcs)\n            )\n            per_group_std = local_grouped.std(dim=1)\n            has_signal = (per_group_std > 0).any(dim=1)\n            offset = s_start or 0\n\n            if has_signal.any():\n                grouped_adv = advantages.view(-1, num_generations)\n                replay_scores = grouped_adv.abs().sum(dim=1) * per_group_std.sum(dim=1)\n                for group_idx in has_signal.nonzero(as_tuple=True)[0]:\n                    gi = group_idx.item()\n                    start = offset + gi * num_generations\n                    end = start + num_generations\n                    group_data = {}\n                    for key in data:\n                        val = data[key]\n                        if (\n                            isinstance(val, torch.Tensor)\n                            and val.dim() > 0\n                            and val.size(0) >= end\n                        ):\n                            group_data[key] = val[start:end].clone()\n                    self._replay_buffer.add(replay_scores[gi].item(), group_data)\n\n            # Replace zero-signal groups with high-signal replay buffer entries\n            # Only in non-streaming path (s_start is None) — streaming scores\n            # groups incrementally, so replacement + logprob recompute would be\n            # too expensive per chunk.\n            n_replaced = 0\n            if s_start is None:\n                no_signal = ~has_signal\n                replaced_ranges = []\n                if no_signal.any() and len(self._replay_buffer) > 0:\n                    for group_idx in no_signal.nonzero(as_tuple=True)[0]:\n                        sampled = self._replay_buffer.sample(1)\n                        if sampled is None:\n                            break\n                        sampled_group = sampled[0]\n                        gi = group_idx.item()\n                        start = offset + gi * num_generations\n                        end = start + num_generations\n                        for key, val in sampled_group.items():\n                            if key in data and isinstance(data[key], torch.Tensor):\n                                src = val.to(data[key].device)\n                                tgt_seq_len = (\n                                    data[key].size(1) if data[key].dim() > 1 else None\n                                )\n                                if start >= data[key].size(0) or end > data[key].size(\n                                    0\n                                ):\n                                    continue\n                                if tgt_seq_len is not None:\n                                    if src.size(1) <= tgt_seq_len:\n                                        data[key][start:end] = 0\n                                        data[key][start:end, : src.size(1)] = src\n                                    else:\n                                        data[key][start:end] = src[:, :tgt_seq_len]\n                                else:\n                                    data[key][start:end] = src\n                        replaced_ranges.append((start, end))\n                        n_replaced += 1\n\n                # Recompute old_per_token_logps for replayed groups\n                if (\n                    n_replaced > 0\n                    and self._replay_recompute_logps\n                    and \"old_per_token_logps\" in data\n                ):\n                    with (\n                        torch.no_grad(),\n                        disable_gradient_checkpointing(\n                            self.model, self.args.gradient_checkpointing_kwargs\n                        ),\n                    ):\n                        for r_start, r_end in replaced_ranges:\n                            r_ids = torch.cat(\n                                [\n                                    data[\"prompt_ids\"][r_start:r_end],\n                                    data[\"completion_ids\"][r_start:r_end],\n                                ],\n                                dim=1,\n                            )\n                            r_mask = torch.cat(\n                                [\n                                    data[\"prompt_mask\"][r_start:r_end],\n                                    data[\"completion_mask\"][r_start:r_end],\n                                ],\n                                dim=1,\n                            )\n                            r_logits_to_keep = data[\"completion_ids\"].size(1)\n                            r_fwd_kwargs = {}\n                            for fk in (\n                                \"pixel_values\",\n                                \"image_grid_thw\",\n                                \"pixel_attention_mask\",\n                                \"image_sizes\",\n                                \"token_type_ids\",\n                                \"mm_token_type_ids\",\n                            ):\n                                if fk in data:\n                                    r_fwd_kwargs[fk] = data[fk]\n                            r_logps, _ = self._get_per_token_logps_and_entropies(\n                                self.model,\n                                r_ids,\n                                r_mask,\n                                r_logits_to_keep,\n                                r_end - r_start,\n                                **r_fwd_kwargs,\n                            )\n                            data[\"old_per_token_logps\"][r_start:r_end] = r_logps\n\n                if n_replaced > 0:\n                    self._metrics[mode][\"replay_buffer_replacements\"].append(\n                        float(n_replaced)\n                    )\n\n            if is_last_chunk:\n                self._metrics[mode][\"replay_buffer_size\"].append(\n                    float(len(self._replay_buffer))\n                )\n\n        # -- Re-roll buffer: store failed prompts --\n        if getattr(self.args, \"reroll_start_fraction\", 1.0) < 1.0:\n            grouped_rewards = rewards_per_func.view(\n                -1, num_generations, len(self.reward_funcs)\n            )\n            per_group_std = grouped_rewards.std(dim=1)\n            per_group_mean = grouped_rewards.mean(dim=1)\n            zero_signal = (per_group_std == 0).all(dim=1)\n            all_failed = (per_group_mean.abs() < 1e-6).all(dim=1)\n            should_reroll = zero_signal & all_failed\n            _n_buffered = 0\n            with self._reroll_lock:\n                for group_idx in should_reroll.nonzero(as_tuple=True)[0]:\n                    idx = group_idx.item() * num_generations\n                    if idx >= len(inputs):\n                        continue\n                    prompt_input = inputs[idx]\n                    self._reroll_buffer.append(prompt_input)\n                    _n_buffered += 1\n            if _n_buffered > 0:\n                self._metrics[mode][\"reroll_buffered\"].append(float(_n_buffered))\n            if is_last_chunk:\n                self._metrics[mode][\"reroll_buffer_size\"].append(\n                    float(len(self._reroll_buffer))\n                )\n\n    # -- Zero-advantage skipping + Liger OPSM ---------------------------------\n\n    def compute_liger_loss(self, unwrapped_model, inputs):\n        \"\"\"Liger loss with zero-adv skipping and off-policy sequence masking (OPSM).\n\n        The base class Liger path doesn't support OPSM because the fused kernel\n        doesn't expose per-token logprobs needed for the KL computation. This\n        override computes them via chunked lm_head matmul (no grad, low memory)\n        and applies the OPSM to the loss mask before calling the kernel.\n        \"\"\"\n        if self.args.skip_zero_advantage_batches and torch.all(\n            inputs[\"advantages\"] == 0\n        ):\n            mode = \"train\" if self.model.training else \"eval\"\n            self._metrics[mode][\"skipped_zero_adv_batches\"].append(1.0)\n            return torch.tensor(\n                0.0, device=inputs[\"advantages\"].device, requires_grad=True\n            )\n\n        if self.off_policy_mask_threshold is None:\n            return super().compute_liger_loss(unwrapped_model, inputs)\n\n        # OPSM path: need per_token_logps for KL, which Liger kernel doesn't provide\n        prompt_ids, prompt_mask = inputs[\"prompt_ids\"], inputs[\"prompt_mask\"]\n        completion_ids, completion_mask = (\n            inputs[\"completion_ids\"],\n            inputs[\"completion_mask\"],\n        )\n        input_ids = torch.cat([prompt_ids, completion_ids], dim=1)\n        attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)\n        logits_to_keep = completion_ids.size(1)\n\n        last_hidden_state = self._get_last_hidden_state(\n            unwrapped_model,\n            input_ids,\n            attention_mask,\n            logits_to_keep,\n            inputs.get(\"pixel_values\"),\n            inputs.get(\"image_grid_thw\"),\n            inputs.get(\"pixel_attention_mask\"),\n            inputs.get(\"image_sizes\"),\n        )\n\n        loss_mask = (\n            completion_mask\n            if \"tool_mask\" not in inputs\n            else completion_mask * inputs[\"tool_mask\"]\n        )\n\n        # Compute per_token_logps via chunked lm_head matmul (no grad, low memory)\n        lm_weight = unwrapped_model.lm_head.weight\n        lm_bias = unwrapped_model.lm_head.bias\n        with torch.no_grad():\n            per_token_logps_chunks = []\n            for i in range(last_hidden_state.size(0)):\n                chunk_logits = torch.matmul(last_hidden_state[i : i + 1], lm_weight.t())\n                if lm_bias is not None:\n                    chunk_logits = chunk_logits + lm_bias\n                chunk_lps = (\n                    chunk_logits.float()\n                    .log_softmax(-1)\n                    .gather(-1, completion_ids[i : i + 1].unsqueeze(-1))\n                    .squeeze(-1)\n                )\n                per_token_logps_chunks.append(chunk_lps)\n                del chunk_logits\n            per_token_logps = torch.cat(per_token_logps_chunks, dim=0)\n\n        advantages = inputs[\"advantages\"]\n        if advantages.dim() == 1:\n            advantages_2d = advantages.unsqueeze(1)\n        else:\n            advantages_2d = advantages\n\n        sampling_per_token_logps = inputs.get(\"sampling_per_token_logps\")\n        if sampling_per_token_logps is None:\n            sampling_per_token_logps = inputs.get(\"old_per_token_logps\")\n        if sampling_per_token_logps is None:\n            sampling_per_token_logps = per_token_logps\n\n        off_policy_mask = GRPOTrainer.get_off_policy_mask(\n            advantages=advantages_2d,\n            per_token_logps=per_token_logps,\n            sampling_per_token_logps=sampling_per_token_logps,\n            mask=loss_mask,\n            off_policy_threshold=self.off_policy_mask_threshold,\n        )\n        loss_mask = loss_mask * off_policy_mask\n\n        # Call the Liger fused kernel with OPSM-modified mask\n        loss, metrics = self.liger_grpo_loss(\n            _input=last_hidden_state,\n            lin_weight=unwrapped_model.lm_head.weight,\n            selected_token_ids=completion_ids,\n            attention_mask=loss_mask,\n            advantages=inputs[\"advantages\"],\n            bias=unwrapped_model.lm_head.bias,\n            old_per_token_logps=inputs.get(\"old_per_token_logps\"),\n            ref_per_token_logps=inputs.get(\"ref_per_token_logps\"),\n            vllm_is_ratio=inputs.get(\"importance_sampling_ratio\"),\n        )\n\n        mean_kl = metrics[0] if self.beta != 0.0 else None\n        clip_ratio = metrics[-1]\n\n        mode = \"train\" if self.model.training else \"eval\"\n        if self.beta != 0.0:\n            self._metrics[mode][\"kl\"].append(\n                self.accelerator.gather(mean_kl).mean().item()\n            )\n        self._metrics[mode][\"clip_ratio\"].append(\n            self.accelerator.gather(clip_ratio).mean().item()\n        )\n        normalizer = (\n            self.current_gradient_accumulation_steps if mode == \"train\" else 1.0\n        )\n        return loss / normalizer\n\n    def _compute_loss(self, model, inputs):\n        if self.args.skip_zero_advantage_batches and torch.all(\n            inputs[\"advantages\"] == 0\n        ):\n            mode = \"train\" if self.model.training else \"eval\"\n            self._metrics[mode][\"skipped_zero_adv_batches\"].append(1.0)\n            # Create zero loss with grad_fn. DeepSpeed requires grad_fn != None.\n            # With ZeRO-3, parameters are partitioned (shape=[0], requires_grad=False)\n            # so we can't just do `(p * 0).sum()`. Instead, do a tiny forward pass\n            # with a single token to create a proper computation graph.\n            prompt_ids = inputs[\"prompt_ids\"][:1, :1]  # (1, 1)\n            attn = torch.ones_like(prompt_ids)\n            with torch.amp.autocast(device_type=\"cuda\", dtype=torch.bfloat16):\n                out = model(input_ids=prompt_ids, attention_mask=attn)\n            return out.logits.sum() * 0\n        return super()._compute_loss(model, inputs)\n"
  },
  {
    "path": "src/axolotl/core/trainers/grpo/replay_buffer.py",
    "content": "\"\"\"Simple replay buffer for storing and sampling high-signal rollout groups.\"\"\"\n\nimport heapq\n\nimport torch\n\n\nclass ReplayBuffer:\n    \"\"\"Min-heap replay buffer that keeps the highest-scoring rollout groups.\n    Groups are scored by signal quality (advantage magnitude * reward variance).\n    When sampling, groups are drawn proportional to their scores.\n    \"\"\"\n\n    def __init__(self, max_size: int):\n        self.max_size = max_size\n        self._heap: list[tuple[float, int, dict]] = []  # min-heap of (score, id, data)\n        self._counter = 0  # unique tiebreaker for heap\n\n    def __len__(self):\n        return len(self._heap)\n\n    def add(self, score: float, data: dict):\n        \"\"\"Add a group to the buffer. If full, replaces lowest-scoring entry.\"\"\"\n        if self.max_size <= 0:\n            return\n        self._counter += 1\n        if len(self._heap) < self.max_size:\n            heapq.heappush(self._heap, (score, self._counter, data))\n        elif score > self._heap[0][0]:\n            heapq.heapreplace(self._heap, (score, self._counter, data))\n\n    def sample(self, num_samples: int) -> list[dict] | None:\n        \"\"\"Sample groups weighted by their scores. Returns None if buffer is empty.\"\"\"\n        if self.max_size <= 0 or not self._heap:\n            return None\n\n        scores = torch.tensor([item[0] for item in self._heap], dtype=torch.float32)\n        scores = scores.clamp(min=1e-8)  # avoid zero probabilities\n        probs = scores / scores.sum()\n        replacement = num_samples > len(self._heap)\n        indices = torch.multinomial(\n            probs, num_samples, replacement=replacement\n        ).tolist()\n        return [self._heap[i][2] for i in indices]\n"
  },
  {
    "path": "src/axolotl/core/trainers/grpo/sampler.py",
    "content": "\"\"\"Repeat random sampler (similar to the one implemented in\nhttps://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py) that adds\nsequence parallelism functionality; i.e., duplicating data across ranks in the same\nsequence parallel group.\n\"\"\"\n\nfrom typing import Iterator, Sized\n\nimport torch\nfrom torch.utils.data import Sampler\n\n\nclass SequenceParallelRepeatRandomSampler(Sampler):\n    \"\"\"Sampler for GRPO training with sequence parallelism.\n\n    This sampler ensures:\n    - Ranks in the same sequence parallel (SP) group receive identical data.\n    - Each index is repeated multiple times for sampling different completions.\n    - Entire batches are repeated for reuse in multiple updates.\n    - Data is properly distributed across SP groups.\n\n    In the table below, the values represent dataset indices. Each SP group has\n    `context_parallel_size = 2` GPUs working together on the same data. There are 2\n    SP groups (SP0 and SP1), with `world_size = 4` total GPUs.\n\n                                               Sequence Parallel Groups\n                                        |       SP0        |       SP1        |\n                                        |  GPU 0  |  GPU 1 |  GPU 2  |  GPU 3 |\n                    global_step  step    <---> mini_repeat_count=3\n                                            <----------> batch_size=2 per SP group\n    grad_accum=2   ▲  ▲  0       0         [0 0 0  1 1 1]     [2 2 2  3 3 3]   <- SP groups get different data\n                   ▼  |  0       1         [0 0 0  1 1 1]     [2 2 2  3 3 3]   <- Same data for each SP group GPU\n                      |\n                      |  1       2         [0 0 0  1 1 1]     [2 2 2  3 3 3]   <- Repeat same indices for iterations\n    num_iterations=2  ▼  1       3         [0 0 0  1 1 1]     [2 2 2  3 3 3]   <- When using gradient accumulation\n\n                         2       4         [4 4 4  5 5 5]     [6 6 6  7 7 7]   <- New batch of data indices\n                         2       5         [4 4 4  5 5 5]     [6 6 6  7 7 7]\n                                            ...\n\n    Args:\n        dataset: Dataset to sample from.\n        mini_repeat_count: How many times to repeat each sample immediately.\n        world_size: Total number of processes.\n        rank: Rank of current process.\n        batch_size: Number of samples per batch.\n        repeat_count: How many times to repeat the full sampling process.\n        context_parallel_size: Number of ranks in a sequence parallel group.\n        shuffle: Whether to shuffle the dataset.\n        seed: Random seed for shuffling.\n        drop_last: Whether to drop the last incomplete batch.\n    \"\"\"\n\n    def __init__(\n        self,\n        dataset: Sized,\n        mini_repeat_count: int,\n        world_size: int,\n        rank: int,\n        batch_size: int = 1,\n        repeat_count: int = 1,\n        context_parallel_size: int = 1,\n        shuffle: bool = True,\n        seed: int = 0,\n        drop_last: bool = False,\n    ):\n        self.dataset = dataset\n        self.mini_repeat_count = mini_repeat_count\n        self.batch_size = batch_size\n        self.repeat_count = repeat_count\n        self.shuffle = shuffle\n        self.seed = seed\n        self.drop_last = drop_last\n        self.epoch = 0\n\n        self.world_size = world_size\n        self.rank = rank\n\n        # Sequence parallelism parameters\n        self.context_parallel_size = context_parallel_size\n        self.num_sp_groups = world_size // context_parallel_size\n        self.sp_group_id = rank // context_parallel_size\n\n        # Adjust dataset size for distributed sampling\n        self.num_samples = len(self.dataset)\n        self.total_size = self.num_samples\n\n        # Calculate effective number of samples per SP group\n        if (\n            self.drop_last\n            and self.total_size % (self.num_sp_groups * self.batch_size) != 0\n        ):\n            # Drop last incomplete batch if drop_last is True\n            self.num_samples_per_sp_group = (\n                self.total_size // self.batch_size // self.num_sp_groups\n            ) * self.batch_size\n        else:\n            # Round up to include last batch if drop_last is False\n            self.num_samples_per_sp_group = (\n                (self.total_size + self.batch_size * self.num_sp_groups - 1)\n                // (self.batch_size * self.num_sp_groups)\n                * self.batch_size\n            )\n\n        if shuffle:\n            self.generator = torch.Generator()\n            self.generator.manual_seed(seed)\n\n    def __iter__(self) -> Iterator[int]:\n        \"\"\"Creates iterator over dataset indices.\n\n        Returns:\n            Iterator that yields indices into the dataset.\n        \"\"\"\n        # Deterministically shuffle based on epoch and seed\n        if self.shuffle:\n            indices = torch.randperm(\n                self.num_samples, generator=self.generator\n            ).tolist()\n        else:\n            indices = list(range(self.num_samples))\n\n        # Add extra samples to make it evenly divisible by batch_size\n        if len(indices) % self.batch_size != 0:\n            padding = indices[: self.batch_size - len(indices) % self.batch_size]\n            indices += padding\n\n        # Subsample based on SP group ID\n        # Each SP group gets distinct batches of data\n        batch_indices = []\n        for i in range(0, len(indices), self.batch_size * self.num_sp_groups):\n            start_idx = i + self.sp_group_id * self.batch_size\n            end_idx = min(start_idx + self.batch_size, len(indices))\n            if start_idx < len(indices):\n                for j in range(self.batch_size):\n                    if start_idx + j < end_idx:\n                        batch_indices.append(indices[start_idx + j])\n\n        # Make sure batch_indices is exactly batch_size * num_batches_per_sp_group\n        if self.drop_last:\n            num_batches_per_sp_group = self.num_samples_per_sp_group // self.batch_size\n            target_len = self.batch_size * num_batches_per_sp_group\n            if len(batch_indices) > target_len:\n                batch_indices = batch_indices[:target_len]\n\n        # Apply the GRPO repeat pattern\n        final_indices = []\n        for _ in range(self.repeat_count):\n            for idx in batch_indices:\n                for _ in range(self.mini_repeat_count):\n                    final_indices.append(idx)\n\n        return iter(final_indices)\n\n    def __len__(self) -> int:\n        \"\"\"Returns the total length of the iterable including repetitions.\n\n        Returns:\n            Total number of samples.\n        \"\"\"\n        # Total length including all repetitions\n        return (\n            self.num_samples_per_sp_group * self.mini_repeat_count * self.repeat_count\n        )\n\n    def set_epoch(self, epoch: int) -> None:\n        \"\"\"Sets the epoch for this sampler.\n\n        Args:\n            epoch: Epoch number to use for shuffling.\n        \"\"\"\n        self.epoch = epoch\n"
  },
  {
    "path": "src/axolotl/core/trainers/grpo/trainer.py",
    "content": "\"\"\"Axolotl GRPO trainers (with and without sequence parallelism handling)\"\"\"\n\nimport warnings\nfrom functools import partial\nfrom typing import Any\n\nimport datasets\nimport torch\nimport torch.distributed as dist\nimport torch.utils.data\nfrom accelerate.utils import (\n    broadcast_object_list,\n    gather,\n    gather_object,\n    is_peft_available,\n)\nfrom datasets import Dataset, IterableDataset\nfrom torch import nn\nfrom torch.utils.data import (\n    BatchSampler,\n    DataLoader,\n    Sampler,\n)\nfrom transformers import (\n    PreTrainedModel,\n    PreTrainedTokenizerBase,\n    Trainer,\n    TrainerCallback,\n)\nfrom transformers.trainer_utils import seed_worker\nfrom trl import GRPOTrainer\nfrom trl.data_utils import (\n    apply_chat_template,\n    is_conversational,\n    maybe_apply_chat_template,\n)\nfrom trl.extras.profiling import profiling_context\nfrom trl.models import unwrap_model_for_generation\nfrom trl.trainer.grpo_config import GRPOConfig\nfrom trl.trainer.grpo_trainer import RewardFunc, nanstd\nfrom trl.trainer.utils import pad\n\nfrom axolotl.core.trainers.grpo.fast_async_trainer import FastAsyncGRPOTrainer\nfrom axolotl.core.trainers.grpo.sampler import SequenceParallelRepeatRandomSampler\nfrom axolotl.core.trainers.mixins import (\n    DistributedParallelMixin,\n    RngLoaderMixin,\n    SchedulerMixin,\n)\nfrom axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, OptimizerMixin\nfrom axolotl.monkeypatch.ring_attn import get_ring_attn_group\n\nif is_peft_available():\n    from peft import PeftConfig\n\n\nclass AxolotlGRPOTrainer(\n    RngLoaderMixin,\n    SchedulerMixin,\n    OptimizerMixin,\n    OptimizerInitMixin,\n    DistributedParallelMixin,\n    GRPOTrainer,\n):\n    \"\"\"Extend the base GRPOTrainer for axolotl helpers\"\"\"\n\n    _tag_names = [\"trl\", \"grpo\", \"axolotl\"]\n\n\nclass AxolotlAsyncGRPOTrainer(\n    RngLoaderMixin,\n    SchedulerMixin,\n    OptimizerMixin,\n    OptimizerInitMixin,\n    DistributedParallelMixin,\n    FastAsyncGRPOTrainer,\n):\n    \"\"\"Extend AsyncGRPOTrainer with axolotl helpers\"\"\"\n\n    _tag_names = [\"trl\", \"grpo\", \"async\", \"axolotl\"]\n\n\nclass AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):\n    \"\"\"Extend the base GRPOTrainer for sequence parallelism handling\"\"\"\n\n    def __init__(\n        self,\n        model: str | PreTrainedModel,\n        reward_funcs: RewardFunc | list[RewardFunc],\n        args: GRPOConfig | None = None,\n        train_dataset: Dataset | IterableDataset | None = None,\n        eval_dataset: (\n            Dataset | IterableDataset | dict[str, Dataset | IterableDataset] | None\n        ) = None,\n        processing_class: PreTrainedTokenizerBase | None = None,\n        reward_processing_classes: (\n            PreTrainedTokenizerBase | list[PreTrainedTokenizerBase] | None\n        ) = None,\n        callbacks: list[TrainerCallback] | None = None,\n        optimizers: tuple[\n            torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None\n        ] = (None, None),\n        peft_config: \"PeftConfig | None\" = None,\n        optimizer_cls_and_kwargs: tuple[type, dict] | None = None,\n    ):\n        # First call the superclass constructor with all arguments\n        super().__init__(\n            model=model,\n            reward_funcs=reward_funcs,\n            args=args,\n            train_dataset=train_dataset,\n            eval_dataset=eval_dataset,\n            processing_class=processing_class,\n            reward_processing_classes=reward_processing_classes,\n            callbacks=callbacks,\n            optimizers=optimizers,\n            peft_config=peft_config,\n            optimizer_cls_and_kwargs=optimizer_cls_and_kwargs,\n        )\n\n        # Get number of SP groups (number of processes divided by SP degree)\n        num_processes = self.accelerator.num_processes\n        num_sp_groups = num_processes // self.args.context_parallel_size\n\n        # Calculate batch size per SP group (not per process)\n        sp_group_batch_size = self.args.per_device_train_batch_size * num_sp_groups\n        possible_values = [\n            n_gen\n            for n_gen in range(2, sp_group_batch_size + 1)\n            if (sp_group_batch_size) % n_gen == 0\n        ]\n\n        if self.num_generations not in possible_values:\n            raise ValueError(\n                f\"The batch size per SP group ({num_sp_groups} x \"\n                f\"{self.args.per_device_train_batch_size}) must be evenly divisible by \"\n                f\"the number of generations per prompt ({self.num_generations}). Given \"\n                \"the current configuration, the valid values for the number of \"\n                f\"generations are: {possible_values}.\"\n            )\n\n        if self.args.eval_strategy != \"no\":\n            # If sequence parallelism is enabled, calculate batch size per SP group\n            sp_group_eval_batch_size = args.per_device_eval_batch_size * num_sp_groups  # type: ignore[union-attr]\n            possible_values = [\n                n_gen\n                for n_gen in range(2, sp_group_eval_batch_size + 1)\n                if (sp_group_eval_batch_size) % n_gen == 0\n            ]\n\n            if self.num_generations not in possible_values:\n                raise ValueError(\n                    f\"With sequence parallelism (degree {self.args.context_parallel_size}), \"\n                    f\"the eval batch size per SP group ({num_sp_groups} x {self.args.per_device_eval_batch_size}) \"\n                    f\"must be evenly divisible by the number of generations per prompt \"\n                    f\"({self.num_generations}). Given the current eval batch size, \"\n                    f\"the valid values for the number of generations are: {possible_values}.\"\n                )\n\n        self.sp_group = None\n        self.rank = dist.get_rank()\n        self.world_size = dist.get_world_size()\n        self.local_rank = 0\n        self.local_world_size = 1\n\n    def train(self, *args, **kwargs):\n        # Initialize the SP group\n        self.sp_group = get_ring_attn_group()\n        self.rank = dist.get_rank()\n        self.world_size = dist.get_world_size()\n        self.local_rank = dist.get_rank(group=self.sp_group)\n        self.local_world_size = dist.get_world_size(group=self.sp_group)\n\n        return super().train(*args, **kwargs)\n\n    def _get_train_sampler(self) -> Sampler:\n        effective_batch_size = (\n            self.args.per_device_train_batch_size\n            * self.world_size\n            * self.args.gradient_accumulation_steps\n        )\n\n        return SequenceParallelRepeatRandomSampler(\n            dataset=self.train_dataset,\n            mini_repeat_count=self.num_generations,\n            world_size=self.world_size,\n            rank=self.rank,\n            batch_size=effective_batch_size\n            // self.num_generations\n            // self.args.context_parallel_size,\n            repeat_count=self.num_iterations * self.args.gradient_accumulation_steps,\n            context_parallel_size=self.args.context_parallel_size,\n            shuffle=True,\n            seed=self.args.seed,\n            drop_last=True,\n        )\n\n    def _create_dataloader_params(self, is_eval=False, custom_batch_size=None):\n        \"\"\"Create common dataloader parameters for train or eval.\"\"\"\n        batch_size = custom_batch_size or (\n            self.args.eval_batch_size if is_eval else self._train_batch_size\n        )\n\n        params = {\n            \"batch_size\": batch_size,\n            \"collate_fn\": self.data_collator,\n            \"num_workers\": self.args.dataloader_num_workers,\n            \"pin_memory\": self.args.dataloader_pin_memory,\n        }\n\n        # Add persistent workers only for training\n        if not is_eval and hasattr(self.args, \"dataloader_persistent_workers\"):\n            params[\"persistent_workers\"] = self.args.dataloader_persistent_workers\n\n        # Add prefetch factor if specified\n        if self.args.dataloader_prefetch_factor:\n            params[\"prefetch_factor\"] = self.args.dataloader_prefetch_factor\n\n        return params\n\n    def _prepare_dataloader(\n        self, dataset, sampler, is_eval=False, custom_batch_size=None\n    ):\n        \"\"\"Prepare a dataloader with the given dataset and sampler.\"\"\"\n        # Get base parameters\n        dataloader_params = self._create_dataloader_params(is_eval, custom_batch_size)\n\n        # Add sampler configuration\n        if not isinstance(dataset, torch.utils.data.IterableDataset):\n            if isinstance(sampler, BatchSampler):\n                # batch_size and batch_sampler are mutually exclusive\n                dataloader_params[\"batch_sampler\"] = sampler\n                del dataloader_params[\"batch_size\"]\n            else:\n                dataloader_params[\"sampler\"] = sampler\n                dataloader_params[\"drop_last\"] = self.args.dataloader_drop_last\n\n            if not is_eval:\n                dataloader_params[\"worker_init_fn\"] = partial(\n                    seed_worker,\n                    num_workers=self.args.dataloader_num_workers,\n                    rank=self.args.process_index,\n                )\n\n        # Create the dataloader\n        dataloader = DataLoader(dataset, **dataloader_params)\n\n        if self.args.sample_packing and (\n            (not is_eval and not self.args.pretraining)\n            or (is_eval and self.args.eval_sample_packing is not False)\n        ):\n            self.accelerator.even_batches = False\n\n        # Return unprepared dataloader if using sequence parallelism\n        # TODO(djsaunde): We might be able to use `accelerate`'s dataloader preparation\n        # if we use `dispatch_batches` and `slice_fn_for_dispatch` properly (i.e.,\n        # slice each batch along the sequence dimension).\n        if self.args.context_parallel_size > 1:\n            return dataloader\n\n        # Otherwise prepare with accelerator\n        return self.accelerator.prepare_data_loader(dataloader)\n\n    def get_train_dataloader(self) -> DataLoader:\n        \"\"\"Get dataloader for training\"\"\"\n        train_dataset = self.train_dataset\n\n        data_collator = self.data_collator  # type: ignore\n\n        # Handle dataset preprocessing\n        if isinstance(train_dataset, datasets.Dataset):\n            # Add debug print before any modifications\n            if self.args.sample_packing and not self.args.pretraining:\n                train_dataset = train_dataset.remove_columns([\"length\"])\n            if not self.args.sample_packing or self.args.pretraining:\n                train_dataset = self._remove_unused_columns(\n                    train_dataset, description=\"training\"\n                )\n        else:\n            self.data_collator = self._get_collator_with_removed_columns(\n                data_collator,\n                description=\"training\",\n            )\n\n        # Get sampler and create dataloader\n        sampler = self._get_train_sampler()\n        dataloader = self._prepare_dataloader(train_dataset, sampler, is_eval=False)\n\n        return dataloader\n\n    def _generate_and_score_completions(\n        self, inputs: list[dict[str, torch.Tensor | Any]]\n    ) -> dict[str, torch.Tensor | Any]:\n        device = self.accelerator.device\n        mode = \"eval\" if self.control.should_evaluate else \"train\"\n\n        prompts = [x[\"prompt\"] for x in inputs]\n        prompts_text = [\n            maybe_apply_chat_template(example, self.processing_class)[\"prompt\"]\n            for example in inputs\n        ]\n        prompt_inputs = self.processing_class(\n            text=prompts_text,\n            return_tensors=\"pt\",\n            padding=True,\n            padding_side=\"left\",\n            add_special_tokens=False,\n        )\n        prompt_inputs = Trainer._prepare_inputs(self, prompt_inputs)\n        prompt_ids, prompt_mask = (\n            prompt_inputs[\"input_ids\"],\n            prompt_inputs[\"attention_mask\"],\n        )\n\n        if self.max_prompt_length is not None:\n            prompt_ids = prompt_ids[:, -self.max_prompt_length :]\n            prompt_mask = prompt_mask[:, -self.max_prompt_length :]\n\n        # Generate completions using either vLLM or regular generation\n        if self.args.use_vllm:\n            # First, have main process load weights if needed\n\n            if self.state.global_step != self._last_loaded_step:  # type: ignore[has-type]\n                self._move_model_to_vllm()\n\n                self._last_loaded_step = self.state.global_step\n\n            # Generate completions using vLLM: gather all prompts and use them in a single call in the main process\n            all_prompts_text = gather_object(prompts_text)\n            if self.accelerator.is_main_process:\n                if self.args.context_parallel_size > 1:\n                    # Calculate sequence parallel group information\n                    world_size = self.accelerator.num_processes\n                    context_parallel_size = self.args.context_parallel_size\n                    num_sp_groups = world_size // context_parallel_size\n\n                    # Since processes in the same SP group have the same prompts, we need to ensure\n                    # we only take one copy of each prompt from each SP group\n                    ordered_set_of_prompts = []\n                    for sp_group_id in range(num_sp_groups):\n                        # Get the first process from each SP group (typically the group leader)\n                        group_leader_rank = sp_group_id * context_parallel_size\n\n                        # Extract prompts from this SP group, accounting for num_generations duplicates\n                        # We only need prompts from one rank in each SP group\n                        group_prompts = all_prompts_text[\n                            group_leader_rank * len(prompts_text) : (\n                                group_leader_rank + 1\n                            )\n                            * len(prompts_text) : self.num_generations\n                        ]\n\n                        ordered_set_of_prompts.extend(group_prompts)\n                else:\n                    # Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate\n                    # num_generations outputs for each one. This is faster than generating outputs for each duplicate\n                    # prompt individually.\n                    ordered_set_of_prompts = all_prompts_text[\n                        :: self.num_generations * self.args.context_parallel_size\n                    ]\n\n                with profiling_context(self, \"vLLM.generate\"):\n                    completion_ids = self.vllm_client.generate(\n                        prompts=ordered_set_of_prompts,\n                        n=self.num_generations,\n                        repetition_penalty=self.repetition_penalty,\n                        temperature=self.temperature,\n                        top_p=self.top_p,\n                        top_k=-1 if self.top_k is None else self.top_k,\n                        min_p=0.0 if self.min_p is None else self.min_p,\n                        max_tokens=self.max_completion_length,\n                        guided_decoding_regex=self.guided_decoding_regex,\n                    )\n            else:\n                completion_ids = [None] * (\n                    len(all_prompts_text) // self.args.context_parallel_size\n                )\n\n            # Broadcast the completions from the main process to all processes\n            completion_ids = broadcast_object_list(completion_ids, from_process=0)\n\n            # Determine the appropriate slice based on sequence parallelism\n            if self.args.context_parallel_size > 1:\n                # Calculate SP group ID (which group of ranks this rank belongs to)\n                sp_group_id = self.accelerator.process_index // self.local_world_size\n\n                # Calculate the start index for this SP group\n                sp_group_start = sp_group_id * len(prompts) * self.local_world_size\n\n                # All ranks in the same SP group get the same data slice\n                process_slice = slice(\n                    sp_group_start,\n                    sp_group_start + len(prompts),\n                )\n                completion_ids = completion_ids[process_slice]\n            else:\n                # Original behavior for non-sequence parallel case\n                process_slice = slice(\n                    self.accelerator.process_index * len(prompts),\n                    (self.accelerator.process_index + 1) * len(prompts),\n                )\n                completion_ids = completion_ids[process_slice]\n\n            # Pad the completions, and concatenate them with the prompts\n            completion_ids = [\n                torch.tensor(ids, device=device) for ids in completion_ids\n            ]\n            completion_ids = pad(\n                completion_ids, padding_value=self.processing_class.pad_token_id\n            )\n            prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)\n        else:\n            # Regular generation path\n            with unwrap_model_for_generation(\n                self.model_wrapped,\n                self.accelerator,\n                gather_deepspeed3_params=self.args.ds3_gather_for_generation,\n            ) as unwrapped_model:\n                prompt_completion_ids = unwrapped_model.generate(\n                    prompt_ids,\n                    attention_mask=prompt_mask,\n                    generation_config=self.generation_config,\n                )\n\n            # Compute prompt length and extract completion ids\n            prompt_length = prompt_ids.size(1)\n            prompt_ids = prompt_completion_ids[:, :prompt_length]\n            completion_ids = prompt_completion_ids[:, prompt_length:]\n\n        # Mask everything after the first EOS token\n        is_eos = completion_ids == self.processing_class.eos_token_id\n        eos_idx = torch.full(\n            (is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device\n        )\n        eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]\n        sequence_indices = torch.arange(is_eos.size(1), device=device).expand(\n            is_eos.size(0), -1\n        )\n        completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()\n\n        # If mask_truncated_completions is enabled, zero out truncated completions in completion_mask\n        if self.args.mask_truncated_completions:\n            truncated_completions = ~is_eos.any(dim=1)\n            completion_mask = (\n                completion_mask * (~truncated_completions).unsqueeze(1).int()\n            )\n\n        # Concatenate prompt_mask with completion_mask for logit computation\n        attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)  # (B, P+C)\n\n        logits_to_keep = completion_ids.size(\n            1\n        )  # we only need to compute the logits for the completion tokens\n        batch_size = (\n            self.args.per_device_train_batch_size\n            if mode == \"train\"\n            else self.args.per_device_eval_batch_size\n        )\n\n        with torch.no_grad():\n            # When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip it's\n            # computation here, and use per_token_logps.detach() instead.\n            if self.num_iterations > 1:\n                old_per_token_logps = self._get_per_token_logps(\n                    self.model,\n                    prompt_completion_ids,\n                    attention_mask,\n                    logits_to_keep,\n                    batch_size,\n                )\n            else:\n                old_per_token_logps = None\n\n            if self.beta == 0.0:\n                ref_per_token_logps = None\n            elif self.ref_model is not None:\n                ref_per_token_logps = self._get_per_token_logps(\n                    self.ref_model,\n                    prompt_completion_ids,\n                    attention_mask,\n                    logits_to_keep,\n                    batch_size,\n                )\n            else:\n                with self.accelerator.unwrap_model(self.model).disable_adapter():\n                    ref_per_token_logps = self._get_per_token_logps(\n                        self.model,\n                        prompt_completion_ids,\n                        attention_mask,\n                        logits_to_keep,\n                        batch_size,\n                    )\n\n        # Decode the generated completions\n        completions_text = self.processing_class.batch_decode(\n            completion_ids, skip_special_tokens=True\n        )\n        if is_conversational(inputs[0]):\n            completions = []\n            for prompt, completion in zip(prompts, completions_text, strict=False):\n                bootstrap = (\n                    prompt.pop()[\"content\"] if prompt[-1][\"role\"] == \"assistant\" else \"\"\n                )\n                completions.append(\n                    [{\"role\": \"assistant\", \"content\": bootstrap + completion}]\n                )\n        else:\n            completions = completions_text\n\n        rewards_per_func = torch.zeros(\n            len(prompts), len(self.reward_funcs), device=device\n        )\n        for i, (reward_func, reward_processing_class, reward_func_name) in enumerate(\n            zip(\n                self.reward_funcs,\n                self.reward_processing_classes,\n                self.reward_func_names,\n                strict=False,\n            )\n        ):\n            with profiling_context(self, reward_func_name):\n                if isinstance(\n                    reward_func, nn.Module\n                ):  # Module instead of PretrainedModel for compat with compiled models\n                    if is_conversational(inputs[0]):\n                        messages = [\n                            {\"messages\": p + c}\n                            for p, c in zip(prompts, completions, strict=False)\n                        ]\n                        texts = [\n                            apply_chat_template(x, reward_processing_class)[\"text\"]\n                            for x in messages\n                        ]\n                    else:\n                        texts = [\n                            p + c for p, c in zip(prompts, completions, strict=False)\n                        ]\n                    reward_inputs = reward_processing_class(\n                        text=texts,\n                        return_tensors=\"pt\",\n                        padding=True,\n                        padding_side=\"right\",\n                        add_special_tokens=False,\n                    )\n                    reward_inputs = Trainer._prepare_inputs(self, reward_inputs)\n                    with torch.inference_mode():\n                        rewards_per_func[:, i] = reward_func(**reward_inputs).logits[\n                            :, 0\n                        ]  # Shape (B*G,)\n                else:\n                    # Repeat all input columns (but \"prompt\" and \"completion\") to match the number of generations\n                    keys = [\n                        key for key in inputs[0] if key not in [\"prompt\", \"completion\"]\n                    ]\n                    reward_kwargs = {\n                        key: [example[key] for example in inputs] for key in keys\n                    }\n                    output_reward_func = reward_func(\n                        prompts=prompts, completions=completions, **reward_kwargs\n                    )\n                    # Convert None values to NaN\n                    output_reward_func = [\n                        reward if reward is not None else torch.nan\n                        for reward in output_reward_func\n                    ]\n\n                    rewards_per_func[:, i] = torch.tensor(\n                        output_reward_func, dtype=torch.float32, device=device\n                    )\n\n        # If all reward functions return None for a given row, issue a detailed warning\n        if torch.isnan(rewards_per_func).all(dim=1).any():\n            nan_row_idx = (\n                torch.isnan(rewards_per_func).all(dim=1).nonzero(as_tuple=True)[0][0]\n            )\n            row_reward_kwargs = {\n                key: value[nan_row_idx] for key, value in reward_kwargs.items()\n            }\n            row_reward_kwargs[\"prompt\"] = prompts[nan_row_idx]\n            row_reward_kwargs[\"completion\"] = completions[nan_row_idx]\n            warnings.warn(\n                f\"All reward functions returned None for the following kwargs: {row_reward_kwargs}. \"\n                \"Please ensure that at least one reward function returns a valid reward.\",\n                stacklevel=2,\n            )\n\n        # Gather the reward per function: this part is crucial, because the rewards are normalized per group and the\n        # completions may be distributed across processes\n        rewards_per_func = gather(rewards_per_func)\n\n        # Apply weights to each reward function's output and sum\n        rewards = (\n            rewards_per_func * self.reward_weights.to(device).unsqueeze(0)\n        ).nansum(dim=1)\n\n        # Compute grouped-wise rewards\n        mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)\n        std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)\n\n        # Normalize the rewards to compute the advantages\n        mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(\n            self.num_generations, dim=0\n        )\n        std_grouped_rewards = std_grouped_rewards.repeat_interleave(\n            self.num_generations, dim=0\n        )\n        advantages = rewards - mean_grouped_rewards\n        if self.args.scale_rewards:\n            advantages = advantages / (std_grouped_rewards + 1e-4)\n\n        # Slice to keep only the local part of the data\n        if self.args.context_parallel_size > 1:\n            # Calculate SP group ID (which group of ranks this rank belongs to)\n            sp_group_id = self.accelerator.process_index // self.local_world_size\n\n            # Calculate the start index for this SP group\n            sp_group_start = sp_group_id * len(prompts) * self.local_world_size\n\n            # All ranks in the same SP group get the same data slice\n            process_slice = slice(\n                sp_group_start,\n                sp_group_start + len(prompts),\n            )\n        else:\n            # Original behavior for non-sequence parallel case\n            process_slice = slice(\n                self.accelerator.process_index * len(prompts),\n                (self.accelerator.process_index + 1) * len(prompts),\n            )\n        advantages = advantages[process_slice]\n\n        # Log the metrics\n        if mode == \"train\":\n            self._total_train_tokens += (\n                self.accelerator.gather_for_metrics(attention_mask.sum()).sum().item()\n            )\n        self._metrics[mode][\"num_tokens\"] = [self._total_train_tokens]\n\n        # log completion lengths, mean, min, max\n        agg_completion_mask = self.accelerator.gather_for_metrics(\n            completion_mask.sum(1)\n        )\n        self._metrics[mode][\"completions/mean_length\"].append(\n            agg_completion_mask.float().mean().item()\n        )\n        self._metrics[mode][\"completions/min_length\"].append(\n            agg_completion_mask.float().min().item()\n        )\n        self._metrics[mode][\"completions/max_length\"].append(\n            agg_completion_mask.float().max().item()\n        )\n\n        # identify sequences that terminated with EOS and log their lengths\n        agg_terminated_with_eos = self.accelerator.gather_for_metrics(is_eos.any(dim=1))\n        term_completion_mask = agg_completion_mask[agg_terminated_with_eos]\n        clipped_completions_ratio = 1 - len(term_completion_mask) / len(\n            agg_completion_mask\n        )\n        self._metrics[mode][\"completions/clipped_ratio\"].append(\n            clipped_completions_ratio\n        )\n        if len(term_completion_mask) == 0:\n            # edge case where no completed sequences are found\n            term_completion_mask = torch.zeros(1, device=device)\n        self._metrics[mode][\"completions/mean_terminated_length\"].append(\n            term_completion_mask.float().mean().item()\n        )\n        self._metrics[mode][\"completions/min_terminated_length\"].append(\n            term_completion_mask.float().min().item()\n        )\n        self._metrics[mode][\"completions/max_terminated_length\"].append(\n            term_completion_mask.float().max().item()\n        )\n\n        # Calculate mean reward per function, but only for samples where the function was applied (non-NaN values)\n        for i, reward_func_name in enumerate(self.reward_func_names):\n            mean_rewards = torch.nanmean(rewards_per_func[:, i]).item()\n            self._metrics[mode][f\"rewards/{reward_func_name}/mean\"].append(mean_rewards)\n            std_rewards = nanstd(rewards_per_func[:, i]).item()\n            self._metrics[mode][f\"rewards/{reward_func_name}/std\"].append(std_rewards)\n        self._metrics[mode][\"reward\"].append(mean_grouped_rewards.mean().item())\n        self._metrics[mode][\"reward_std\"].append(std_grouped_rewards.mean().item())\n\n        # Log prompt and completion texts\n        self._textual_logs[\"prompt\"].extend(gather_object(prompts_text))\n        self._textual_logs[\"completion\"].extend(gather_object(completions_text))\n        for i, name in enumerate(self.reward_func_names):\n            self._textual_logs[\"rewards\"][name].extend(rewards_per_func[:, i].tolist())\n\n        return {\n            \"prompt_ids\": prompt_ids,\n            \"prompt_mask\": prompt_mask,\n            \"completion_ids\": completion_ids,\n            \"completion_mask\": completion_mask,\n            \"advantages\": advantages,\n            \"old_per_token_logps\": old_per_token_logps,\n            \"ref_per_token_logps\": ref_per_token_logps,\n        }\n"
  },
  {
    "path": "src/axolotl/core/trainers/mamba.py",
    "content": "\"\"\"Module for mamba trainer\"\"\"\n\nimport torch\n\nfrom axolotl.core.trainers.base import AxolotlTrainer\n\n\nclass AxolotlMambaTrainer(AxolotlTrainer):\n    \"\"\"Mamba specific trainer to handle loss calculation\"\"\"\n\n    tag_names = [\"axolotl\", \"mamba\"]\n\n    def compute_loss(\n        self,\n        model,\n        inputs,\n        return_outputs=False,\n        num_items_in_batch=None,\n    ):\n        input_ids = inputs.pop(\"input_ids\")\n        lm_logits = model(input_ids).logits\n\n        labels = input_ids.to(lm_logits.device)\n        shift_logits = lm_logits[:, :-1, :].contiguous()\n        labels = labels[:, 1:].contiguous()\n\n        loss_fct = torch.nn.CrossEntropyLoss()\n        lm_loss = loss_fct(\n            shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)\n        )\n\n        return lm_loss\n"
  },
  {
    "path": "src/axolotl/core/trainers/mixins/__init__.py",
    "content": "\"\"\"Init for axolotl.core.trainers.mixins\"\"\"\n\n# flake8: noqa\n\nfrom .activation_checkpointing import ActivationOffloadingMixin\nfrom .checkpoints import CheckpointSaveMixin\nfrom .distributed_parallel import DistributedParallelMixin\nfrom .optimizer import OptimizerMixin\nfrom .packing import PackingMixin\nfrom .rng_state_loader import RngLoaderMixin\nfrom .scheduler import SchedulerMixin\n"
  },
  {
    "path": "src/axolotl/core/trainers/mixins/activation_checkpointing.py",
    "content": "\"\"\"\nTrainer mixin for activation checkpointing w offloading\n\"\"\"\n\nimport contextlib\n\nfrom peft import PeftModel\nfrom torch import nn\nfrom torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (\n    apply_activation_checkpointing,\n)\nfrom torch.distributed.fsdp.wrap import ModuleWrapPolicy\nfrom transformers import GradientCheckpointingLayer, Trainer\nfrom trl.models.activation_offloading import (\n    NoOpManager,\n    OffloadActivations,\n    get_act_offloading_ctx_manager,\n)\n\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\n\nclass ActivationOffloadingMixin(Trainer):\n    \"\"\"\n    Trainer mixin class for activation checkpointing w offloading\n    \"\"\"\n\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        if self.args.activation_offloading:\n            if isinstance(self.model, PeftModel):\n                self.activation_offload_context = get_lora_act_offloading_ctx_manager(\n                    self.model, use_streams=True\n                )\n            else:\n                self.activation_offload_context = get_act_offloading_ctx_manager(\n                    self.model, use_streams=True\n                )\n        else:\n            self.activation_offload_context = contextlib.nullcontext()\n\n    def training_step(self, *args, **kwargs):\n        with self.activation_offload_context:\n            return super().training_step(*args, **kwargs)\n\n\ndef ac_wrap_hf_model(model: nn.Module, **kwargs):\n    auto_wrap_policy = ModuleWrapPolicy(set((GradientCheckpointingLayer,)))\n    apply_activation_checkpointing(model, auto_wrap_policy=auto_wrap_policy, **kwargs)\n\n\ndef get_lora_act_offloading_ctx_manager(\n    model: nn.Module,\n    use_pin_memory: bool = True,\n    use_streams: bool = True,\n    min_offload_size: int = 1024,\n    max_fwd_stash_size: int = 5,\n    warn_if_no_head: bool = True,\n) -> OffloadActivations:\n    \"\"\"\n    Returns the activation offloading context manager for the model. All but the last output Linear in every step will\n    be offloaded.\n\n    If activation offloading is enabled, we return the OffloadActivations context manager. If activation offloading is\n    disabled, we return a NoOpManager context manager.\n\n    Args:\n        model (`nn.Module`):\n            Model to wrap with the activation offloading context manager.\n        use_pin_memory (`bool`, *optional*, defaults to `True`):\n            Whether to offloaded Tensor will be placed in pinned memory on the CPU. Pinned memory allows the Tensor to\n            be moved back onto GPU more quickly but is a limited resource.\n        use_streams (`bool`, *optional*, defaults to `True`):\n            Whether to use streams for performance optimization where the communications get overlapped with the\n            computation. Requires a torch build after torch-2.5.0.\n        min_offload_size (`int`, *optional*, defaults to `1024`):\n            Minimum number of bytes a Tensor must be in order to qualify for offloading. If the tensor is too small, we\n            do not want to waste bandwidth and resources moving it to CPU and back.\n        max_fwd_stash_size (`int`, *optional*, defaults to `5`):\n            Maximum size of the forward stash, or the maximum number of consecutive activations to keep alive during\n            the forward pass. This number must be at least 1. Keeping alive more activations will potentially allow\n            more overlap between the communication and compute streams at the cost of increasing memory usage. Keeping\n            alive fewer activations will conserve memory, but may cause poor overlap between the streams, increasing\n            runtime.\n        warn_if_no_head (`bool`, *optional*, defaults to `True`):\n            Whether to warn if no output head is detected. If set to `False`, no warning will be raised if no output\n            head is detected.\n\n    Returns:\n        `contextlib.ContextDecorator`:\n            Activation offloading context manager for the model.\n    \"\"\"\n\n    activations_handling_ctx = OffloadActivations(\n        use_pin_memory=use_pin_memory,\n        use_streams=use_streams,\n        min_offload_size=min_offload_size,\n        max_fwd_stash_size=max_fwd_stash_size,\n    )\n\n    # Below is our hack to disable offloading the last output Linear in every\n    # step, as the cost for offloading the activation and then soon after bringing\n    # it back is expensive.\n    output_head_detected = False\n    noop_ctx = NoOpManager()\n\n    # Try to get the actual model if it's wrapped\n    unwrapped_model = model\n    if hasattr(unwrapped_model, \"module\"):\n        unwrapped_model = unwrapped_model.module\n    # check for PEFT models\n    if hasattr(unwrapped_model, \"base_model\") and hasattr(\n        unwrapped_model, \"peft_config\"\n    ):\n        unwrapped_model = unwrapped_model.base_model\n\n    # Check for different types of output heads\n    if hasattr(unwrapped_model, \"output\"):\n        if isinstance(unwrapped_model.output, nn.Module):\n            unwrapped_model.output.register_forward_pre_hook(\n                lambda *args: noop_ctx.__enter__()\n            )\n            unwrapped_model.output.register_forward_hook(\n                lambda *args: noop_ctx.__exit__(), always_call=True\n            )\n            output_head_detected = True\n        elif hasattr(unwrapped_model.output, \"linear\") and isinstance(\n            unwrapped_model.output.linear, nn.Module\n        ):\n            unwrapped_model.output.linear.register_forward_pre_hook(\n                lambda *args: noop_ctx.__enter__()\n            )\n            unwrapped_model.output.linear.register_forward_hook(\n                lambda *args: noop_ctx.__exit__(), always_call=True\n            )\n            output_head_detected = True\n\n    # Check for HuggingFace model output heads\n    elif hasattr(unwrapped_model, \"lm_head\"):\n        unwrapped_model.lm_head.register_forward_pre_hook(\n            lambda *args: noop_ctx.__enter__()\n        )\n        unwrapped_model.lm_head.register_forward_hook(\n            lambda *args: noop_ctx.__exit__(), always_call=True\n        )\n        output_head_detected = True\n\n    # Check for decoder-based models\n    elif hasattr(unwrapped_model, \"decoder\"):\n        decoder = unwrapped_model.decoder\n        if hasattr(decoder, \"output\"):\n            decoder.output.register_forward_pre_hook(lambda *args: noop_ctx.__enter__())\n            decoder.output.register_forward_hook(\n                lambda *args: noop_ctx.__exit__(), always_call=True\n            )\n            output_head_detected = True\n        # Some models have lm_head in the decoder\n        elif hasattr(decoder, \"lm_head\"):\n            decoder.lm_head.register_forward_pre_hook(\n                lambda *args: noop_ctx.__enter__()\n            )\n            decoder.lm_head.register_forward_hook(\n                lambda *args: noop_ctx.__exit__(), always_call=True\n            )\n            output_head_detected = True\n\n    # Check for transformer models with final layer norm\n    elif hasattr(unwrapped_model, \"final_layer_norm\") or hasattr(\n        unwrapped_model, \"ln_f\"\n    ):\n        final_norm = (\n            getattr(unwrapped_model, \"final_layer_norm\", None) or unwrapped_model.ln_f\n        )\n        final_norm.register_forward_pre_hook(lambda *args: noop_ctx.__enter__())\n        final_norm.register_forward_hook(\n            lambda *args: noop_ctx.__exit__(), always_call=True\n        )\n        output_head_detected = True\n\n    # Check for models with head module\n    elif hasattr(unwrapped_model, \"head\") and isinstance(\n        unwrapped_model.head, nn.Module\n    ):\n        unwrapped_model.head.register_forward_pre_hook(\n            lambda *args: noop_ctx.__enter__()\n        )\n        unwrapped_model.head.register_forward_hook(\n            lambda *args: noop_ctx.__exit__(), always_call=True\n        )\n        output_head_detected = True\n\n    if not output_head_detected and warn_if_no_head:\n        LOG.warning(\n            \"During activation offloading, no output head was detected. If your model has an output head, it will be \"\n            \"offloaded. This usually greatly slows training, given the large vocabulary size. To change this \"\n            \"behavior, set your output head as model.output and make it an nn.Module. You can disable this warning by \"\n            \"passing `warn_if_no_head=False`.\"\n        )\n\n    for name, module in unwrapped_model.named_modules():\n        # Disable offloading for any Liger modules\n        if \"liger\" in name.lower():\n            module.register_forward_pre_hook(lambda *args: noop_ctx.__enter__())\n            module.register_forward_hook(\n                lambda *args: noop_ctx.__exit__(), always_call=True\n            )\n        # disable offloading for any submodules to fix LoRA training\n        if name.endswith(\"._checkpoint_wrapped_module\"):\n            for _, sub_module in module.named_modules():\n                sub_module.register_forward_pre_hook(lambda *args: noop_ctx.__enter__())\n                sub_module.register_forward_hook(\n                    lambda *args: noop_ctx.__exit__(), always_call=True\n                )\n\n    return activations_handling_ctx\n"
  },
  {
    "path": "src/axolotl/core/trainers/mixins/checkpoints.py",
    "content": "\"\"\"Custom handling to not fail training if fsdp optimizer is not savable\"\"\"\n\nfrom transformers import Trainer\n\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\n\nclass CheckpointSaveMixin(Trainer):\n    \"\"\"Mixin to handle saving the optimizer and scheduler if they are not savable.\"\"\"\n\n    def _save_optimizer_and_scheduler(self, output_dir):\n        try:\n            super()._save_optimizer_and_scheduler(output_dir)\n        except (NotImplementedError, KeyError) as exc:\n            # TODO: fix fsdp2 optimizer saving\n            LOG.warning_once(\n                f\"Trainer does not support saving optimizer and scheduler:  {exc}\\n\"\n                \"Optimizer and scheduler states were not saved - resuming from checkpoints \"\n                \"for this training run will not be possible.\",\n            )\n"
  },
  {
    "path": "src/axolotl/core/trainers/mixins/distributed_parallel.py",
    "content": "\"\"\"\nMixin for correctly saving fsdp\n\"\"\"\n\nfrom accelerate import PartialState\nfrom transformers import Trainer\n\n\nclass DistributedParallelMixin(Trainer):\n    \"\"\"\n    Mixin for correctly saving fsdp\n    \"\"\"\n\n    def _save(self, output_dir: str | None = None, state_dict=None):\n        if (\n            state_dict is None\n            and self.accelerator.parallelism_config\n            and self.accelerator.parallelism_config.dp_shard_enabled\n        ):\n            state_dict = self.accelerator.get_state_dict(self.model)\n        super()._save(output_dir, state_dict=state_dict)\n\n    def create_accelerator_and_postprocess(self):\n        super().create_accelerator_and_postprocess()\n        if (\n            self.accelerator.distributed_type == \"FSDP\"\n            and self.accelerator.state.fsdp_plugin is None\n        ):\n            # handle Context Parallelism without FSDP\n            self.accelerator.state.distributed_type = \"MULTI_GPU\"\n            self.accelerator.state._shared_state[\"distributed_type\"] = \"MULTI_GPU\"\n            PartialState().distributed_type = \"MULTI_GPU\"\n"
  },
  {
    "path": "src/axolotl/core/trainers/mixins/optimizer.py",
    "content": "\"\"\"Module for Axolotl trainer optimizer mixin\"\"\"\n\nfrom peft.optimizers import create_loraplus_optimizer\nfrom torch import nn\nfrom transformers.trainer import Trainer\nfrom transformers.utils import is_sagemaker_mp_enabled\n\nfrom axolotl.integrations.base import BaseOptimizerFactory\nfrom axolotl.utils.logging import get_logger\n\nif is_sagemaker_mp_enabled():\n    import smdistributed.modelparallel.torch as smp\n\nLOG = get_logger(__name__)\n\n\nclass OptimizerMixin(Trainer):\n    \"\"\"Mixin class for shared handling of building custom optimizers\"\"\"\n\n    args = None  # type: \"AxolotlTrainingArguments\"  # type: ignore[name-defined]\n\n    def create_optimizer_grouped_parameters(\n        self, opt_model, optimizer_kwargs\n    ) -> list[dict]:\n        decay_parameters = self.get_decay_parameter_names(opt_model)\n        params: dict = {\n            \"to_weight_decay\": {},  # LayerNorm and bias\n            \"embeddings\": {},  # lm_head, embed_tokens,\n            \"no_weight_decay\": {},\n        }\n        lr_groups_lookup = {}\n        lr_groups_learning_rates = {}\n        if self.args.lr_groups:\n            for lr_group in self.args.lr_groups:\n                group_name = lr_group[\"name\"]\n                group_modules = lr_group[\"modules\"]\n                for module in group_modules:\n                    lr_groups_lookup[module] = group_name\n                lr_groups_learning_rates[group_name] = lr_group[\"lr\"]\n                params[f\"to_weight_decay_{group_name}\"] = {}\n\n        for name, param in opt_model.named_parameters():\n            if not param.requires_grad:\n                continue\n            if name.endswith(\"modules_to_save.default.weight\") or any(\n                embed_name in name for embed_name in [\"embed_tokens\", \"lm_head\"]\n            ):\n                params[\"embeddings\"][name] = param\n            elif name in decay_parameters:\n                lr_group_modules = [\n                    group_modules\n                    for group_modules in lr_groups_lookup\n                    if group_modules in name\n                ]\n                if lr_groups_lookup and any(lr_group_modules):\n                    lr_group_module = lr_group_modules[0]\n                    group_name = lr_groups_lookup[lr_group_module]\n                    params[f\"to_weight_decay_{group_name}\"][name] = param\n                else:\n                    params[\"to_weight_decay\"][name] = param\n            else:\n                params[\"no_weight_decay\"][name] = param\n        optimizer_grouped_parameters = []\n        if params[\"to_weight_decay\"]:\n            optimizer_grouped_parameters.append(\n                {\n                    \"params\": list(params[\"to_weight_decay\"].values()),\n                    \"weight_decay\": self.args.weight_decay,\n                    \"lr\": optimizer_kwargs[\"lr\"],\n                }\n            )\n        if params[\"embeddings\"]:\n            lr = optimizer_kwargs[\"lr\"]\n            if self.args.embedding_lr_scale:\n                lr *= self.args.embedding_lr_scale\n            elif self.args.embedding_lr:\n                lr = self.args.embedding_lr\n            optimizer_grouped_parameters.append(\n                {\n                    \"params\": list(params[\"embeddings\"].values()),\n                    \"weight_decay\": 0.0,\n                    \"lr\": lr,\n                }\n            )\n        if params[\"no_weight_decay\"]:\n            optimizer_grouped_parameters.append(\n                {\n                    \"params\": list(params[\"no_weight_decay\"].values()),\n                    \"weight_decay\": 0.0,\n                    \"lr\": optimizer_kwargs[\"lr\"],\n                }\n            )\n        for group_name, group_lr in lr_groups_learning_rates.items():\n            if params[f\"to_weight_decay_{group_name}\"]:\n                optimizer_grouped_parameters.append(\n                    {\n                        \"params\": list(\n                            params[f\"to_weight_decay_{group_name}\"].values()\n                        ),\n                        \"weight_decay\": self.args.weight_decay,\n                        \"lr\": group_lr,\n                    }\n                )\n\n        return optimizer_grouped_parameters\n\n    def create_optimizer(self, model=None):\n        if (\n            self.args.loraplus_lr_ratio is None\n            and self.args.embedding_lr_scale is None\n            and self.args.embedding_lr is None\n            and self.args.lr_groups is None\n            and self.optimizer_cls_and_kwargs is None\n        ):\n            return super().create_optimizer(model=model)\n\n        opt_model = self.model if model is None else model\n\n        if (\n            not self.optimizer\n            and self.optimizer_cls_and_kwargs is not None\n            and issubclass(self.optimizer_cls_and_kwargs[0], BaseOptimizerFactory)\n        ):\n            optimizer_factory_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs\n            self.optimizer = optimizer_factory_cls()(\n                opt_model, self.args, **optimizer_kwargs\n            )\n\n        if not self.optimizer:\n            if self.optimizer_cls_and_kwargs is not None:\n                optimizer_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs\n            else:\n                optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(\n                    self.args, opt_model\n                )\n\n            optimizer_grouped_parameters = self.create_optimizer_grouped_parameters(\n                opt_model, optimizer_kwargs\n            )\n\n            if self.args.loraplus_lr_ratio is not None:\n                loraplus_lr_ratio = getattr(self.args, \"loraplus_lr_ratio\", None)\n                loraplus_lr_embedding = getattr(\n                    self.args, \"loraplus_lr_embedding\", 1e-6\n                )\n                self.optimizer = create_loraplus_optimizer(\n                    opt_model,\n                    optimizer_cls,\n                    loraplus_lr_ratio=loraplus_lr_ratio,\n                    loraplus_lr_embedding=loraplus_lr_embedding,\n                    **optimizer_kwargs,\n                )\n            else:\n                # Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs`\n                # e.g. for GaLore optimizer.\n                if \"params\" in optimizer_kwargs:\n                    optimizer_grouped_parameters = optimizer_kwargs.pop(\"params\")\n\n                # Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs`\n                # e.g. for LOMO optimizer.\n                if \"model\" in optimizer_kwargs:\n                    optimizer_grouped_parameters = optimizer_kwargs.pop(\"model\")\n\n                # For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict`\n                # to avoid arguments conflicts.\n                if \"optimizer_dict\" in optimizer_kwargs:\n                    optimizer_grouped_parameters = optimizer_kwargs.pop(\n                        \"optimizer_dict\"\n                    )\n\n                self.optimizer = optimizer_cls(\n                    optimizer_grouped_parameters, **optimizer_kwargs\n                )\n\n            if optimizer_cls.__name__ == \"Adam8bit\":\n                import bitsandbytes\n\n                manager = bitsandbytes.optim.GlobalOptimManager.get_instance()\n\n                skipped = 0\n                for module in opt_model.modules():\n                    if isinstance(module, nn.Embedding):\n                        skipped += sum(\n                            {\n                                p.data_ptr(): p.numel() for p in module.parameters()\n                            }.values()\n                        )\n                        LOG.info(f\"skipped {module}: {skipped / 2**20}M params\")\n                        manager.register_module_override(\n                            module, \"weight\", {\"optim_bits\": 32}\n                        )\n                        LOG.debug(f\"bitsandbytes: will optimize {module} in fp32\")\n                LOG.info(f\"skipped: {skipped / 2**20}M params\")\n\n        if is_sagemaker_mp_enabled():\n            self.optimizer = smp.DistributedOptimizer(self.optimizer)\n\n        return self.optimizer\n\n\nclass OptimizerInitMixin:\n    \"\"\"\n    Mixin to handle common optimizer initialization logic for Trainers (mostly TRL) that do not\n    accept optimizer_cls_and_kwargs as kwarg in constructor.\n    \"\"\"\n\n    def __init__(self, *args, **kwargs):\n        optimizer_cls_and_kwargs = kwargs.pop(\"optimizer_cls_and_kwargs\", None)\n        super().__init__(*args, **kwargs)\n        if (\n            optimizer_cls_and_kwargs\n            and self.optimizer_cls_and_kwargs is None\n            and self.optimizer is None\n        ):\n            self.optimizer_cls_and_kwargs = optimizer_cls_and_kwargs\n"
  },
  {
    "path": "src/axolotl/core/trainers/mixins/packing.py",
    "content": "\"\"\"Trainer mixin to support packing\"\"\"\n\nfrom transformers import Trainer\n\n\nclass PackingMixin(Trainer):\n    \"\"\"\n    Trainer mixin to support packing\n    \"\"\"\n\n    def _set_signature_columns_if_needed(self):\n        super()._set_signature_columns_if_needed()\n        if (\n            self._signature_columns\n            and self.args.sample_packing\n            and self.args.sample_packing_drop_attention_mask\n        ):\n            set_sig_columns = set(self._signature_columns)\n            set_sig_columns.remove(\"attention_mask\")\n            self._signature_columns = list(set_sig_columns)\n"
  },
  {
    "path": "src/axolotl/core/trainers/mixins/rng_state_loader.py",
    "content": "\"\"\"\nTemporary fix/override for bug in resume from checkpoint\n\nSee https://github.com/huggingface/transformers/pull/37162\n\nTODO: Remove when upstream added PR to release\n\"\"\"\n\nimport os\nimport random\n\nimport numpy as np\nimport torch\nfrom transformers import Trainer, is_torch_npu_available\nfrom transformers.trainer import safe_globals\nfrom transformers.trainer_pt_utils import set_rng_state_for_device\nfrom transformers.training_args import ParallelMode\n\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\n\nclass RngLoaderMixin(Trainer):\n    \"\"\"\n    mixin for method override to load RNG states from a checkpoint\n    \"\"\"\n\n    def _load_rng_state(self, checkpoint):\n        # Load RNG states from `checkpoint`\n        if checkpoint is None:\n            return\n\n        if self.args.world_size > 1:\n            process_index = self.args.process_index\n            rng_file = os.path.join(checkpoint, f\"rng_state_{process_index}.pth\")\n            if not os.path.isfile(rng_file):\n                LOG.info(\n                    f\"Didn't find an RNG file for process {process_index}, if you are resuming a training that \"\n                    \"wasn't launched in a distributed fashion, reproducibility is not guaranteed.\"\n                )\n                return\n        else:\n            rng_file = os.path.join(checkpoint, \"rng_state.pth\")\n            if not os.path.isfile(rng_file):\n                LOG.info(\n                    \"Didn't find an RNG file, if you are resuming a training that was launched in a distributed \"\n                    \"fashion, reproducibility is not guaranteed.\"\n                )\n                return\n\n        # Use safe_globals to ensure numpy RNG states can be deserialized safely under PyTorch 2.6+,\n        # which requires allowlisted classes when loading with weights_only=True.\n        with safe_globals():\n            checkpoint_rng_state = torch.load(rng_file)  # nosec B614\n        random.setstate(checkpoint_rng_state[\"python\"])\n        np.random.set_state(checkpoint_rng_state[\"numpy\"])\n        torch.random.set_rng_state(checkpoint_rng_state[\"cpu\"])\n\n        is_distributed = self.args.parallel_mode == ParallelMode.DISTRIBUTED\n        if torch.cuda.is_available():\n            set_rng_state_for_device(\n                \"CUDA\", torch.cuda, checkpoint_rng_state, is_distributed\n            )\n        if is_torch_npu_available():\n            set_rng_state_for_device(\n                \"NPU\", torch.npu, checkpoint_rng_state, is_distributed\n            )\n"
  },
  {
    "path": "src/axolotl/core/trainers/mixins/scheduler.py",
    "content": "\"\"\"Module for Axolotl trainer scheduler mixin\"\"\"\n\nimport torch\nfrom torch.optim.lr_scheduler import LRScheduler, OneCycleLR\nfrom transformers.trainer import Trainer\n\nfrom axolotl.integrations.base import PluginManager\nfrom axolotl.utils.logging import get_logger\nfrom axolotl.utils.schedulers import (\n    JaggedLRRestartScheduler,\n    RexLR,\n    get_cosine_schedule_with_min_lr,\n    get_cosine_schedule_with_quadratic_warmup,\n    get_cosine_schedule_with_warmup_decay_constant,\n)\n\nLOG = get_logger(__name__)\n\n\nclass SchedulerMixin(Trainer):\n    \"\"\"\n    Mixin class for scheduler setup in CausalTrainer.\n    \"\"\"\n\n    args = None  # type: \"AxolotlTrainingArguments\"  # type: ignore[name-defined]\n\n    def create_scheduler(\n        self, num_training_steps: int, optimizer: None | torch.optim.Optimizer = None\n    ) -> LRScheduler:\n        \"\"\"\n        Set up the scheduler. The optimizer of the trainer must have been set up either before this method is called or\n        passed as an argument.\n\n        Args:\n            num_training_steps (int): The number of training steps to do.\n            optimizer (torch.optim.Optimizer): The training optimizer\n        \"\"\"\n        use_cosine_quadratic = (\n            self.args.lr_scheduler_type == \"cosine\"\n            and self.args.lr_quadratic_warmup is True\n        )\n\n        use_cosine_min_lr = (\n            self.args.lr_scheduler_type == \"cosine\"\n            and self.args.cosine_min_lr_ratio is not None\n        )\n\n        if optimizer is None:\n            if self.optimizer is None:\n                raise ValueError(\n                    \"Optimizer must be set before calling create_scheduler or passed as an argument.\"\n                )\n            optimizer = self.optimizer\n\n        # fmt: off\n        if self.lr_scheduler is None:  # type: ignore\n            # fmt: on\n            plugin_manager = PluginManager.get_instance()\n            lr_scheduler: LRScheduler | None = plugin_manager.create_lr_scheduler(\n                trainer=self,\n                optimizer=optimizer,\n                num_training_steps=num_training_steps\n            )\n            if lr_scheduler is not None:\n                LOG.info(f\"Using plugin-created lr_scheduler: {lr_scheduler}\")\n                self.lr_scheduler = lr_scheduler\n            elif self.args.alternate_lr_scheduler_type == \"one_cycle\":\n                num_warmup_steps = self.args.get_warmup_steps(num_training_steps)\n                pct_start = num_warmup_steps / num_training_steps\n                extra_lr_kwargs = {}\n                if \"pct_start\" not in self.args.lr_scheduler_kwargs:\n                    extra_lr_kwargs[\"pct_start\"] = pct_start\n                if \"anneal_strategy\" not in self.args.lr_scheduler_kwargs:\n                    extra_lr_kwargs[\"anneal_strategy\"] = \"cos\"\n\n                self.lr_scheduler = OneCycleLR(\n                    optimizer,\n                    max_lr=self.args.learning_rate,\n                    total_steps=num_training_steps,\n                    **extra_lr_kwargs,\n                    **self.args.lr_scheduler_kwargs,\n                )\n            elif self.args.alternate_lr_scheduler_type == \"rex\":\n                if use_cosine_min_lr:\n                    assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, \"cosine_min_lr_ratio must be between 0.0 and 1.0\"\n\n                self.lr_scheduler = RexLR(\n                    optimizer=optimizer,\n                    max_lr=self.args.learning_rate,\n                    min_lr=0 if not use_cosine_min_lr else (\n                        self.args.learning_rate * self.args.cosine_min_lr_ratio),\n                    total_steps=num_training_steps,\n                    num_warmup_steps=self.args.get_warmup_steps(num_training_steps),\n                )\n            elif use_cosine_quadratic:\n                if use_cosine_min_lr:\n                    LOG.warning(\n                        \"Both cosine quadratic warmup and min lr detected. Using quadratic warmup.\")\n\n                self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup(\n                    optimizer,\n                    num_warmup_steps=self.args.get_warmup_steps(num_training_steps),\n                    num_training_steps=num_training_steps,\n                )\n            elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr:\n                assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, \"cosine_min_lr_ratio must be between 0.0 and 1.0\"\n                assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, \"cosine_constant_lr_ratio must be between 0.0 and 1.0\"\n                self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant(\n                    optimizer,\n                    num_warmup_steps=self.args.get_warmup_steps(num_training_steps),\n                    num_training_steps=num_training_steps,\n                    min_lr_ratio=self.args.cosine_min_lr_ratio,\n                    constant_lr_ratio=self.args.cosine_constant_lr_ratio,\n                )\n            elif self.args.cosine_min_lr_ratio and use_cosine_min_lr:\n                assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, \"cosine_min_lr_ratio must be between 0.0 and 1.0\"\n                self.lr_scheduler = get_cosine_schedule_with_min_lr(\n                    optimizer,\n                    num_warmup_steps=self.args.get_warmup_steps(num_training_steps),\n                    num_training_steps=num_training_steps,\n                    min_lr_ratio=self.args.cosine_min_lr_ratio,\n                )\n            else:\n                super().create_scheduler(num_training_steps, optimizer=optimizer)\n        else:\n            if use_cosine_quadratic:\n                LOG.warning(\n                    \"axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).\")\n\n            if use_cosine_min_lr:\n                LOG.warning(\n                    \"axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).\")\n\n        if self.args.jagged_restart_steps:\n            warmup_steps = (\n                self.args.jagged_restart_warmup_steps or 10\n            )\n            anneal_steps = (\n                self.args.jagged_restart_anneal_steps or 1\n            )\n            if not self.lr_scheduler:\n                super().create_scheduler(num_training_steps, optimizer)\n            self.lr_scheduler = JaggedLRRestartScheduler(\n                optimizer,\n                self.lr_scheduler,\n                self.args.jagged_restart_steps,\n                warmup_steps,\n                anneal_steps,\n                min_lr_scale=self.args.cosine_min_lr_ratio or 0.001,\n            )\n\n        return self.lr_scheduler  # type: ignore\n"
  },
  {
    "path": "src/axolotl/core/trainers/trl.py",
    "content": "\"\"\"Module for TRL RL trainers\"\"\"\n\nfrom trl import RewardTrainer\nfrom trl.experimental.cpo import CPOTrainer\nfrom trl.experimental.kto import KTOTrainer\nfrom trl.experimental.orpo import ORPOTrainer\nfrom trl.experimental.prm import PRMTrainer\n\nfrom axolotl.core.trainers.mixins import DistributedParallelMixin, RngLoaderMixin\nfrom axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, OptimizerMixin\nfrom axolotl.core.trainers.mixins.scheduler import SchedulerMixin\n\n\nclass AxolotlORPOTrainer(\n    RngLoaderMixin,\n    SchedulerMixin,\n    OptimizerMixin,\n    OptimizerInitMixin,\n    DistributedParallelMixin,\n    ORPOTrainer,\n):\n    \"\"\"\n    Extend the base ORPOTrainer for axolotl helpers\n    \"\"\"\n\n    tag_names = [\"axolotl\", \"orpo\"]\n\n\nclass AxolotlKTOTrainer(\n    RngLoaderMixin,\n    SchedulerMixin,\n    OptimizerMixin,\n    OptimizerInitMixin,\n    DistributedParallelMixin,\n    KTOTrainer,\n):\n    \"\"\"\n    Extend the base KTOTrainer for axolotl helpers\n    \"\"\"\n\n    tag_names = [\"axolotl\", \"kto\"]\n\n\nclass AxolotlCPOTrainer(\n    RngLoaderMixin,\n    SchedulerMixin,\n    OptimizerMixin,\n    OptimizerInitMixin,\n    DistributedParallelMixin,\n    CPOTrainer,\n):\n    \"\"\"\n    Extend the base CPOTrainer for axolotl helpers\n    \"\"\"\n\n    tag_names = [\"axolotl\", \"cpo\"]\n\n\nclass AxolotlRewardTrainer(\n    RngLoaderMixin,\n    SchedulerMixin,\n    OptimizerMixin,\n    OptimizerInitMixin,\n    DistributedParallelMixin,\n    RewardTrainer,\n):\n    \"\"\"\n    Extend the base RewardTrainer for axolotl helpers\n    \"\"\"\n\n    tag_names = [\"axolotl\", \"reward\"]\n\n\nclass AxolotlPRMTrainer(\n    RngLoaderMixin,\n    SchedulerMixin,\n    OptimizerMixin,\n    OptimizerInitMixin,\n    DistributedParallelMixin,\n    PRMTrainer,\n):\n    \"\"\"\n    Extend the base trl.PRMTrainer for axolotl helpers\n    \"\"\"\n\n    tag_names = [\"axolotl\", \"prm\"]\n"
  },
  {
    "path": "src/axolotl/core/trainers/utils.py",
    "content": "\"\"\"Utils for Axolotl trainers\"\"\"\n\n\ndef sanitize_kwargs_for_tagging(tag_names, kwargs=None):\n    if isinstance(tag_names, str):\n        tag_names = [tag_names]\n\n    if kwargs is not None:\n        if \"tags\" not in kwargs:\n            kwargs[\"tags\"] = tag_names\n        elif \"tags\" in kwargs and isinstance(kwargs[\"tags\"], list):\n            kwargs[\"tags\"].extend(tag_names)\n        elif \"tags\" in kwargs and isinstance(kwargs[\"tags\"], str):\n            tag_names.append(kwargs[\"tags\"])\n            kwargs[\"tags\"] = tag_names\n\n    return kwargs\n\n\ndef sanitize_kwargs_for_ds_tagging(dataset_tags, kwargs=None):\n    if isinstance(dataset_tags, str):\n        dataset_tags = [dataset_tags]\n\n    if (dataset_tags is not None) and (kwargs is not None):\n        if \"dataset_tags\" not in kwargs:\n            kwargs[\"dataset_tags\"] = dataset_tags\n        elif \"dataset_tags\" in kwargs and isinstance(kwargs[\"dataset_tags\"], list):\n            kwargs[\"dataset_tags\"].extend(dataset_tags)\n        elif \"dataset_tags\" in kwargs and isinstance(kwargs[\"dataset_tags\"], str):\n            dataset_tags.append(kwargs[\"dataset_tags\"])\n            kwargs[\"dataset_tags\"] = dataset_tags\n\n    return kwargs\n"
  },
  {
    "path": "src/axolotl/core/training_args.py",
    "content": "\"\"\"\nextra axolotl specific training args\n\"\"\"\n\nfrom __future__ import annotations\n\nfrom dataclasses import dataclass, field\nfrom typing import Optional, Type\n\nfrom transformers import TrainingArguments\nfrom trl import RewardConfig\nfrom trl.experimental.cpo import CPOConfig\nfrom trl.experimental.kto import KTOConfig\nfrom trl.experimental.orpo import ORPOConfig\nfrom trl.experimental.prm import PRMConfig\n\nfrom axolotl.integrations.config import merge_training_args\n\nAxolotlTrainingMixins: Type = merge_training_args()\n\n\n@dataclass\nclass AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):\n    \"\"\"\n    Training arguments for Causal trainer\n\n    This code is duplicated due to HF TrainingArguments not setting output_dir with a\n    default value so it can't be used as a mixin.\n    \"\"\"\n\n\n@dataclass\nclass AxolotlORPOConfig(AxolotlTrainingMixins, ORPOConfig):\n    \"\"\"\n    ORPO config for ORPO training\n    \"\"\"\n\n\n@dataclass\nclass AxolotlKTOConfig(AxolotlTrainingMixins, KTOConfig):\n    \"\"\"\n    KTO config for KTO training\n    \"\"\"\n\n\n@dataclass\nclass AxolotlCPOConfig(AxolotlTrainingMixins, CPOConfig):\n    \"\"\"\n    CPO config for CPO training\n    \"\"\"\n\n    simpo_gamma: Optional[float] = field(\n        default=None,\n        metadata={\"help\": \"simpo gamma parameter\"},\n    )\n\n\n@dataclass\nclass AxolotlRewardConfig(AxolotlTrainingMixins, RewardConfig):\n    \"\"\"\n    Reward config for Reward training\n    \"\"\"\n\n\n@dataclass\nclass AxolotlPRMConfig(AxolotlTrainingMixins, PRMConfig):\n    \"\"\"\n    PRM config for PRM training\n    \"\"\"\n"
  },
  {
    "path": "src/axolotl/core/training_args_base.py",
    "content": "\"\"\"\nBase Axolotl Training Mixins shared across various trainer configs\n\"\"\"\n\nfrom dataclasses import dataclass, field\nfrom typing import Optional\n\nfrom PIL.Image import Resampling\n\n\n@dataclass\nclass AxolotlTrainingMixins:\n    \"\"\"\n    Mixin class for the Axolotl training args.\n    \"\"\"\n\n    model_type: Optional[str] = field(\n        default=None, metadata={\"help\": \"HF model configuration model_type.\"}\n    )\n    lr_quadratic_warmup: bool = field(\n        default=False,\n        metadata={\"help\": \"Use quadratic warmup for cosine scheduling.\"},\n    )\n    pretraining: bool = field(\n        default=False,\n        metadata={\n            \"help\": \"Indicates to trainer whether we are doing continued pretraining.\"\n        },\n    )\n    sample_packing: bool = field(\n        default=False,\n        metadata={\"help\": \"Use sample packing for efficient training.\"},\n    )\n    sample_packing_sequentially: bool = field(\n        default=False,\n        metadata={\n            \"help\": \"Use next-fit sample packing that preserves the order of samples coming from the sampler. Use in combination with curriculum_sampling for fully sequential packing.\"\n        },\n    )\n    sample_packing_mp_start_method: str | None = field(\n        default=None,\n        metadata={\"help\": \"The multiprocessing start method to use.\"},\n    )\n    sample_packing_drop_attention_mask: bool = field(\n        default=False,\n        metadata={\"help\": \"Drop attention mask from inputs when using packing.\"},\n    )\n    multipack_real_batches: bool = field(\n        default=False,\n        metadata={\"help\": \"Use real batches for efficient training.\"},\n    )\n    include_tkps: bool = field(\n        default=True,\n        metadata={\n            \"help\": \"Whether to include tokens per second in the training metrics.\"\n        },\n    )\n    eval_sample_packing: Optional[bool] = field(\n        default=None,\n        metadata={\"help\": \"Use sample packing for efficient evals.\"},\n    )\n    sample_packing_efficiency: float = field(\n        default=1.0,\n        metadata={\"help\": \"Sample packing efficiency for calculating batch length.\"},\n    )\n    sample_packing_bin_size: int = field(\n        default=200,\n        metadata={\n            \"help\": \"The max number of samples that packed sample can contain after packing. Increase for better packing.\"\n        },\n    )\n    sample_packing_group_size: int = field(\n        default=100000,\n        metadata={\n            \"help\": \"The number of samples to group together for packing. Increase for better packing.\"\n        },\n    )\n    max_seq_length: int = field(\n        default=2048,\n        metadata={\"help\": \"The maximum sequence length the model can handle\"},\n    )\n    dataset_num_proc: int | None = field(\n        default=None,\n        metadata={\"help\": \"The number of processes to use for data processing\"},\n    )\n    relora_steps: Optional[int] = field(\n        default=None,\n        metadata={\"help\": \"how often to reset for ReLoRA\"},\n    )\n    relora_prune_ratio: Optional[float] = field(\n        default=0.9,\n        metadata={\"help\": \"prune ratio for magnitude pruning of the optimizer\"},\n    )\n    jagged_restart_steps: Optional[int] = field(\n        default=None,\n        metadata={\"help\": \"how often to reset for jagged restarts\"},\n    )\n    jagged_restart_warmup_steps: Optional[int] = field(\n        default=None,\n        metadata={\n            \"help\": \"how many warmup steps to take after reset for jagged restarts\"\n        },\n    )\n    jagged_restart_anneal_steps: Optional[int] = field(\n        default=None,\n        metadata={\n            \"help\": \"how many anneal steps to take before reset for jagged restarts\"\n        },\n    )\n    bench_split: Optional[str] = field(\n        default=\"eval\", metadata={\"help\": \"The benchmark split to run on\"}\n    )\n    bench_dataset: Optional[str] = field(\n        default=\"pharaouk/dharma-1/dharma_1_mini.json\",\n        metadata={\n            \"help\": \"Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file\"\n        },\n    )\n    do_bench_eval: Optional[bool] = field(\n        default=False, metadata={\"help\": \"Whether to run the Benchmark evaluation.\"}\n    )\n    do_causal_lm_eval: Optional[bool] = field(\n        default=False, metadata={\"help\": \"Whether to run the Causal LM evaluation.\"}\n    )\n    max_bench_samples: Optional[int] = field(\n        default=None,\n        metadata={\n            \"help\": \"If set, only evaluates on `max_bench_samples` of the benchmark dataset.\"\n        },\n    )\n    bench_source_max_len: int = field(\n        default=2048, metadata={\"help\": \"Maximum source sequence length for bench.\"}\n    )\n    dataloader_prefetch_factor: Optional[int] = field(\n        default=None,\n        metadata={\"help\": \"prefetch_factor argument to the dataloader\"},\n    )\n    cosine_min_lr_ratio: Optional[float] = field(\n        default=None,\n        metadata={\"help\": \"Minimum learning rate is min_lr_ratio * learning_rate\"},\n    )\n    cosine_constant_lr_ratio: Optional[float] = field(\n        default=None,\n        metadata={\n            \"help\": \"Starting constant learning rate step is cosine_constant_lr_ratio * max_steps\"\n        },\n    )\n    loraplus_lr_ratio: Optional[float] = field(\n        default=None, metadata={\"help\": \"loraplus learning rate ratio lr_B / lr_A.\"}\n    )\n    loraplus_lr_embedding: Optional[float] = field(\n        default=1e-6,\n        metadata={\"help\": \"loraplus learning rate for lora embedding layers.\"},\n    )\n    embedding_lr_scale: Optional[float] = field(\n        default=None,\n        metadata={\"help\": \"Scale the learning rate for the embedding layers.\"},\n    )\n    lr_groups: Optional[list[dict]] = field(\n        default=None,\n        metadata={\"help\": \"Specify learning rate groups for with different LRs.\"},\n    )\n    embedding_lr: Optional[float] = field(\n        default=None,\n        metadata={\"help\": \"absolute learning rate for the embedding layers.\"},\n    )\n    qlora: bool = field(\n        default=False,\n        metadata={\"help\": \"whether this is a qlora training\"},\n    )\n    orpo_alpha: Optional[float] = field(\n        default=None,\n    )\n    lisa_n_layers: Optional[int] = field(\n        default=None,\n        metadata={\"help\": \"the number of activate layers in LISA\"},\n    )\n    lisa_step_interval: Optional[int] = field(\n        default=None,\n        metadata={\"help\": \"how often to switch layers in LISA\"},\n    )\n    lisa_layers_attribute: Optional[str] = field(\n        default=None,\n        metadata={\"help\": \"path under the model to access the layers\"},\n    )\n    curriculum_sampling: Optional[bool] = field(\n        default=None,\n        metadata={\"help\": \"whether to use sequential sampling for curriculum learning\"},\n    )\n    alternate_lr_scheduler_type: Optional[str] = field(\n        default=None,\n        metadata={\n            \"help\": \"workaround to pass an alternate lr scheduler to the HF trainer\"\n        },\n    )\n    chat_template: Optional[str] = field(\n        default=None,\n        metadata={\"help\": \"Chat template converting chat messages to text\"},\n    )\n\n    # kd_ce_alpha: Optional[float] = field(\n    #     default=None,\n    #     metadata={\n    #         \"help\": \"The alpha scaling parameter for SFT cross entropy loss when using KD\"\n    #     },\n    # )\n    #\n    # kd_alpha: Optional[float] = field(\n    #     default=1.0,\n    #     metadata={\"help\": \"The alpha scaling parameter for KD loss\"},\n    # )\n    #\n    # kd_temperature: Optional[float] = field(\n    #     default=1.0,\n    #     metadata={\n    #         \"help\": \"the temperature parameter for KL divergence loss when using KD\"\n    #     },\n    # )\n\n    adam_beta3: Optional[float] = field(\n        default=None,\n        metadata={\n            \"help\": \"The beta3 hyperparameter used in some optimizers such as CAME\"\n        },\n    )\n    adam_epsilon2: Optional[float] = field(\n        default=None,\n        metadata={\n            \"help\": \"The epsilon2 hyperparameter used in some optimizers such as CAME\"\n        },\n    )\n\n    activation_offloading: bool | None = field(\n        default=None,\n        metadata={\"help\": \"Use activation offloading with CUDA streams for training.\"},\n    )\n\n    # multi-modal section\n\n    image_size: int | tuple[int, int] | None = field(\n        default=None,\n        metadata={\"help\": \"The size of the image to resize to\"},\n    )\n\n    image_resize_algorithm: Resampling | None = field(\n        default=None,\n        metadata={\"help\": \"The algorithm to use for image resizing\"},\n    )\n\n    # end of multi-modal section\n\n    dion_learning_rate: float | None = field(\n        default=None,\n        metadata={\"help\": \"The learning rate for Dion\"},\n    )\n    dion_momentum: float | None = field(\n        default=None,\n        metadata={\"help\": \"The momentum for Dion\"},\n    )\n    dion_rank_fraction: float | None = field(\n        default=None,\n    )\n    dion_rank_multiple_of: int | None = field(\n        default=None,\n    )\n"
  },
  {
    "path": "src/axolotl/datasets.py",
    "content": "\"\"\"\nModule containing dataset functionality.\n\nWe want this to be a wrapper for an existing dataset that we have loaded. Lets use the\nconcept of middlewares to wrap each dataset. We'll use the collators later on to pad the\ndatasets.\n\"\"\"\n\nfrom datasets import Dataset, IterableDataset\n\nfrom axolotl.utils.logging import get_logger\n\nfrom .prompt_tokenizers import PromptTokenizingStrategy\n\nLOG = get_logger(__name__)\n\n\nclass TokenizedPromptDataset(Dataset):\n    \"\"\"Dataset that returns tokenized prompts from a stream of text files.\n\n    Args:\n        prompt_tokenizer: The prompt tokenizing method for processing the data.\n        dataset: Dataset with text files.\n        process_count: Number of processes to use for tokenizing.\n        keep_in_memory: Whether to keep the tokenized dataset in memory.\n    \"\"\"\n\n    def __init__(\n        self,\n        prompt_tokenizer: PromptTokenizingStrategy,\n        dataset: Dataset,\n        process_count: int | None = None,\n        keep_in_memory: bool | None = False,\n        **kwargs,\n    ):\n        self.prompt_tokenizer = prompt_tokenizer\n        self.process_count = process_count\n        self.keep_in_memory = keep_in_memory\n        super().__init__(\n            self.process(dataset).data,\n            **kwargs,\n        )\n\n    def process(self, dataset):\n        features = dataset.features.keys()\n\n        map_kwargs = {}\n        if self.prompt_tokenizer.supports_batched:\n            map_kwargs[\"batched\"] = True\n            map_kwargs[\"batch_size\"] = 1_000\n\n        if (\n            hasattr(self.prompt_tokenizer, \"filter_rows\")\n            and self.prompt_tokenizer.filter_rows\n        ):\n            dataset = dataset.filter(\n                self.prompt_tokenizer.filter_rows,\n                num_proc=self.process_count,\n                desc=\"Strategy Filtering Rows\",\n            )\n\n        return dataset.map(\n            self.prompt_tokenizer.tokenize_prompt,\n            num_proc=self.process_count,\n            remove_columns=features,\n            keep_in_memory=self.keep_in_memory,\n            desc=\"Tokenizing Prompts\",\n            **map_kwargs,\n        )\n\n\ndef wrap_dataset_for_tokenized_prompt(\n    prompt_tokenizer: PromptTokenizingStrategy,\n    dataset: Dataset | IterableDataset,\n    **kwargs,\n):\n    if isinstance(dataset, IterableDataset):\n        map_kwargs = {}\n        if prompt_tokenizer.supports_batched:\n            map_kwargs[\"batched\"] = True\n        features = list(dataset.features.keys())\n        return dataset.map(\n            prompt_tokenizer.tokenize_prompt,\n            remove_columns=features,\n            **map_kwargs,\n        )\n    return TokenizedPromptDataset(prompt_tokenizer, dataset, **kwargs)\n"
  },
  {
    "path": "src/axolotl/evaluate.py",
    "content": "\"\"\"Module for evaluating models.\"\"\"\n\nimport csv\nimport os\nimport sys\nfrom pathlib import Path\nfrom typing import Dict, Optional\n\nimport torch\nfrom datasets import Dataset\nfrom transformers.trainer import Trainer\n\nfrom axolotl.telemetry.errors import send_errors\nfrom axolotl.train import (\n    TrainDatasetMeta,\n    setup_model_and_tokenizer,\n)\nfrom axolotl.utils.dict import DictDefault\nfrom axolotl.utils.distributed import cleanup_distributed\nfrom axolotl.utils.logging import get_logger\nfrom axolotl.utils.trainer import setup_trainer\n\nproject_root = os.path.abspath(os.path.join(os.path.dirname(__file__), \"..\"))\nsrc_dir = os.path.join(project_root, \"src\")\nsys.path.insert(0, src_dir)\n\nLOG = get_logger(__name__)\n\n\ndef evaluate_dataset(\n    trainer: Trainer, dataset: Dataset, dataset_type: str, flash_optimum: bool = False\n) -> Optional[Dict[str, float]]:\n    \"\"\"Helper function to evaluate a single dataset.\n\n    Args:\n        trainer: The trainer instance.\n        dataset: Dataset to evaluate.\n        dataset_type: Type of dataset ('train' or 'eval').\n        flash_optimum: Whether to use flash optimum.\n\n    Returns:\n        Dictionary of metrics or None if dataset is None.\n    \"\"\"\n    if dataset is None:\n        return None\n\n    LOG.info(f\"Starting {dataset_type} set evaluation...\")\n\n    if flash_optimum:\n        with torch.backends.cuda.sdp_kernel(\n            enable_flash=True,\n            enable_math=True,\n            enable_mem_efficient=True,\n        ):\n            metrics = trainer.evaluate(dataset, metric_key_prefix=dataset_type)\n    else:\n        metrics = trainer.evaluate(dataset, metric_key_prefix=dataset_type)\n\n    LOG.info(f\"{dataset_type.capitalize()} set evaluation completed!\")\n    LOG.info(f\"{dataset_type.capitalize()} Metrics:\")\n    for key, value in metrics.items():\n        LOG.info(f\"{key}: {value}\")\n\n    return metrics\n\n\n@send_errors\ndef evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, float]:\n    \"\"\"\n    Evaluate a model on training and validation datasets.\n\n    Args:\n        cfg: Dictionary mapping `axolotl` config keys to values.\n        dataset_meta: Dataset metadata containing training and evaluation datasets.\n\n    Returns:\n        Dictionary mapping metric names to their values.\n    \"\"\"\n    # Load tokenizer, processor and model\n    LOG.debug(\"loading model for evaluation...\")\n    model, tokenizer, _, processor = setup_model_and_tokenizer(cfg)\n\n    # Get datasets\n\n    train_dataset = dataset_meta.train_dataset\n    eval_dataset = dataset_meta.eval_dataset\n    total_num_steps = dataset_meta.total_num_steps\n\n    # Set up trainer\n    trainer = setup_trainer(\n        cfg=cfg,\n        train_dataset=train_dataset,\n        eval_dataset=eval_dataset,\n        model=model,\n        tokenizer=tokenizer,\n        processor=processor,\n        total_num_steps=total_num_steps,\n    )\n\n    # Evaluate datasets\n    all_metrics = {}\n    train_metrics = evaluate_dataset(trainer, train_dataset, \"train\", cfg.flash_optimum)\n    eval_metrics = evaluate_dataset(trainer, eval_dataset, \"eval\", cfg.flash_optimum)\n\n    if train_metrics:\n        all_metrics.update(train_metrics)\n    if eval_metrics:\n        all_metrics.update(eval_metrics)\n\n    # Save metrics to CSV if output directory is specified and we have metrics\n    if cfg.output_dir and (train_metrics or eval_metrics):\n        output_dir = Path(cfg.output_dir)\n        output_dir.mkdir(parents=True, exist_ok=True)\n\n        metrics_file = output_dir / \"eval_summary.csv\"\n        with metrics_file.open(\"w\", newline=\"\", encoding=\"utf-8\") as file:\n            writer = csv.writer(file)\n            writer.writerow([\"metric\", \"training\", \"validation\"])\n\n            # Get unique metric names (removing prefixes) from available metrics\n            train_metric_names = {\n                k.replace(\"train_\", \"\"): k for k in (train_metrics or {})\n            }\n            eval_metric_names = {\n                k.replace(\"eval_\", \"\"): k for k in (eval_metrics or {})\n            }\n            all_metric_names = sorted(\n                set(train_metric_names.keys()) | set(eval_metric_names.keys())\n            )\n\n            for metric_name in all_metric_names:\n                train_value = (\n                    train_metrics.get(train_metric_names.get(metric_name, \"\"), \"\")\n                    if train_metrics\n                    else \"\"\n                )\n                eval_value = (\n                    eval_metrics.get(eval_metric_names.get(metric_name, \"\"), \"\")\n                    if eval_metrics\n                    else \"\"\n                )\n                writer.writerow([metric_name, train_value, eval_value])\n\n        LOG.info(f\"Evaluation results saved to {metrics_file}\")\n\n    del model\n    del tokenizer\n\n    cleanup_distributed()\n\n    return all_metrics\n"
  },
  {
    "path": "src/axolotl/integrations/LICENSE.md",
    "content": "### AXOLOTL COMMUNITY LICENSE AGREEMENT\n\nThis Axolotl Community License Agreement (“Agreement”) is entered into by and between Axolotl AI Corp. (“Axolotl”) and\nany individual or entity (“Licensee”) who wishes to use the Software (as defined below) in accordance with the terms\nand conditions set forth in this Agreement.\n\n1.  Definitions\n    1.1 “Licensee” refers to any individual or entity who has obtained a copy of the Software under this Agreement.\n    1.2 “Plugin Integration” means independent integration software modules which may or may not be offered by Axolotl,\n        which may be licensed separately by their respective  authors and/or licensors.\n    1.3 “Software” refers to the specific sub-directory of the Axolotl, Inc. software located at\n        https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/integrations and its subdirectories which\n        permits Plugin Integrations to integrate with the Axolotl service.\n2.  Grant of License\n    2.1\tAxolotl hereby grants Licensee a worldwide, non-exclusive, royalty-free, license to use, copy, modify, merge,\n        publish, distribute, sublicense, and/or otherwise exploit the Software, subject to the following conditions:\n        - Licensee must comply with all the terms and conditions of this Agreement.\n        - Licensee must include the original copyright notice and disclaimer of warranty in all copies or substantial\n          portions of the Software.\n    2.2 Licensee may use the Software for any lawful purpose, except as restricted in Section 3.\n3.  Restrictions\n    3.1 Licensee shall not use the Software for any activity that constitutes a commercial activity of offering for\n        free or for sale any services, platform, or equivalent  to third parties for the purposes of allowing such\n        third parties to fine-tune artificial intelligence models.\n    3.2 Licensee shall not:\n        - Use the Software for any illegal or unauthorized purpose.\n        - Reverse engineer, decompile, or disassemble the Software.\n        - Remove or modify any copyright, trademark, or other proprietary notices contained in the Software.\n        - Use the Software in a way that could damage, disable, overburden, or impair the functionality of the\n          Software or interfere with any third-party use of the Software.\n    3.3 Axolotl reserves the right to restrict certain Plugin Integrations for use with the Software. To the extent Licensee integrates a permitted, applicable Plugin Integration with the Software, Licensee shall comply with any additional terms and conditions imposed by the licensors of such Plugin Integration for use of such Plugin Integrations. Licensee shall contact Axolotl if it has questions about whether its use of the Software falls beyond the scope of this Agreement.\n4.  Intellectual Property Rights\n    4.1 Axolotl and its contributors retain all intellectual property rights in and to the Software. Licensee\n        acknowledges that this Agreement does not transfer any ownership rights or intellectual property rights to\n        Licensee.\n5.  Disclaimer of Warranty\n    5.1 THE SOFTWARE IS PROVIDED “AS IS,” WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED\n        TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND NON-INFRINGEMENT. IN NO EVENT SHALL\n        THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN ACTION OF\n        CONTRACT, TORT, OR OTHERWISE, ARISING FROM, OUT OF, OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER\n        DEALINGS IN THE SOFTWARE.\n6.  Termination\n    6.1 Axolotl may terminate this Agreement at any time if Licensee fails to comply with any of the terms and\n        conditions set forth herein. Upon termination, Licensee shall cease all use of the Software and destroy any\n        copies in its possession.\n7.  Governing Law\n    7.1 This Agreement shall be governed by and construed in accordance with the laws of the State of California,\n        without regards to conflicts of laws provisions thereof.\n8.  Entire Agreement\n    8.1 This Agreement constitutes the entire agreement between Axolotl and Licensee with respect to the subject matter\n        hereof and supersedes all prior or contemporaneous understandings or agreements between the parties concerning\n        the Software, whether written or oral. Axolotl may update the terms of this Agreement from time to time, and\n        Licensee’s continued use of the Software after any such updates shall constitute acceptance of updated terms\n        on a go-forward basis.  Axolotl will use commercially reasonable efforts to provide Licensee notice of any\n        material updates. By using the Software, Licensee acknowledges that it has read, understood, and agrees to be\n        bound by the terms and conditions of this Agreement.\n\nThis Agreement was last updated on August 23, 2024.\n"
  },
  {
    "path": "src/axolotl/integrations/__init__.py",
    "content": ""
  },
  {
    "path": "src/axolotl/integrations/base.py",
    "content": "# Copyright 2024 Axolotl AI. All rights reserved.\n#\n# This software may be used and distributed according to\n# the terms of the Axolotl Community License Agreement (the \"License\");\n# you may not use this file except in compliance with the License.\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT\n# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the\n# License for the specific language governing permissions and limitations under\n# the License.\n\n\"\"\"Base class for all plugins.\n\nA plugin is a reusable, modular, and self-contained piece of code that extends the functionality of Axolotl.\nPlugins can be used to integrate third-party models, modify the training process, or add new features.\n\nTo create a new plugin, you need to inherit from the BasePlugin class and implement the required methods.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport collections\nimport importlib\nimport traceback\nfrom typing import TYPE_CHECKING, Callable, OrderedDict, Union\n\nfrom peft import PeftModel\nfrom torch import nn\nfrom torch.optim import Optimizer\nfrom torch.optim.lr_scheduler import LRScheduler\nfrom transformers import PreTrainedModel, Trainer\nfrom transformers.trainer_pt_utils import get_parameter_names\n\nfrom axolotl.utils.dict import DictDefault\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\nif TYPE_CHECKING:\n    from axolotl.common.datasets import TrainDatasetMeta\n\n\nclass BasePlugin:\n    \"\"\"Base class for all plugins. Defines the interface for plugin methods.\n\n    A plugin is a reusable, modular, and self-contained piece of code that extends\n    the functionality of Axolotl. Plugins can be used to integrate third-party models,\n    modify the training process, or add new features.\n\n    To create a new plugin, you need to inherit from the BasePlugin class and\n    implement the required methods.\n\n    Note:\n        Plugin methods include:\n        - register(cfg): Registers the plugin with the given configuration.\n        - load_datasets(cfg): Loads and preprocesses the dataset for training.\n        - pre_model_load(cfg): Performs actions before the model is loaded.\n        - post_model_build(cfg, model): Performs actions after the model is loaded, but\n            before LoRA adapters are applied.\n        - pre_lora_load(cfg, model): Performs actions before LoRA weights are loaded.\n        - post_lora_load(cfg, model): Performs actions after LoRA weights are loaded.\n        - post_model_load(cfg, model): Performs actions after the model is loaded,\n            inclusive of any adapters.\n        - post_trainer_create(cfg, trainer): Performs actions after the trainer is\n            created.\n        - create_optimizer(cfg, trainer): Creates and returns an optimizer for training.\n        - create_lr_scheduler(cfg, trainer, optimizer, num_training_steps): Creates and\n            returns a learning rate scheduler.\n        - add_callbacks_pre_trainer(cfg, model): Adds callbacks to the trainer before\n            training.\n        - add_callbacks_post_trainer(cfg, trainer): Adds callbacks to the trainer after\n            training.\n    \"\"\"\n\n    def __init__(self):\n        \"\"\"Initializes the BasePlugin.\"\"\"\n\n    def register(self, cfg: dict):\n        \"\"\"Registers the plugin with the given configuration as an unparsed dict.\n\n        Args:\n            cfg: The configuration for the plugin.\n        \"\"\"\n\n    def get_input_args(self) -> str | None:\n        \"\"\"Returns a pydantic model for the plugin's input arguments.\"\"\"\n\n    def get_training_args_mixin(self) -> str | None:\n        \"\"\"\n        Returns a dataclass model for the plugin's training arguments.\n        \"\"\"\n\n    def load_datasets(\n        self, cfg: DictDefault, preprocess: bool = False\n    ) -> Union[\"TrainDatasetMeta\", None]:\n        \"\"\"Loads and preprocesses the dataset for training.\n\n        Args:\n            cfg: The configuration for the plugin.\n            preprocess: Whether this is the preprocess step of the datasets.\n\n        Returns:\n            dataset_meta: The metadata for the training dataset.\n        \"\"\"\n\n    def pre_model_load(self, cfg: DictDefault):\n        \"\"\"Performs actions before the model is loaded.\n\n        Args:\n            cfg: The configuration for the plugin.\n        \"\"\"\n\n    def post_model_build(self, cfg: DictDefault, model: PreTrainedModel):\n        \"\"\"Performs actions after the model is built/loaded, but before any adapters are applied.\n\n        Args:\n            cfg: The configuration for the plugin.\n        \"\"\"\n\n    def pre_lora_load(self, cfg: DictDefault, model: PreTrainedModel):\n        \"\"\"Performs actions before LoRA weights are loaded.\n\n        Args:\n            cfg: The configuration for the plugin.\n            model: The loaded model.\n        \"\"\"\n\n    def post_lora_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):\n        \"\"\"Performs actions after LoRA weights are loaded.\n\n        Args:\n            cfg: The configuration for the plugin.\n            model: The loaded model.\n        \"\"\"\n\n    def post_model_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):\n        \"\"\"Performs actions after the model is loaded.\n\n        Args:\n            cfg: The configuration for the plugin.\n            model: The loaded model.\n        \"\"\"\n\n    def get_trainer_cls(self, cfg: DictDefault) -> type[Trainer] | None:\n        \"\"\"Returns a custom class for the trainer.\n\n        Args:\n            cfg: The global axolotl configuration.\n\n        Returns:\n            The first non-`None` trainer class returned by a plugin.\n        \"\"\"\n\n    def post_trainer_create(self, cfg: DictDefault, trainer: Trainer):\n        \"\"\"Performs actions after the trainer is created.\n\n        Args:\n            cfg: The configuration for the plugin.\n            trainer: The trainer object for training.\n        \"\"\"\n\n    def get_training_args(self, cfg: DictDefault):\n        \"\"\"\n        Returns custom training arguments to set on TrainingArgs.\n\n        Args:\n            cfg: The global axolotl configuration.\n\n        Returns:\n            object: dict containing the training arguments.\n        \"\"\"\n\n    def get_collator_cls_and_kwargs(self, cfg: DictDefault, is_eval: bool = False):\n        \"\"\"\n        Returns a custom class for the collator.\n\n        Args:\n            cfg: The global axolotl configuration.\n            is_eval: Whether this is an eval split.\n\n        Returns:\n            class: The class for the collator.\n        \"\"\"\n\n    def create_optimizer(self, cfg: DictDefault, trainer: Trainer) -> Optimizer | None:\n        \"\"\"Creates and returns an optimizer for training.\n\n        Args:\n            cfg: The configuration for the plugin.\n            trainer: The trainer object for training.\n\n        Returns:\n            The created optimizer.\n        \"\"\"\n\n    def create_lr_scheduler(\n        self,\n        cfg: DictDefault,\n        trainer: Trainer,\n        optimizer: Optimizer,\n        num_training_steps: int,\n    ) -> LRScheduler | None:\n        \"\"\"Creates and returns a learning rate scheduler.\n\n        Args:\n            cfg: The configuration for the plugin.\n            trainer: The trainer object for training.\n            optimizer: The optimizer for training.\n            num_training_steps: Total number of training steps\n\n        Returns:\n            The created learning rate scheduler.\n        \"\"\"\n\n    def add_callbacks_pre_trainer(\n        self, cfg: DictDefault, model: PreTrainedModel\n    ) -> list[Callable]:\n        \"\"\"Set up callbacks before creating the trainer.\n\n        Args:\n            cfg: The configuration for the plugin.\n            model: The loaded model.\n\n        Returns:\n            A list of callback functions to be added to the `TrainingArgs`.\n        \"\"\"\n        return []\n\n    def add_callbacks_post_trainer(\n        self, cfg: DictDefault, trainer: Trainer\n    ) -> list[Callable]:\n        \"\"\"Adds callbacks to the trainer after creating the trainer. This is useful for\n        callbacks that require access to the model or trainer.\n\n        Args:\n            cfg: The configuration for the plugin.\n            trainer: The trainer object for training.\n\n        Returns:\n            A list of callback functions to be added\n        \"\"\"\n        return []\n\n    def post_train(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):\n        \"\"\"Performs actions after training is complete.\n\n        Args:\n            cfg: The axolotl configuration.\n            model: The loaded model.\n        \"\"\"\n\n    def post_train_unload(self, cfg: DictDefault):\n        \"\"\"Performs actions after training is complete and the model is unloaded.\n\n        Args:\n            cfg: The configuration for the plugin.\n        \"\"\"\n\n\ndef load_plugin(plugin_name: str) -> BasePlugin:\n    \"\"\"Loads a plugin based on the given plugin name.\n\n    The plugin name should be in the format \"module_name.class_name\". This function\n    splits the plugin name into module and class, imports the module, retrieves the\n    class from the module, and creates an instance of the class.\n\n    Args:\n        plugin_name: The name of the plugin to be loaded. The name should be in the\n            format \"module_name.class_name\".\n\n    Returns:\n        An instance of the loaded plugin.\n\n    Raises:\n        ImportError: If the plugin module cannot be imported.\n    \"\"\"\n    # split the plugin name into module and class\n    module_name, class_name = plugin_name.rsplit(\".\", 1)\n\n    # import the module\n    try:\n        module = importlib.import_module(module_name)\n    except ModuleNotFoundError as orig_exc:\n        try:\n            if not module_name.startswith(\"axolotl.integrations.\"):\n                module = importlib.import_module(\"axolotl.integrations.\" + module_name)\n            else:\n                raise orig_exc\n        except ModuleNotFoundError as exc:\n            raise orig_exc from exc\n\n    # instantiate the class\n    plugin_class = getattr(module, class_name)\n    # create an instance of the class\n    plugin = plugin_class()\n\n    return plugin\n\n\nclass PluginManager:\n    \"\"\"The `PluginManager` class is responsible for loading and managing plugins. It\n    should be a singleton so it can be accessed from anywhere in the codebase.\n\n    Attributes:\n        plugins: A list of loaded plugins.\n\n    Note:\n        Key methods include:\n        - get_instance(): Static method to get the singleton instance of `PluginManager`.\n        - register(plugin_name: str): Registers a new plugin by its name.\n        - pre_model_load(cfg): Calls the pre_model_load method of all registered plugins.\n    \"\"\"\n\n    plugins: OrderedDict[str, BasePlugin] = collections.OrderedDict()\n\n    _instance: PluginManager | None = None\n    _cfg: DictDefault | None = None\n\n    def __new__(cls):\n        \"\"\"Creates a new instance of PluginManager if it doesn't exist yet.\"\"\"\n        if cls._instance is None:\n            cls._instance = super(PluginManager, cls).__new__(cls)\n            cls._instance.plugins: OrderedDict[str, BasePlugin] = (\n                collections.OrderedDict()\n            )\n        return cls._instance\n\n    @staticmethod\n    def get_instance() -> \"PluginManager\":\n        \"\"\"Returns the singleton instance of PluginManager. If the instance doesn't\n        exist, it creates a new one.\n        \"\"\"\n        if PluginManager._instance is None:\n            PluginManager()\n        return PluginManager._instance  # type: ignore\n\n    @property\n    def cfg(self):\n        return self._cfg\n\n    @cfg.setter\n    def cfg(self, cfg):\n        self._cfg = cfg\n\n    def register(self, plugin_name: str):\n        \"\"\"Registers a new plugin by its name.\n\n        Args:\n            plugin_name: The name of the plugin to be registered.\n\n        Raises:\n            ImportError: If the plugin module cannot be imported.\n        \"\"\"\n        try:\n            LOG.info(f\"Attempting to load plugin: {plugin_name}\")\n            plugin = load_plugin(plugin_name)\n            self.plugins[plugin_name] = plugin\n            LOG.info(f\"Plugin loaded successfully: {plugin_name}\")\n        except ImportError as exc:\n            LOG.error(f\"Failed to load plugin: {plugin_name}\")\n            # print stacktrace\n            traceback.print_exc()\n            print(f\"Error: {exc}\")\n\n    def get_input_args(self) -> list[str]:\n        \"\"\"Returns a list of Pydantic classes for all registered plugins' input arguments.'\n\n        Returns:\n            A list of Pydantic classes for all registered plugins' input arguments.'\n        \"\"\"\n        input_args = []\n        for plugin in self.plugins.values():\n            input_args_from_plugin = plugin.get_input_args()\n            if input_args_from_plugin is not None:\n                input_args.append(input_args_from_plugin)\n        return input_args\n\n    def get_training_args_mixin(self):\n        \"\"\"\n        Returns a list of dataclasses for all registered plugins' training args mixins'\n\n        Returns:\n        list[str]: A list of dataclsses\n        \"\"\"\n        training_args = []\n        for plugin in self.plugins.values():\n            training_args_from_plugin = plugin.get_training_args_mixin()\n            if training_args_from_plugin is not None:\n                training_args.append(training_args_from_plugin)\n        return training_args\n\n    def load_datasets(\n        self, cfg: DictDefault, preprocess: bool = False\n    ) -> Union[\"TrainDatasetMeta\", None]:\n        \"\"\"Calls the load_datasets method of each registered plugin.\n\n        Args:\n            cfg: The configuration for the plugins.\n            preprocess: Whether this is preprocess step of the datasets.\n\n        Returns:\n            The dataset metadata loaded from all registered plugins.\n        \"\"\"\n        return_ds_meta = None\n        for plugin in self.plugins.values():\n            dataset_meta = plugin.load_datasets(cfg, preprocess)\n            if dataset_meta is not None:\n                if return_ds_meta is None:\n                    return_ds_meta = dataset_meta\n                else:\n                    raise RuntimeError(\"Multiple plugins loaded datasets\")\n        return return_ds_meta\n\n    def pre_model_load(self, cfg: DictDefault):\n        \"\"\"Calls the pre_model_load method of all registered plugins.\n\n        Args:\n            cfg: The configuration for the plugins.\n        \"\"\"\n        for plugin in self.plugins.values():\n            plugin.pre_model_load(cfg)\n\n    def post_model_build(self, cfg: DictDefault, model: PreTrainedModel):\n        \"\"\"Calls the `post_model_build` method of all registered plugins after the\n        model has been built / loaded, but before any adapters have been applied.\n\n        Args:\n            cfg: The configuration for the plugins.\n            model: The loaded model.\n        \"\"\"\n        for plugin in self.plugins.values():\n            plugin.post_model_build(cfg, model)\n\n    def pre_lora_load(self, cfg: DictDefault, model: PreTrainedModel):\n        \"\"\"Calls the `pre_lora_load` method of all registered plugins.\n\n        Args:\n            cfg: The configuration for the plugins.\n            model: The loaded model.\n        \"\"\"\n        for plugin in self.plugins.values():\n            plugin.pre_lora_load(cfg, model)\n\n    def post_lora_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):\n        \"\"\"Calls the `post_lora_load` method of all registered plugins.\n\n        Args:\n            cfg: The configuration for the plugins.\n            model: The loaded model.\n        \"\"\"\n        for plugin in self.plugins.values():\n            plugin.post_lora_load(cfg, model)\n\n    def post_model_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):\n        \"\"\"Calls the `post_model_load` method of all registered plugins after the model\n        has been loaded inclusive of any adapters.\n\n        Args:\n            cfg: The configuration for the plugins.\n            model: The loaded model.\n        \"\"\"\n        for plugin in self.plugins.values():\n            plugin.post_model_load(cfg, model)\n\n    def get_trainer_cls(self, cfg: DictDefault) -> Trainer | None:\n        \"\"\"Calls the `get_trainer_cls` method of all registered plugins and returns the\n        first non-`None` trainer class.\n\n        Args:\n            cfg: The configuration for the plugins.\n\n        Returns:\n            The first non-`None` trainer class returned by a plugin.\n        \"\"\"\n        for plugin in self.plugins.values():\n            trainer_cls = plugin.get_trainer_cls(cfg)\n            if trainer_cls is not None:\n                return trainer_cls\n        return None\n\n    def get_training_args(self, cfg):\n        \"\"\"\n        Calls the get_training_args method of all registered plugins and returns the combined training arguments.\n\n        Parameters:\n        cfg (dict): The configuration for the plugins.\n\n        Returns:\n        object: The training arguments\n        \"\"\"\n        training_args_kwargs = {}\n        for plugin in self.plugins.values():\n            training_args = plugin.get_training_args(cfg)\n            if training_args is not None:\n                training_args_kwargs.update(training_args)\n\n        return training_args_kwargs\n\n    def get_collator_cls_and_kwargs(self, cfg, is_eval=False):\n        \"\"\"\n        Calls the get_collator_cls_and_kwargs method of all registered plugins and returns the first non-None collator class.\n\n        Parameters:\n        cfg (dict): The configuration for the plugins.\n        is_eval (bool): Whether this is an eval split.\n\n        Returns:\n        object: The collator class, or None if none was found.\n        \"\"\"\n        for plugin in self.plugins.values():\n            collator = plugin.get_collator_cls_and_kwargs(cfg, is_eval=is_eval)\n            if collator is not None:\n                collator_cls, collator_kwargs = collator\n                return collator_cls, collator_kwargs\n        return None\n\n    def post_trainer_create(self, cfg: DictDefault, trainer: Trainer):\n        \"\"\"Calls the `post_trainer_create` method of all registered plugins.\n\n        Args:\n            cfg: The configuration for the plugins.\n            trainer: The trainer object for training.\n        \"\"\"\n        for plugin in self.plugins.values():\n            plugin.post_trainer_create(cfg, trainer)\n\n    def create_optimizer(self, trainer: Trainer) -> Optimizer | None:\n        \"\"\"Calls the `create_optimizer` method of all registered plugins and returns\n        the first non-`None` optimizer.\n\n        Args:\n            trainer: The trainer object for training.\n\n        Returns:\n            The created optimizer, or `None` if none was found.\n        \"\"\"\n        for plugin in self.plugins.values():\n            optimizer = plugin.create_optimizer(self.cfg, trainer)\n            if optimizer is not None:\n                return optimizer\n        return None\n\n    def create_lr_scheduler(\n        self, trainer: Trainer, optimizer: Optimizer, num_training_steps: int\n    ) -> LRScheduler | None:\n        \"\"\"Calls the `create_lr_scheduler` method of all registered plugins and returns\n        the first non-`None` scheduler.\n\n        Args:\n            trainer: The trainer object for training.\n            optimizer: The optimizer for training.\n\n        Returns:\n            The created learning rate scheduler, or `None` if not found.\n        \"\"\"\n        for plugin in self.plugins.values():\n            scheduler: LRScheduler | None = plugin.create_lr_scheduler(\n                self.cfg,\n                trainer=trainer,\n                optimizer=optimizer,\n                num_training_steps=num_training_steps,\n            )\n            if scheduler is not None:\n                return scheduler\n        return None\n\n    def add_callbacks_pre_trainer(\n        self, cfg: DictDefault, model: PreTrainedModel\n    ) -> list[Callable]:\n        \"\"\"Calls the add_callbacks_pre_trainer method of all registered plugins.\n\n        Args:\n            cfg: The configuration for the plugins.\n            model: The loaded model.\n\n        Returns:\n            A list of callback functions to be added to the `TrainingArgs`.\n        \"\"\"\n        callbacks = []\n        for plugin in self.plugins.values():\n            plugin_callbacks = plugin.add_callbacks_pre_trainer(cfg, model)\n            if plugin_callbacks:  # if the plugin returned a list of callbacks\n                callbacks.extend(plugin_callbacks)\n        return callbacks\n\n    def add_callbacks_post_trainer(\n        self, cfg: DictDefault, trainer: Trainer\n    ) -> list[Callable]:\n        \"\"\"Calls the `add_callbacks_post_trainer` method of all registered plugins.\n\n        Args:\n            cfg: The configuration for the plugins.\n            trainer: The trainer object for training.\n\n        Returns:\n            A list of callback functions to be added to the `TrainingArgs`.\n        \"\"\"\n        callbacks = []\n        for plugin in self.plugins.values():\n            plugin_callbacks = plugin.add_callbacks_post_trainer(cfg, trainer)\n            if plugin_callbacks:\n                callbacks.extend(plugin_callbacks)\n        return callbacks\n\n    def post_train(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):\n        \"\"\"Calls the post_train method of all registered plugins.\n\n        Args:\n            cfg: The configuration for the plugins.\n            model: The loaded model.\n        \"\"\"\n        for plugin in self.plugins.values():\n            plugin.post_train(cfg, model)\n\n    def post_train_unload(self, cfg: DictDefault):\n        \"\"\"Calls the post_train_unload method of all registered plugins.\n\n        Args:\n            cfg: The configuration for the plugins.\n        \"\"\"\n        for plugin in self.plugins.values():\n            plugin.post_train_unload(cfg)\n\n\nclass BaseOptimizerFactory:\n    \"\"\"Base class for factories to create custom optimizers\"\"\"\n\n    def __call__(\n        self, opt_model, training_args, **optimizer_kwargs\n    ) -> Optimizer | None:\n        pass\n\n    # duplicated from transformers\n    def get_decay_parameter_names(self, model) -> list[str]:\n        \"\"\"\n        Get all parameter names that weight decay will be applied to.\n\n        This function filters out parameters in two ways:\n        1. By layer type (instances of layers specified in ALL_LAYERNORM_LAYERS)\n        2. By parameter name patterns (containing 'bias', or variation of 'norm')\n        \"\"\"\n        forbidden_name_patterns = [\n            r\"bias\",\n            r\"layernorm\",\n            r\"rmsnorm\",\n            r\"(?:^|\\.)norm(?:$|\\.)\",\n            r\"_norm(?:$|\\.)\",\n        ]\n        decay_parameters = get_parameter_names(\n            model, [nn.LayerNorm], forbidden_name_patterns\n        )\n        return decay_parameters\n"
  },
  {
    "path": "src/axolotl/integrations/config.py",
    "content": "# Copyright 2024 Axolotl AI. All rights reserved.\n#\n# This software may be used and distributed according to\n# the terms of the Axolotl Community License Agreement (the \"License\");\n# you may not use this file except in compliance with the License.\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT\n# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the\n# License for the specific language governing permissions and limitations under\n# the License.\n\n\"\"\"\nModule to handle merging the plugins' input arguments with the base configurations.\n\nThis was moved here to prevent circular imports.\n\"\"\"\n\nfrom typing import Any, Dict, List, Type\n\nfrom axolotl.utils.schemas.config import (\n    AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,\n    AxolotlInputConfig as AxolotlInputConfigBase,\n)\n\n\ndef merge_input_args():\n    \"\"\"\n    Merges input arguments from registered plugins with the base configurations.\n\n    This function retrieves the input arguments from registered plugins using the PluginManager.\n    It then dynamically creates new classes, AxolotlConfigWCapabilities and AxolotlInputConfig,\n    that inherit from the base configurations and include the input arguments from the plugins.\n\n    Returns:\n    tuple: A tuple containing the newly created classes, AxolotlConfigWCapabilities and AxolotlInputConfig.\n    \"\"\"\n    from axolotl.integrations.base import PluginManager\n\n    plugin_manager = PluginManager.get_instance()\n    input_args: List[str] = plugin_manager.get_input_args()\n    plugin_classes = []\n    dynamic_input = \"\"\n    for plugin_args in input_args:\n        plugin_module, plugin_cls = plugin_args.rsplit(\".\", 1)\n        dynamic_input += f\"from {plugin_module} import {plugin_cls}\\n\"\n        plugin_classes.append(plugin_cls)\n    if dynamic_input:\n        dynamic_input += f\"class AxolotlConfigWCapabilities(AxolotlConfigWCapabilitiesBase, {', '.join(plugin_classes)}):\\n    pass\\n\"\n        dynamic_input += f\"class AxolotlInputConfig(AxolotlInputConfigBase, {', '.join(plugin_classes)}):\\n    pass\\n\"\n\n        namespace: Dict[Any, Any] = {}\n        exec(dynamic_input, globals(), namespace)  # nosec B102\n        AxolotlInputConfig = namespace[\"AxolotlInputConfig\"]\n        AxolotlConfigWCapabilities = namespace[\"AxolotlConfigWCapabilities\"]\n        return AxolotlConfigWCapabilities, AxolotlInputConfig\n    return AxolotlConfigWCapabilitiesBase, AxolotlInputConfigBase\n\n\ndef merge_training_args() -> Type:\n    \"\"\"\n    Merges training arguments from registered plugins with the base TrainingArguments.\n\n    This function retrieves the training arguments from registered plugins using the PluginManager.\n    It then dynamically creates new classes, AxolotlTrainingMixins,\n    that inherit from the base configurations and include the training arguments from the plugins.\n\n    Returns:\n    tuple: A tuple containing the newly created classes, AxolotlTrainingMixins.\n    \"\"\"\n\n    from axolotl.core.training_args_base import (\n        AxolotlTrainingMixins as AxolotlTrainingMixinsBase,\n    )\n    from axolotl.integrations.base import PluginManager\n\n    plugin_manager = PluginManager.get_instance()\n    training_args_mixins: List[str] = plugin_manager.get_training_args_mixin()\n    mixin_classes = []\n    dynamic_input = \"\"\n    for plugin_args in training_args_mixins:\n        plugin_module, plugin_cls = plugin_args.rsplit(\".\", 1)\n        dynamic_input += f\"from {plugin_module} import {plugin_cls}\\n\"\n        mixin_classes.append(plugin_cls)\n    if dynamic_input:\n        dynamic_input += f\"class AxolotlTrainingMixins(AxolotlTrainingMixinsBase, {', '.join(mixin_classes)}):\\n    pass\\n\"\n\n        namespace: Dict[Any, Any] = {}\n        local_vars = {\"AxolotlTrainingMixinsBase\": AxolotlTrainingMixinsBase}\n        exec(dynamic_input, {**globals(), **local_vars}, namespace)  # nosec B102\n        AxolotlTrainingMixins = namespace[\"AxolotlTrainingMixins\"]\n        return AxolotlTrainingMixins\n    return AxolotlTrainingMixinsBase\n"
  },
  {
    "path": "src/axolotl/integrations/cut_cross_entropy/ACKNOWLEDGEMENTS.md",
    "content": "Acknowledgements\n\nPortions of this Cut Cross Entropy Software may utilize the following copyrighted\nmaterial, the use of which is hereby acknowledged.\n\n\n------\n\n\nPyTorch\n\n    From PyTorch:\n\n    Copyright (c) 2016-     Facebook, Inc            (Adam Paszke)\n    Copyright (c) 2014-     Facebook, Inc            (Soumith Chintala)\n    Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)\n    Copyright (c) 2012-2014 Deepmind Technologies    (Koray Kavukcuoglu)\n    Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)\n    Copyright (c) 2011-2013 NYU                      (Clement Farabet)\n    Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)\n    Copyright (c) 2006      Idiap Research Institute (Samy Bengio)\n    Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)\n\n    From Caffe2:\n\n    Copyright (c) 2016-present, Facebook Inc. All rights reserved.\n\n    All contributions by Facebook:\n    Copyright (c) 2016 Facebook Inc.\n\n    All contributions by Google:\n    Copyright (c) 2015 Google Inc.\n    All rights reserved.\n\n    All contributions by Yangqing Jia:\n    Copyright (c) 2015 Yangqing Jia\n    All rights reserved.\n\n    All contributions by Kakao Brain:\n    Copyright 2019-2020 Kakao Brain\n\n    All contributions by Cruise LLC:\n    Copyright (c) 2022 Cruise LLC.\n    All rights reserved.\n\n    All contributions by Arm:\n    Copyright (c) 2021, 2023-2024 Arm Limited and/or its affiliates\n\n    All contributions from Caffe:\n    Copyright(c) 2013, 2014, 2015, the respective contributors\n    All rights reserved.\n\n    All other contributions:\n    Copyright(c) 2015, 2016 the respective contributors\n    All rights reserved.\n\n    Caffe2 uses a copyright model similar to Caffe: each contributor holds\n    copyright over their contributions to Caffe2. The project versioning records\n    all such contribution and copyright details. If a contributor wants to further\n    mark their specific copyright on a particular contribution, they should\n    indicate their copyright solely in the commit message of the change when it is\n    committed.\n\n    All rights reserved.\n\n    Redistribution and use in source and binary forms, with or without\n    modification, are permitted provided that the following conditions are met:\n\n    1. Redistributions of source code must retain the above copyright\n    notice, this list of conditions and the following disclaimer.\n\n    2. Redistributions in binary form must reproduce the above copyright\n    notice, this list of conditions and the following disclaimer in the\n    documentation and/or other materials provided with the distribution.\n\n    3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America\n    and IDIAP Research Institute nor the names of its contributors may be\n    used to endorse or promote products derived from this software without\n    specific prior written permission.\n\n    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n    AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n    IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE\n    ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE\n    LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR\n    CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF\n    SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS\n    INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN\n    CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)\n    ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE\n    POSSIBILITY OF SUCH DAMAGE.\n\n\nTriton\n\n    /*\n    * Copyright 2018-2020 Philippe Tillet\n    * Copyright 2020-2022 OpenAI\n    *\n    * Permission is hereby granted, free of charge, to any person obtaining\n    * a copy of this software and associated documentation files\n    * (the \"Software\"), to deal in the Software without restriction,\n    * including without limitation the rights to use, copy, modify, merge,\n    * publish, distribute, sublicense, and/or sell copies of the Software,\n    * and to permit persons to whom the Software is furnished to do so,\n    * subject to the following conditions:\n    *\n    * The above copyright notice and this permission notice shall be\n    * included in all copies or substantial portions of the Software.\n    *\n    * THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND,\n    * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF\n    * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.\n    * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY\n    * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,\n    * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE\n    * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.\n    */\n\n\nTransformers\n\n    Copyright 2018- The Hugging Face team. All rights reserved.\n\n                                    Apache License\n                            Version 2.0, January 2004\n                            http://www.apache.org/licenses/\n\n    TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n    1. Definitions.\n\n        \"License\" shall mean the terms and conditions for use, reproduction,\n        and distribution as defined by Sections 1 through 9 of this document.\n\n        \"Licensor\" shall mean the copyright owner or entity authorized by\n        the copyright owner that is granting the License.\n\n        \"Legal Entity\" shall mean the union of the acting entity and all\n        other entities that control, are controlled by, or are under common\n        control with that entity. For the purposes of this definition,\n        \"control\" means (i) the power, direct or indirect, to cause the\n        direction or management of such entity, whether by contract or\n        otherwise, or (ii) ownership of fifty percent (50%) or more of the\n        outstanding shares, or (iii) beneficial ownership of such entity.\n\n        \"You\" (or \"Your\") shall mean an individual or Legal Entity\n        exercising permissions granted by this License.\n\n        \"Source\" form shall mean the preferred form for making modifications,\n        including but not limited to software source code, documentation\n        source, and configuration files.\n\n        \"Object\" form shall mean any form resulting from mechanical\n        transformation or translation of a Source form, including but\n        not limited to compiled object code, generated documentation,\n        and conversions to other media types.\n\n        \"Work\" shall mean the work of authorship, whether in Source or\n        Object form, made available under the License, as indicated by a\n        copyright notice that is included in or attached to the work\n        (an example is provided in the Appendix below).\n\n        \"Derivative Works\" shall mean any work, whether in Source or Object\n        form, that is based on (or derived from) the Work and for which the\n        editorial revisions, annotations, elaborations, or other modifications\n        represent, as a whole, an original work of authorship. For the purposes\n        of this License, Derivative Works shall not include works that remain\n        separable from, or merely link (or bind by name) to the interfaces of,\n        the Work and Derivative Works thereof.\n\n        \"Contribution\" shall mean any work of authorship, including\n        the original version of the Work and any modifications or additions\n        to that Work or Derivative Works thereof, that is intentionally\n        submitted to Licensor for inclusion in the Work by the copyright owner\n        or by an individual or Legal Entity authorized to submit on behalf of\n        the copyright owner. For the purposes of this definition, \"submitted\"\n        means any form of electronic, verbal, or written communication sent\n        to the Licensor or its representatives, including but not limited to\n        communication on electronic mailing lists, source code control systems,\n        and issue tracking systems that are managed by, or on behalf of, the\n        Licensor for the purpose of discussing and improving the Work, but\n        excluding communication that is conspicuously marked or otherwise\n        designated in writing by the copyright owner as \"Not a Contribution.\"\n\n        \"Contributor\" shall mean Licensor and any individual or Legal Entity\n        on behalf of whom a Contribution has been received by Licensor and\n        subsequently incorporated within the Work.\n\n    2. Grant of Copyright License. Subject to the terms and conditions of\n        this License, each Contributor hereby grants to You a perpetual,\n        worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n        copyright license to reproduce, prepare Derivative Works of,\n        publicly display, publicly perform, sublicense, and distribute the\n        Work and such Derivative Works in Source or Object form.\n\n    3. Grant of Patent License. Subject to the terms and conditions of\n        this License, each Contributor hereby grants to You a perpetual,\n        worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n        (except as stated in this section) patent license to make, have made,\n        use, offer to sell, sell, import, and otherwise transfer the Work,\n        where such license applies only to those patent claims licensable\n        by such Contributor that are necessarily infringed by their\n        Contribution(s) alone or by combination of their Contribution(s)\n        with the Work to which such Contribution(s) was submitted. If You\n        institute patent litigation against any entity (including a\n        cross-claim or counterclaim in a lawsuit) alleging that the Work\n        or a Contribution incorporated within the Work constitutes direct\n        or contributory patent infringement, then any patent licenses\n        granted to You under this License for that Work shall terminate\n        as of the date such litigation is filed.\n\n    4. Redistribution. You may reproduce and distribute copies of the\n        Work or Derivative Works thereof in any medium, with or without\n        modifications, and in Source or Object form, provided that You\n        meet the following conditions:\n\n        (a) You must give any other recipients of the Work or\n            Derivative Works a copy of this License; and\n\n        (b) You must cause any modified files to carry prominent notices\n            stating that You changed the files; and\n\n        (c) You must retain, in the Source form of any Derivative Works\n            that You distribute, all copyright, patent, trademark, and\n            attribution notices from the Source form of the Work,\n            excluding those notices that do not pertain to any part of\n            the Derivative Works; and\n\n        (d) If the Work includes a \"NOTICE\" text file as part of its\n            distribution, then any Derivative Works that You distribute must\n            include a readable copy of the attribution notices contained\n            within such NOTICE file, excluding those notices that do not\n            pertain to any part of the Derivative Works, in at least one\n            of the following places: within a NOTICE text file distributed\n            as part of the Derivative Works; within the Source form or\n            documentation, if provided along with the Derivative Works; or,\n            within a display generated by the Derivative Works, if and\n            wherever such third-party notices normally appear. The contents\n            of the NOTICE file are for informational purposes only and\n            do not modify the License. You may add Your own attribution\n            notices within Derivative Works that You distribute, alongside\n            or as an addendum to the NOTICE text from the Work, provided\n            that such additional attribution notices cannot be construed\n            as modifying the License.\n\n        You may add Your own copyright statement to Your modifications and\n        may provide additional or different license terms and conditions\n        for use, reproduction, or distribution of Your modifications, or\n        for any such Derivative Works as a whole, provided Your use,\n        reproduction, and distribution of the Work otherwise complies with\n        the conditions stated in this License.\n\n    5. Submission of Contributions. Unless You explicitly state otherwise,\n        any Contribution intentionally submitted for inclusion in the Work\n        by You to the Licensor shall be under the terms and conditions of\n        this License, without any additional terms or conditions.\n        Notwithstanding the above, nothing herein shall supersede or modify\n        the terms of any separate license agreement you may have executed\n        with Licensor regarding such Contributions.\n\n    6. Trademarks. This License does not grant permission to use the trade\n        names, trademarks, service marks, or product names of the Licensor,\n        except as required for reasonable and customary use in describing the\n        origin of the Work and reproducing the content of the NOTICE file.\n\n    7. Disclaimer of Warranty. Unless required by applicable law or\n        agreed to in writing, Licensor provides the Work (and each\n        Contributor provides its Contributions) on an \"AS IS\" BASIS,\n        WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n        implied, including, without limitation, any warranties or conditions\n        of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n        PARTICULAR PURPOSE. You are solely responsible for determining the\n        appropriateness of using or redistributing the Work and assume any\n        risks associated with Your exercise of permissions under this License.\n\n    8. Limitation of Liability. In no event and under no legal theory,\n        whether in tort (including negligence), contract, or otherwise,\n        unless required by applicable law (such as deliberate and grossly\n        negligent acts) or agreed to in writing, shall any Contributor be\n        liable to You for damages, including any direct, indirect, special,\n        incidental, or consequential damages of any character arising as a\n        result of this License or out of the use or inability to use the\n        Work (including but not limited to damages for loss of goodwill,\n        work stoppage, computer failure or malfunction, or any and all\n        other commercial damages or losses), even if such Contributor\n        has been advised of the possibility of such damages.\n\n    9. Accepting Warranty or Additional Liability. While redistributing\n        the Work or Derivative Works thereof, You may choose to offer,\n        and charge a fee for, acceptance of support, warranty, indemnity,\n        or other liability obligations and/or rights consistent with this\n        License. However, in accepting such obligations, You may act only\n        on Your own behalf and on Your sole responsibility, not on behalf\n        of any other Contributor, and only if You agree to indemnify,\n        defend, and hold each Contributor harmless for any liability\n        incurred by, or claims asserted against, such Contributor by reason\n        of your accepting any such warranty or additional liability.\n\n    END OF TERMS AND CONDITIONS\n\n    APPENDIX: How to apply the Apache License to your work.\n\n        To apply the Apache License to your work, attach the following\n        boilerplate notice, with the fields enclosed by brackets \"[]\"\n        replaced with your own identifying information. (Don't include\n        the brackets!)  The text should be enclosed in the appropriate\n        comment syntax for the file format. We also recommend that a\n        file or class name and description of purpose be included on the\n        same \"printed page\" as the copyright notice for easier\n        identification within third-party archives.\n\n    Copyright [yyyy] [name of copyright owner]\n\n    Licensed under the Apache License, Version 2.0 (the \"License\");\n    you may not use this file except in compliance with the License.\n    You may obtain a copy of the License at\n\n        http://www.apache.org/licenses/LICENSE-2.0\n\n    Unless required by applicable law or agreed to in writing, software\n    distributed under the License is distributed on an \"AS IS\" BASIS,\n    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n    See the License for the specific language governing permissions and\n    limitations under the License.\n"
  },
  {
    "path": "src/axolotl/integrations/cut_cross_entropy/LICENSE",
    "content": "Copyright (C) 2024 Apple Inc. All Rights Reserved.\n\nIMPORTANT:  This Apple software is supplied to you by Apple\nInc. (\"Apple\") in consideration of your agreement to the following\nterms, and your use, installation, modification or redistribution of\nthis Apple software constitutes acceptance of these terms.  If you do\nnot agree with these terms, please do not use, install, modify or\nredistribute this Apple software.\n\nIn consideration of your agreement to abide by the following terms, and\nsubject to these terms, Apple grants you a personal, non-exclusive\nlicense, under Apple's copyrights in this original Apple software (the\n\"Apple Software\"), to use, reproduce, modify and redistribute the Apple\nSoftware, with or without modifications, in source and/or binary forms;\nprovided that if you redistribute the Apple Software in its entirety and\nwithout modifications, you must retain this notice and the following\ntext and disclaimers in all such redistributions of the Apple Software.\nNeither the name, trademarks, service marks or logos of Apple Inc. may\nbe used to endorse or promote products derived from the Apple Software\nwithout specific prior written permission from Apple.  Except as\nexpressly stated in this notice, no other rights or licenses, express or\nimplied, are granted by Apple herein, including but not limited to any\npatent rights that may be infringed by your derivative works or by other\nworks in which the Apple Software may be incorporated.\n\nThe Apple Software is provided by Apple on an \"AS IS\" basis.  APPLE\nMAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION\nTHE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS\nFOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND\nOPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS.\n\nIN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL\nOR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF\nSUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS\nINTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION,\nMODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED\nAND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE),\nSTRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE\nPOSSIBILITY OF SUCH DAMAGE.\n\n\n-------------------------------------------------------------------------------\nSOFTWARE DISTRIBUTED WITH CUT CROSS ENTROPY:\n\nThe Cut Cross Entropy software includes a number of subcomponents with separate\ncopyright notices and license terms - please see the file ACKNOWLEDGEMENTS.md.\n-------------------------------------------------------------------------------\n"
  },
  {
    "path": "src/axolotl/integrations/cut_cross_entropy/README.md",
    "content": "# Cut Cross Entropy\n\nCut Cross Entropy (CCE) reduces VRAM usage through optimization on the cross-entropy operation during loss calculation.\n\nSee https://github.com/apple/ml-cross-entropy\n\n## Requirements\n\n- PyTorch 2.4.0 or higher\n\n## Installation\n\nRun the following command to install `cut_cross_entropy[transformers]` if you don't have it already.\n\n- If you are in dev environment\n```bash\npython scripts/cutcrossentropy_install.py | sh\n```\n\n- If you are installing from pip\n```bash\npip3 uninstall -y cut-cross-entropy && pip3 install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6\"\n```\n\n## Usage\n\n```yaml\nplugins:\n  - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin\n```\n\n## Supported Models\n\n- afmoe\n- apertus\n- arcee\n- cohere\n- cohere2\n- deepseek_v3\n- exaone4\n- gemma\n- gemma2\n- gemma3\n- gemma3_text\n- gemma3n\n- gemma3n_text\n- glm\n- glm4\n- glm4_moe\n- glm4_moe_lite\n- glm46v\n- glm4v\n- glm4v_moe\n- glm_image\n- glm_moe_dsa\n- gpt_oss\n- granite\n- granitemoe\n- granitemoehybrid\n- granitemoeshared\n- hunyuan_v1_dense\n- hunyuan_v1_moe\n- internvl\n- kimi_linear\n- lfm2\n- lfm2_moe\n- lfm2_vl\n- llama\n- llama4\n- llama4_text\n- llava\n- ministral\n- ministral3\n- mistral\n- mistral3\n- mistral4\n- mixtral\n- mllama\n- nemotron_h\n- olmo\n- olmo2\n- olmo3\n- olmoe\n- phi\n- phi3\n- phi4_multimodal\n- qwen2\n- qwen2_5_vl\n- qwen2_moe\n- qwen2_vl\n- qwen3\n- qwen3_5\n- qwen3_5_text\n- qwen3_5_moe\n- qwen3_5_moe_text\n- qwen3_moe\n- qwen3_next\n- qwen3_vl\n- qwen3_vl_moe\n- seed_oss\n- smollm3\n- step3p5\n- voxtral\n\n## Citation\n\n```bib\n@article{wijmans2024cut,\n  author       = {Erik Wijmans and\n                  Brody Huval and\n                  Alexander Hertzberg and\n                  Vladlen Koltun and\n                  Philipp Kr\\\"ahenb\\\"uhl},\n  title        = {Cut Your Losses in Large-Vocabulary Language Models},\n  journal      = {arXiv},\n  year         = {2024},\n  url          = {https://arxiv.org/abs/2411.09009},\n}\n```\n"
  },
  {
    "path": "src/axolotl/integrations/cut_cross_entropy/__init__.py",
    "content": "# Copyright 2024 Axolotl AI. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nModule for the Plugin for Cut Cross Entropy integration with Axolotl.\n\nCut Cross Entropy is an optimized implementation of cross entropy loss\nfrom Apple's ML team.\n\"\"\"\n\nimport importlib\nfrom functools import partial\n\nimport torch\n\nfrom axolotl.integrations.base import BasePlugin\nfrom axolotl.utils import get_pytorch_version\nfrom axolotl.utils.callbacks.models import get_causal_lm_model_cls_prefix\nfrom axolotl.utils.logging import get_logger\n\nfrom .args import CutCrossEntropyArgs as CutCrossEntropyArgs\n\nLOG = get_logger(__name__)\n\n_CCE_INSTALL_MESSAGE = (\n    \"Please install Axolotl's fork of cut_cross_entropy with transformers support using \"\n    '`pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6\"`'\n)\n\n\nclass CutCrossEntropyPlugin(BasePlugin):\n    \"\"\"\n    Plugin for Cut Cross Entropy integration with Axolotl.\n    \"\"\"\n\n    def get_input_args(self):\n        return \"axolotl.integrations.cut_cross_entropy.CutCrossEntropyArgs\"\n\n    def _check_requirements(self):\n        \"\"\"Check if all requirements are met.\"\"\"\n        # Check PyTorch version\n\n        major, minor, _ = get_pytorch_version()\n        if (major, minor) < (2, 4):\n            raise ImportError(\n                \"Cut Cross Entropy requires PyTorch >= 2.4.0. \"\n                f\"Current version: {torch.__version__}\"\n            )\n\n        # Check if cut_cross_entropy is installed\n        cce_spec = importlib.util.find_spec(\"cut_cross_entropy\")\n        if cce_spec is None:\n            raise ImportError(_CCE_INSTALL_MESSAGE)\n\n        cce_spec_transformers = importlib.util.find_spec(\n            \"cut_cross_entropy.transformers\"\n        )\n        if cce_spec_transformers is None:\n            raise ImportError(\n                \"Transformers support is not installed. \" + _CCE_INSTALL_MESSAGE\n            )\n\n        # Check if Axolotl's cce fork is installed\n        try:\n            from cut_cross_entropy.transformers.patch import AXOLOTL_CCE_FORK\n\n            if not AXOLOTL_CCE_FORK:\n                raise ImportError\n        except ImportError as e:\n            raise ImportError(\n                \"Axolotl's fork of cut_cross_entropy is not installed. \"\n                + _CCE_INSTALL_MESSAGE\n            ) from e\n\n    def pre_model_load(self, cfg):\n        \"\"\"Apply cut cross entropy before model loading if enabled.\"\"\"\n        if cfg.cut_cross_entropy:\n            self._check_requirements()\n            self.patch_llama_like(cfg.model_config_type)\n\n            from cut_cross_entropy.transformers.patch import cce_patch\n\n            LOG.info(\n                f\"Applying Cut Cross Entropy to model type: {cfg.model_config_type}\"\n            )\n\n            # The patch checks model_type internally\n\n            cce_patch(\n                cfg.model_config_type,\n                remote_model_id=cfg.base_model if cfg.trust_remote_code else None,\n            )\n\n    def patch_llama_like(\n        self,\n        model_type_to_patch: str,\n    ) -> None:\n        \"\"\"\n        Generic patch for model architectures with causal lm similar to llama\n        \"\"\"\n        from cut_cross_entropy.transformers.patch import PATCH_FNS\n\n        def patch_generic(\n            maybe_model,\n            patch_options,\n            remote_model_id: str | None,\n            model_type: str,\n        ):\n            import cut_cross_entropy.transformers.llama\n            from cut_cross_entropy.transformers.llama import cce_forward\n\n            try:\n                # Dynamically import the module and CausalLM class\n                module_path = f\"transformers.models.{model_type}.modeling_{model_type}\"\n                model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type)\n                module = __import__(\n                    module_path, fromlist=[f\"{model_cls_prefix}ForCausalLM\"]\n                )\n                model_cls = getattr(module, f\"{model_cls_prefix}ForCausalLM\")\n\n                cut_cross_entropy.transformers.llama._PATCH_OPTS = patch_options\n\n                model_cls.forward = cce_forward\n\n            except (ImportError, AttributeError) as e:\n                raise RuntimeError(\n                    f\"Could not import ForCausalLM class for model_type: {model_type}. \"\n                    f\"Error: {str(e)}\"\n                ) from e\n\n        if model_type_to_patch not in PATCH_FNS:\n            LOG.warning_once(\n                \"Setting up generic cce patch for model type: %s\", model_type_to_patch\n            )\n            LOG.warning_once(\n                f\"Generic Cut Cross Entropy + {model_type_to_patch} support is experimental and may not work as expected.\"\n            )\n            PATCH_FNS[model_type_to_patch] = partial(\n                patch_generic, model_type=model_type_to_patch\n            )\n"
  },
  {
    "path": "src/axolotl/integrations/cut_cross_entropy/args.py",
    "content": "# Copyright 2024 Axolotl AI. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nModule for handling Cut Cross Entropy input arguments.\n\"\"\"\n\nfrom typing import Optional\n\nfrom pydantic import BaseModel, model_validator\n\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\n\nclass CutCrossEntropyArgs(BaseModel):\n    \"\"\"\n    Input args for Cut Cross Entropy.\n    \"\"\"\n\n    cut_cross_entropy: Optional[bool] = True\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_dtype_is_half(cls, data):\n        if data.get(\"cut_cross_entropy\") and not (data.get(\"bf16\") or data.get(\"fp16\")):\n            raise ValueError(\n                \"Cut Cross Entropy requires fp16/bf16 training for backward pass. \"\n                \"Please set `bf16` or `fp16` to `True`.\"\n            )\n\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_chunked_cross_entropy_not_set(cls, data):\n        if data.get(\"chunked_cross_entropy\"):\n            raise ValueError(\n                \"Cut Cross Entropy does not support chunked cross entropy. \"\n                \"Please set `chunked_cross_entropy` to `False` or disable Cut Cross Entropy.\"\n            )\n        return data\n"
  },
  {
    "path": "src/axolotl/integrations/densemixer/README.md",
    "content": "# DenseMixer\n\nSee [DenseMixer](https://github.com/yaof20/DenseMixer/)\n\n# Usage\n\nSimply add the following to your axolotl YAML config:\n\n```yaml\nplugins:\n  - axolotl.integrations.densemixer.DenseMixerPlugin\n```\n"
  },
  {
    "path": "src/axolotl/integrations/densemixer/__init__.py",
    "content": "\"\"\"Integration entry point for the DenseMixer plugin.\"\"\"\n\nfrom .plugin import DenseMixerPlugin\n\n__all__ = [\"DenseMixerPlugin\"]\n"
  },
  {
    "path": "src/axolotl/integrations/densemixer/args.py",
    "content": "\"\"\"Pydantic models for DenseMixer plugin\"\"\"\n\nfrom pydantic import BaseModel\n\n\nclass DenseMixerArgs(BaseModel):\n    \"\"\"\n    Args for DenseMixer\n    \"\"\"\n\n    dense_mixer: bool = True\n"
  },
  {
    "path": "src/axolotl/integrations/densemixer/plugin.py",
    "content": "\"\"\"DenseMixer plugin for Axolotl\"\"\"\n\nimport importlib\n\nfrom axolotl.integrations.base import BasePlugin\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\n\nclass DenseMixerPlugin(BasePlugin):\n    \"\"\"\n    Plugin for DenseMixer\n    \"\"\"\n\n    def get_input_args(self) -> str | None:\n        return \"axolotl.integrations.densemixer.args.DenseMixerArgs\"\n\n    def pre_model_load(self, cfg):\n        \"\"\"Apply densemixer patches before model loading if enabled.\"\"\"\n        if cfg.dense_mixer:\n            if not importlib.util.find_spec(\"densemixer\"):\n                raise RuntimeError(\n                    \"DenseMixer is not installed. Install it with `pip install densemixer`\"\n                )\n\n            from densemixer.patching import (\n                apply_olmoe_patch,\n                apply_qwen2_moe_patch,\n                apply_qwen3_moe_patch,\n            )\n\n            LOG.info(\n                f\"Applying DenseMixer patches for model type: {cfg.model_config_type}\"\n            )\n\n            if cfg.model_config_type == \"olmoe\":\n                apply_olmoe_patch()\n            if cfg.model_config_type == \"qwen2_moe\":\n                apply_qwen2_moe_patch()\n            if cfg.model_config_type == \"qwen3_moe\":\n                apply_qwen3_moe_patch()\n"
  },
  {
    "path": "src/axolotl/integrations/diffusion/README.md",
    "content": "# Diffusion LM Training Plugin for Axolotl\n\nThis plugin enables diffusion language model training using an approach inspired by\nLLaDA (Large Language Diffusion Models) within Axolotl.\n\n## Overview\n\nLLaDA is a diffusion-based approach to language model training that uses:\n- **Random token masking** during training instead of next-token prediction\n- **Bidirectional attention** to allow the model to attend to the full context\n- **Importance weighting** based on masking probabilities for stable training\n\nThis approach can lead to more robust language models with better understanding of\nbidirectional context.\n\n## Installation\n\nThe plugin is included with Axolotl. See our\n[installation docs](https://docs.axolotl.ai/docs/installation.html).\n\n## Quickstart\n\nTrain with an example config (Llama‑3.2 1B):\n   - Pretrain: `axolotl train examples/llama-3/diffusion-3.2-1b-pretrain.yaml`\n   - SFT: `axolotl train examples/llama-3/diffusion-3.2-1b-sft.yaml`\n\n### Basic Configuration\n\nYou can also modify your existing configs to enable / customize diffusion training.\n\nAdd the following to your Axolotl config:\n\n```yaml\n# Enable diffusion LM training plugin\nplugins:\n  - axolotl.integrations.diffusion.DiffusionPlugin\n```\n\nAnd, configure the nested `diffusion` block (defaults shown):\n\n```yaml\ndiffusion:\n  noise_schedule: linear  # or \"cosine\"\n  min_mask_ratio: 0.1\n  max_mask_ratio: 0.9\n  num_diffusion_steps: 128\n  eps: 1e-3\n  importance_weighting: true\n\n  # Mask token (training auto-adds if missing, avoid pad/eos)\n  mask_token_str: \"<|diffusion_mask|>\"\n  # Or use an existing special token id (e.g., 128002 for Llama-3.x)\n  # mask_token_id: 128002\n\n  # Sample generation during training (optional)\n  generate_samples: true\n  generation_interval: 100\n  num_generation_samples: 3\n  generation_steps: 128\n  generation_temperature: 0.0\n  generation_max_length: 100\n```\n\n## Supported Models\n\nAny models that support 4D attention masks should work out of the box. If not, please\ncreate an [issue](https://github.com/axolotl-ai-cloud/axolotl/issues) or open a\n[PR](https://github.com/axolotl-ai-cloud/axolotl/compare)!\n\n## How It Works\n\n### Random Masking\nDuring training, tokens are randomly masked:\n- Sample timestep `t` uniformly from [0, 1]\n- Calculate masking probability: `p = (1 - eps) * t + eps`\n- Randomly mask tokens with probability `p`\n\n### Diffusion Loss\n\nLoss is computed only on masked tokens with (optional) importance weighting:\n\n```python\nloss = sum(cross_entropy(pred, target) / p_mask) / total_tokens\n```\n\n## Sample Generation\n\nWhen `diffusion.generate_samples: true`, the plugin generates samples during training:\n\n```\nSample 1:\n   Original (45 tokens): The quick brown fox jumps over the lazy dog...\n   Masked (18/45 tokens, 40.0%): The [MASK] [MASK] fox [MASK] over [MASK] lazy [MASK]...\n   Generated: The quick brown fox jumps over the lazy dog...\n```\n\nSamples are logged to console and wandb (if enabled).\n\n## Inference\n\nDiffusion inference is integrated into the standard Axolotl CLI. Use the same config\nyou trained with and run:\n\n```\naxolotl inference path/to/your-config.yaml\n```\n\nOptionally, pass `--gradio` to use a simple web interface.\n\nInteractive controls (prefix the prompt with commands):\n- `:complete N` → completion mode with N new masked tokens appended (default 64)\n- `:mask R` → random masking mode with target mask ratio R in [0.0, 1.0]\n\nExample session:\n\n```\n================================================================================\nCommands:\n:complete N -> completion mode with N tokens (default 64)\n:mask R     -> random masking with ratio R (0.0–1.0)\n================================================================================\nGive me an instruction (Ctrl + D to submit):\n\n:mask 0.4 The quick brown fox jumps over the lazy dog\n\nMasked (40.0%):\nThe [MASK] brown [MASK] jumps over the [MASK] dog\n\nGenerated:\nThe quick brown fox jumps over the loud dog\n```\n\n## Metrics and Monitoring\n\nThe plugin adds (or modifies) several metrics to track diffusion training:\n\n- `train/loss`: Weighted diffusion loss\n- `train/accuracy`: Accuracy on masked tokens\n- `train/mask_ratio`: Average fraction of tokens masked\n- `train/num_masked_tokens`: Number of tokens masked\n- `train/avg_p_mask`: Average masking probability\n- `train/ce_loss`: Unweighted cross-entropy loss\n- `train/importance_weight_avg`: Average importance weight\n\n## Limitations\n\n- No flash attention support\n- No RL training support\n\n## References\n\n- [LLaDA Paper](https://arxiv.org/abs/2404.10406)\n- [Axolotl Documentation](https://docs.axolotl.ai/)\n- [API reference for plugin](https://docs.axolotl.ai/docs/api/integrations.diffusion.args.html#axolotl.integrations.diffusion.args)\n"
  },
  {
    "path": "src/axolotl/integrations/diffusion/__init__.py",
    "content": "\"\"\"Diffusion LM training plugin init.\"\"\"\n\nfrom .args import DiffusionArgs, DiffusionConfig\nfrom .callbacks import DiffusionGenerationCallback\nfrom .generation import generate\nfrom .plugin import DiffusionPlugin\nfrom .trainer import DiffusionTrainer\nfrom .utils import create_bidirectional_attention_mask, resolve_mask_token_id\n\n__all__ = [\n    \"DiffusionArgs\",\n    \"DiffusionPlugin\",\n    \"DiffusionTrainer\",\n    \"generate\",\n    \"resolve_mask_token_id\",\n    \"create_bidirectional_attention_mask\",\n    \"DiffusionGenerationCallback\",\n    \"DiffusionConfig\",\n]\n"
  },
  {
    "path": "src/axolotl/integrations/diffusion/args.py",
    "content": "\"\"\"Config args for diffusion LM training (nested under `diffusion:`).\"\"\"\n\nfrom __future__ import annotations\n\nfrom typing import Literal\n\nfrom pydantic import BaseModel, Field, model_validator\n\n\nclass DiffusionConfig(BaseModel):\n    \"\"\"Nested diffusion configuration available under the `diffusion` key.\"\"\"\n\n    # Noise schedule config\n    noise_schedule: Literal[\"linear\", \"cosine\"] = Field(\n        default=\"linear\", description=\"Type of noise schedule for diffusion training\"\n    )\n    min_mask_ratio: float = Field(\n        default=0.1,\n        ge=0.0,\n        le=1.0,\n        description=\"Minimum masking ratio for diffusion noise schedule\",\n    )\n    max_mask_ratio: float = Field(\n        default=0.9,\n        ge=0.0,\n        le=1.0,\n        description=\"Maximum masking ratio for diffusion noise schedule\",\n    )\n    num_diffusion_steps: int = Field(\n        default=128, ge=1, description=\"Number of diffusion timesteps\"\n    )\n    eps: float = Field(\n        default=1e-3,\n        ge=0.0,\n        le=1.0,\n        description=\"Epsilon value for minimum masking probability in forward process\",\n    )\n\n    # Training config\n    importance_weighting: bool = Field(\n        default=True,\n        description=\"Apply importance weighting to loss based on masking probability\",\n    )\n    mask_token_id: int | None = Field(\n        default=None,\n        description=(\n            \"Token ID to use for masking. Unset by default; can use one of the \"\n            \"tokenizer's special tokens here.\"\n        ),\n    )\n    mask_token_str: str | None = Field(\n        default=None,\n        description=(\n            \"Token string to use as a mask. If `mask_token_id` is invalid or unset, \"\n            \"this token will be ensured to exist as an additional special token and \"\n            \"used. If absent, a default '<|diffusion_mask|>' will be added.\"\n        ),\n    )\n\n    # Sample generation config\n    generate_samples: bool = Field(\n        default=True, description=\"Enable sample generation during training\"\n    )\n    generation_interval: int = Field(\n        default=100, ge=1, description=\"Generate samples every N steps\"\n    )\n    num_generation_samples: int = Field(\n        default=3, ge=1, description=\"Number of samples to generate each time\"\n    )\n    generation_steps: int = Field(\n        default=128, ge=1, description=\"Number of diffusion steps for generation\"\n    )\n    generation_temperature: float = Field(\n        default=0.0,\n        ge=0.0,\n        description=\"Temperature for generation sampling (0.0 = deterministic)\",\n    )\n    generation_max_length: int = Field(\n        default=100, ge=1, description=\"Maximum sequence length for generation\"\n    )\n\n    @model_validator(mode=\"after\")\n    def _validate_mask_ratios(self) -> \"DiffusionConfig\":\n        if self.min_mask_ratio > self.max_mask_ratio:\n            raise ValueError(\"min_mask_ratio must be ≤ max_mask_ratio\")\n        return self\n\n\nclass DiffusionArgs(BaseModel):\n    \"\"\"Plugin entry that exposes the nested `diffusion` block to the core config.\"\"\"\n\n    diffusion: DiffusionConfig = Field(\n        default_factory=DiffusionConfig,\n        description=\"Diffusion training configuration. Only nested block is supported.\",\n    )\n"
  },
  {
    "path": "src/axolotl/integrations/diffusion/callbacks.py",
    "content": "\"\"\"Callbacks for diffusion training.\"\"\"\n\nimport logging\nimport sys\n\nimport wandb\nfrom colorama import Fore, Style\nfrom transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState\nfrom transformers.training_args import TrainingArguments\n\nfrom .generation import generate_samples\n\n# Simpler logger for more readable sample generation\nlogger = logging.getLogger(__name__)\nif not logger.handlers:\n    handler = logging.StreamHandler(sys.stdout)\n    handler.setFormatter(logging.Formatter(\"%(message)s\"))\n    logger.addHandler(handler)\n    logger.propagate = False\nlogger.setLevel(logging.INFO)\n\n\nclass DiffusionGenerationCallback(TrainerCallback):\n    \"\"\"Callback for generating samples during diffusion training.\"\"\"\n\n    def __init__(self, trainer):\n        self.trainer = trainer\n\n    def on_step_end(\n        self,\n        args: TrainingArguments,\n        state: TrainerState,\n        control: TrainerControl,\n        **kwargs,\n    ):\n        \"\"\"Generate samples at specified intervals.\"\"\"\n        if (\n            state.global_step > 0\n            and state.global_step % self.trainer.cfg.diffusion.generation_interval == 0\n        ):\n            if not self.trainer.state.is_world_process_zero:\n                return\n\n            # Use eval dataloader if available, otherwise use train dataloader\n            dataloader = None\n            try:\n                if getattr(self.trainer, \"eval_dataset\", None) is not None:\n                    dataloader = self.trainer.get_eval_dataloader()\n            except Exception:\n                dataloader = None\n            if dataloader is None:\n                dataloader = self.trainer.get_train_dataloader()\n\n            # Generate samples\n            diffusion_cfg = self.trainer.cfg.diffusion\n            samples = generate_samples(\n                model=self.trainer.model,\n                tokenizer=self.trainer.processing_class,\n                dataloader=dataloader,\n                num_generation_samples=diffusion_cfg.num_generation_samples,\n                max_length=diffusion_cfg.generation_max_length,\n                num_diffusion_steps=diffusion_cfg.generation_steps,\n                temperature=diffusion_cfg.generation_temperature,\n                mask_token_id=diffusion_cfg.mask_token_id,\n            )\n\n            # Log samples\n            self._log_samples(samples, state.global_step)\n\n    def _log_samples(self, samples: list, step: int):\n        \"\"\"Log generated samples.\"\"\"\n        if not samples:\n            return\n\n        logger.info(\"=\" * 60)\n        logger.info(\"GENERATED SAMPLES\")\n        logger.info(\"=\" * 60)\n\n        for i, sample_data in enumerate(samples, 1):\n            original = sample_data[\"original\"]\n            masked = sample_data[\"masked\"]\n            generated = sample_data[\"generated\"]\n            mask_ratio = sample_data[\"mask_ratio\"]\n            masked_tokens = sample_data[\"masked_tokens\"]\n            total_tokens = sample_data[\"total_tokens\"]\n\n            logger.info(f\"\\nSample {i}:\")\n            logger.info(f\"\\tOriginal ({total_tokens} tokens): {original}\")\n            logger.info(\n                f\"\\tMasked ({masked_tokens}/{total_tokens} tokens, \"\n                f\"{mask_ratio:.1%}): {masked}\"\n            )\n\n            try:\n                gen_ids = sample_data.get(\"generated_ids\")\n                orig_ids = sample_data.get(\"orig_ids\")\n                masked_positions = set(sample_data.get(\"masked_positions\") or [])\n                if isinstance(gen_ids, list) and isinstance(orig_ids, list):\n                    styles: list[str] = []\n                    for i, tid in enumerate(gen_ids):\n                        if i in masked_positions:\n                            if i < len(orig_ids) and tid == orig_ids[i]:\n                                styles.append(\"green\")\n                            elif i < len(orig_ids):\n                                styles.append(\"red\")\n                            else:\n                                styles.append(\"normal\")\n                        else:\n                            same = i < len(orig_ids) and tid == orig_ids[i]\n                            styles.append(\"dim\" if same else \"normal\")\n\n                    spans: list[tuple[str, int, int]] = []\n                    if gen_ids:\n                        cur = styles[0]\n                        start = 0\n                        for i in range(1, len(gen_ids)):\n                            s = styles[i]\n                            if s != cur:\n                                spans.append((cur, start, i))\n                                cur, start = s, i\n                        spans.append((cur, start, len(gen_ids)))\n\n                    parts = []\n                    for style_name, a, b in spans:\n                        chunk_text = self.trainer.processing_class.decode(\n                            gen_ids[a:b], skip_special_tokens=False\n                        )\n                        if style_name == \"green\":\n                            parts.append(Fore.GREEN + chunk_text + Style.RESET_ALL)\n                        elif style_name == \"red\":\n                            parts.append(Fore.RED + chunk_text + Style.RESET_ALL)\n                        else:\n                            if style_name == \"dim\":\n                                parts.append(Style.DIM + chunk_text + Style.RESET_ALL)\n                            else:\n                                parts.append(chunk_text)\n                    logger.info(\"\\tGenerated:\\n%s\", \"\".join(parts))\n                else:\n                    logger.info(f\"\\tGenerated: {generated}\")\n            except Exception:\n                logger.info(f\"\\tGenerated: {generated}\")\n\n        logger.info(\"=\" * 60)\n\n        if self.trainer.cfg.use_wandb:\n            if wandb.run is not None:\n                wandb.log(\n                    {\n                        \"generated_samples\": wandb.Table(\n                            columns=[\n                                \"step\",\n                                \"original\",\n                                \"masked\",\n                                \"generated\",\n                                \"mask_ratio\",\n                                \"masked_tokens\",\n                                \"total_tokens\",\n                            ],\n                            data=[\n                                [\n                                    step,\n                                    sample[\"original\"],\n                                    sample[\"masked\"],\n                                    sample[\"generated\"],\n                                    f\"{sample['mask_ratio']:.1%}\",\n                                    sample[\"masked_tokens\"],\n                                    sample[\"total_tokens\"],\n                                ]\n                                for sample in samples\n                            ],\n                        )\n                    },\n                    step=step,\n                )\n"
  },
  {
    "path": "src/axolotl/integrations/diffusion/generation.py",
    "content": "\"\"\"Sample generation utilities for diffusion training.\"\"\"\n\nimport re\nfrom typing import Any, List, Literal, Optional\n\nimport torch\n\nfrom axolotl.utils.logging import get_logger\n\nfrom .utils import create_bidirectional_attention_mask, shift_logits_to_input_positions\n\nLOG = get_logger(__name__)\n\n\ndef generate_samples(\n    model: torch.nn.Module,\n    tokenizer: Any,\n    dataloader: Optional[Any] = None,\n    num_generation_samples: int = 3,\n    max_length: int = 100,\n    num_diffusion_steps: int = 128,\n    temperature: float = 0.0,\n    mask_token_id: int = 32000,\n    mode: Literal[\"random\", \"completion\"] = \"random\",\n    completion_tokens: int = 0,\n    target_mask_ratio: Optional[float] = None,\n) -> List[dict]:\n    \"\"\"\n    Generate text samples using the diffusion model by randomly masking sequences from\n    the given dataset and running the reverse diffusion process.\n\n    Args:\n        model: The wrapped or unwrapped model\n        tokenizer: Tokenizer for encoding/decoding\n        dataloader: Validation dataloader (for sampling sequences)\n        num_generation_samples: Number of samples to generate\n        max_length: Maximum length of sequences to use\n        num_diffusion_steps: Number of diffusion steps for generation\n        temperature: Temperature for sampling (0.0 = deterministic)\n        mask_token_id: Token ID used for masking\n\n    Returns:\n        List of dictionaries with original text, masked text, and generated text\n    \"\"\"\n    if dataloader is None:\n        LOG.warning(\"No validation dataloader provided, cannot generate samples\")\n        return []\n\n    unwrapped_model = model.module if hasattr(model, \"module\") else model\n    training = unwrapped_model.training\n    unwrapped_model.eval()\n\n    # Resolve device robustly (some modules don't expose `.device`)\n    device = getattr(unwrapped_model, \"device\", None)\n    if device is None:\n        try:\n            device = next(unwrapped_model.parameters()).device\n        except StopIteration:\n            device = torch.device(\"cpu\")\n    generations = []\n\n    # Sample sequences from validation dataset\n    sampled_sequences = _sample_sequences_from_dataloader(\n        dataloader, num_generation_samples, max_length, device\n    )\n    LOG.info(f\"Sampled {len(sampled_sequences)} sequences from validation dataset\")\n\n    # Generate samples using reverse diffusion process\n    with torch.no_grad():\n        for sample in sampled_sequences:\n            if isinstance(sample, dict):\n                original_sequence = sample.get(\"input_ids\")\n                labels_seq = sample.get(\"labels\")\n                attn_seq = sample.get(\"attention_mask\")\n            else:\n                original_sequence = sample\n                labels_seq = None\n                attn_seq = None\n            generation_result = generate(\n                unwrapped_model,\n                tokenizer,\n                original_sequence,\n                num_diffusion_steps,\n                temperature,\n                mask_token_id,\n                mode=mode,\n                completion_tokens=completion_tokens,\n                target_mask_ratio=target_mask_ratio,\n                labels=labels_seq,\n                attention_mask=attn_seq,\n            )\n            generations.append(generation_result)\n\n    # Restore prior training state\n    if training:\n        unwrapped_model.train()\n    else:\n        unwrapped_model.eval()\n\n    return generations\n\n\ndef _sample_sequences_from_dataloader(\n    dataloader: Any, num_samples: int, max_length: int, device: torch.device\n) -> List[Any]:\n    \"\"\"Sample sequences from validation dataloader.\"\"\"\n    sampled_sequences: list[dict[str, torch.Tensor] | torch.Tensor] = []\n    sample_count = 0\n\n    # Skip a random number of batches (we could be more clever about this)\n    skip_batches = torch.randint(0, 10, (1,)).item()\n    batch_count = 0\n\n    for batch in dataloader:\n        # Skip some batches for variety\n        if batch_count < skip_batches:\n            batch_count += 1\n            continue\n\n        if sample_count >= num_samples:\n            break\n\n        batch_count += 1\n        input_ids = batch[\"input_ids\"]\n        attention_mask = batch.get(\"attention_mask\")\n        labels = batch.get(\"labels\")\n\n        # Randomly sample from sequences in this batch\n        batch_indices = torch.randperm(input_ids.size(0)).tolist()\n\n        for i in batch_indices:\n            if sample_count >= num_samples:\n                break\n\n            # Get actual sequence length (non-padded)\n            if attention_mask is not None:\n                seq_len = attention_mask[i].sum().item()\n            else:\n                seq_len = input_ids.size(1)\n\n            if seq_len < 10:\n                continue\n\n            # Determine truncation length\n            max_total = min(seq_len, max_length)\n            if labels is not None:\n                labels_i = labels[i][:seq_len]\n                answer_mask = labels_i != -100\n                if not answer_mask.any():\n                    # No answer tokens; skip for SFT masking\n                    continue\n                first_ans_idx = int(\n                    torch.nonzero(answer_mask, as_tuple=False)[0].item()\n                )\n                prompt_len = first_ans_idx\n                if prompt_len >= max_total:\n                    # Prompt alone reaches cap; cannot include any answer\n                    continue\n                remaining_answer = int(answer_mask[prompt_len:].sum().item())\n                allowed_answer = max_total - prompt_len\n                take_answer = min(remaining_answer, allowed_answer)\n                if take_answer <= 0:\n                    continue\n                actual_length = prompt_len + take_answer\n            else:\n                actual_length = max_total\n\n            # Extract the (possibly truncated) sequence\n            sequence = input_ids[i][:actual_length].unsqueeze(0).to(device)\n            attn_seq = (\n                attention_mask[i][:actual_length].unsqueeze(0).to(device)\n                if attention_mask is not None\n                else None\n            )\n            if labels is not None:\n                labels_seq = labels[i][:actual_length].unsqueeze(0).to(device)\n                sampled_sequences.append(\n                    {\n                        \"input_ids\": sequence,\n                        \"labels\": labels_seq,\n                        \"attention_mask\": attn_seq,\n                    }\n                )\n            else:\n                if attn_seq is not None:\n                    sampled_sequences.append(\n                        {\"input_ids\": sequence, \"attention_mask\": attn_seq}\n                    )\n                else:\n                    sampled_sequences.append(sequence)\n            sample_count += 1\n\n    return sampled_sequences\n\n\ndef generate(\n    model: torch.nn.Module,\n    tokenizer: Any,\n    original_sequence: torch.Tensor,\n    num_diffusion_steps: int,\n    temperature: float,\n    mask_token_id: int,\n    *,\n    mode: Literal[\"random\", \"completion\"] = \"random\",\n    completion_tokens: int = 0,\n    target_mask_ratio: Optional[float] = None,\n    labels: Optional[torch.Tensor] = None,\n    attention_mask: Optional[torch.Tensor] = None,\n) -> dict:\n    \"\"\"Generate a single sample using reverse diffusion.\"\"\"\n    # Get original text for comparison\n    original_text = tokenizer.decode(\n        original_sequence[0].cpu(), skip_special_tokens=True\n    )\n\n    # Build masked sequence\n    if (\n        labels is not None\n        and labels.numel() > 0\n        and (labels == -100).any()\n        and (labels != -100).any()\n    ):\n        # SFT case: completely mask all answer tokens (labels != -100)\n        total_tokens = original_sequence.size(1)\n        masked_indices = (labels != -100).to(dtype=torch.bool)\n        masked_sequence = original_sequence.clone()\n        masked_sequence[masked_indices] = mask_token_id\n        masked_tokens = int(masked_indices.sum().item())\n        mask_ratio = masked_tokens / max(int(total_tokens), 1)\n    elif mode == \"completion\" and completion_tokens > 0:\n        # Append mask tokens to the right for completion\n        total_tokens = original_sequence.size(1) + int(completion_tokens)\n        masked_indices = torch.zeros(\n            1, total_tokens, dtype=torch.bool, device=original_sequence.device\n        )\n        masked_indices[0, -int(completion_tokens) :] = True\n\n        append = torch.full(\n            (1, int(completion_tokens)), mask_token_id, device=original_sequence.device\n        )\n        masked_sequence = torch.cat([original_sequence, append], dim=1)\n        masked_tokens = int(completion_tokens)\n        mask_ratio = masked_tokens / total_tokens\n    else:\n        # Apply random masking with optional fixed ratio\n        total_tokens = original_sequence.size(1)\n        if target_mask_ratio is None:\n            min_ratio, max_ratio = 0.1, 0.7\n            target_mask_ratio = (\n                torch.rand(1).item() * (max_ratio - min_ratio) + min_ratio\n            )\n        target_masked_tokens = max(1, int(total_tokens * float(target_mask_ratio)))\n\n        # Create random mask indices\n        mask_positions = torch.randperm(total_tokens)[:target_masked_tokens]\n        masked_indices = torch.zeros(\n            1, total_tokens, dtype=torch.bool, device=original_sequence.device\n        )\n        masked_indices[0, mask_positions] = True\n\n        # Create masked sequence\n        masked_sequence = original_sequence.clone()\n        masked_sequence[masked_indices] = mask_token_id\n\n        # Calculate actual mask ratio\n        masked_tokens = masked_indices.sum().item()\n        mask_ratio = masked_tokens / total_tokens\n\n    # Get masked text for comparison\n    masked_text = tokenizer.decode(masked_sequence[0].cpu(), skip_special_tokens=False)\n    masked_text = _clean_masked_text(masked_text, tokenizer, mask_token_id)\n\n    # Run reverse diffusion process\n    sequence = masked_sequence.clone()\n    attention_mask = create_bidirectional_attention_mask(\n        sequence, attention_mask, sample_packing=attention_mask is not None\n    )\n    for step in range(num_diffusion_steps):\n        sequence = _diffusion_step(\n            model,\n            sequence,\n            step,\n            num_diffusion_steps,\n            temperature,\n            mask_token_id,\n            attention_mask,\n        )\n    generated_text = tokenizer.decode(sequence[0].cpu(), skip_special_tokens=True)\n\n    # Collect diagnostic info\n    final_ids = sequence[0].detach().cpu().tolist()\n    orig_ids_for_render = original_sequence[0].detach().cpu().tolist()\n    if masked_indices is not None:\n        masked_positions = (\n            torch.where(masked_indices[0])[0].detach().cpu().tolist()\n            if masked_indices.ndim == 2\n            else []\n        )\n    else:\n        masked_positions = []\n\n    result = {\n        \"original\": original_text,\n        \"masked\": masked_text,\n        \"generated\": generated_text,\n        \"mask_ratio\": mask_ratio,\n        \"masked_tokens\": masked_tokens,\n        \"total_tokens\": total_tokens,\n        \"generated_ids\": final_ids,\n        \"masked_positions\": masked_positions,\n        \"orig_ids\": orig_ids_for_render,\n        \"formatted\": (\n            f\"Original: '{original_text}' → Masked: '{masked_text}' \"\n            f\"({mask_ratio:.1%}) → Generated: '{generated_text}'\"\n        ),\n    }\n\n    return result\n\n\ndef _clean_masked_text(masked_text: str, tokenizer: Any, mask_token_id: int) -> str:\n    \"\"\"Clean up masked text for display.\"\"\"\n    mask_token_repr = tokenizer.decode([mask_token_id], skip_special_tokens=False)\n    cleaned = masked_text.replace(mask_token_repr, \"[MASK]\")\n\n    # Remove literal special token strings\n    if hasattr(tokenizer, \"special_tokens_map\"):\n        for token_value in tokenizer.special_tokens_map.values():\n            if token_value and isinstance(token_value, str):\n                cleaned = cleaned.replace(token_value, \"\")\n\n    # Normalize whitespace but preserve newlines\n    cleaned = cleaned.replace(\"\\r\\n\", \"\\n\").replace(\"\\r\", \"\\n\")\n    cleaned = re.sub(r\"[ \\t]+\", \" \", cleaned)\n    cleaned = \"\\n\".join(line.rstrip() for line in cleaned.split(\"\\n\")).strip()\n    return cleaned\n\n\ndef _diffusion_step(\n    model: torch.nn.Module,\n    sequence: torch.Tensor,\n    step: int,\n    num_diffusion_steps: int,\n    temperature: float,\n    mask_token_id: int,\n    attention_mask: torch.Tensor | None = None,\n) -> torch.Tensor:\n    \"\"\"Perform a single diffusion step with remasking.\"\"\"\n    # Only process if there are masked tokens remaining\n    current_mask = sequence == mask_token_id\n    if not current_mask.any():\n        return sequence\n\n    # Create or use provided attention mask\n    if attention_mask is None:\n        batch_size, seq_len = sequence.shape\n        attention_mask = torch.ones(\n            batch_size, 1, seq_len, seq_len, dtype=torch.bool, device=sequence.device\n        )\n\n    # Forward pass\n    outputs = model(input_ids=sequence, attention_mask=attention_mask)\n    logits = shift_logits_to_input_positions(outputs.logits)\n\n    # Only sample at currently masked positions\n    if current_mask.any():\n        masked_logits = logits[current_mask]\n\n        # Apply temperature scaling\n        if temperature > 0:\n            scaled_logits = masked_logits / temperature\n        else:\n            scaled_logits = masked_logits\n\n        # Suppress mask token in outputs\n        scaled_logits[:, mask_token_id] = -float(\"inf\")\n\n        if temperature > 0:\n            # Add Gumbel noise for sampling\n            gumbel_noise = -torch.log(\n                -torch.log(torch.rand_like(scaled_logits, dtype=torch.float32))\n            )\n            gumbel_logits = scaled_logits + gumbel_noise\n            predicted_tokens = torch.argmax(gumbel_logits, dim=-1)\n        else:\n            predicted_tokens = torch.argmax(scaled_logits, dim=-1)\n\n        # Calculate probabilities for confidence scoring\n        probs = torch.softmax(scaled_logits, dim=-1)\n        predicted_token_probs = probs[range(len(predicted_tokens)), predicted_tokens]\n\n        # Determine how many tokens to unmask this step\n        remaining_masked = current_mask.sum().item()\n        if step == num_diffusion_steps - 1:\n            num_to_unmask = remaining_masked\n        else:\n            unmask_ratio = 1.0 / (num_diffusion_steps - step)\n            num_to_unmask = max(1, int(remaining_masked * unmask_ratio))\n\n        # Select highest confidence predictions to unmask\n        if num_to_unmask >= remaining_masked:\n            sequence[current_mask] = predicted_tokens\n        else:\n            _, top_indices = predicted_token_probs.topk(num_to_unmask)\n            mask_positions = torch.where(current_mask)[1]\n            positions_to_unmask = mask_positions[top_indices]\n            sequence[0, positions_to_unmask] = predicted_tokens[top_indices]\n\n    return sequence\n"
  },
  {
    "path": "src/axolotl/integrations/diffusion/plugin.py",
    "content": "\"\"\"Diffusion LM training plugin for Axolotl.\"\"\"\n\nfrom peft import PeftModel\nfrom transformers import PreTrainedModel\n\nfrom axolotl.integrations.base import BasePlugin\nfrom axolotl.utils.dict import DictDefault\nfrom axolotl.utils.logging import get_logger\n\nfrom .trainer import DiffusionTrainer\n\nLOG = get_logger(__name__)\n\n\nclass DiffusionPlugin(BasePlugin):\n    \"\"\"\n    Plugin for diffusion language model training.\n\n    This plugin enables diffusion-based training using the LLaDA approach, which uses\n    random masking and bidirectional attention to train language models.\n    \"\"\"\n\n    def __init__(self):\n        super().__init__()\n        self.cfg = None\n\n    def get_input_args(self) -> str:\n        \"\"\"Returns the pydantic model for LLaDA plugin arguments.\"\"\"\n        return \"axolotl.integrations.diffusion.DiffusionArgs\"\n\n    def post_model_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):\n        \"\"\"Perform actions after model is loaded.\"\"\"\n        self.cfg = cfg\n\n    def get_trainer_cls(self, cfg: DictDefault) -> type[DiffusionTrainer] | None:\n        \"\"\"Return custom trainer class for diffusion training.\"\"\"\n        return DiffusionTrainer\n\n    def post_trainer_create(self, cfg: DictDefault, trainer: DiffusionTrainer):\n        \"\"\"Configure trainer after creation.\"\"\"\n        trainer.set_config(cfg)\n"
  },
  {
    "path": "src/axolotl/integrations/diffusion/trainer.py",
    "content": "\"\"\"Custom trainer for diffusion LM training.\"\"\"\n\nfrom typing import Any, Literal\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\n\nfrom axolotl.core.trainers.base import AxolotlTrainer\nfrom axolotl.utils.dict import DictDefault\nfrom axolotl.utils.logging import get_logger\n\nfrom .callbacks import DiffusionGenerationCallback\nfrom .utils import create_bidirectional_attention_mask, shift_logits_to_input_positions\n\nLOG = get_logger(__name__)\n\n\nclass DiffusionTrainer(AxolotlTrainer):\n    \"\"\"Custom trainer for diffusion LM training that overrides loss computation.\"\"\"\n\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.cfg = None\n        self._special_token_ids = None\n\n    def set_config(self, config: DictDefault):\n        \"\"\"Set config for diffusion training.\"\"\"\n        self.cfg = config\n        self._cache_special_token_ids()\n        self._resolve_mask_token_id()\n\n        token_id = int(getattr(self.cfg.diffusion, \"mask_token_id\", 0))\n        LOG.info(f\"Diffusion: using mask_token_id={token_id}\")\n\n        if getattr(config.diffusion, \"generate_samples\", True):\n            generation_callback = DiffusionGenerationCallback(self)\n            self.add_callback(generation_callback)\n\n    def _resolve_mask_token_id(self) -> None:\n        \"\"\"Ensure mask_token_id is valid for the current tokenizer.\"\"\"\n        from .utils import resolve_mask_token_id\n\n        tokenizer = getattr(self, \"processing_class\", None)\n        if tokenizer is None:\n            return\n\n        mid = resolve_mask_token_id(\n            tokenizer,\n            self.cfg,\n            allow_add=True,\n            model=getattr(self, \"model\", None),\n        )\n        try:\n            self.cfg.diffusion.mask_token_id = int(mid)\n        except Exception:\n            pass\n\n    def compute_loss(\n        self,\n        model: nn.Module,\n        inputs: dict[str, torch.Tensor],\n        return_outputs: bool = False,\n        num_items_in_batch: torch.Tensor | None = None,\n    ) -> torch.Tensor | tuple[torch.Tensor, dict[str, torch.Tensor]]:\n        \"\"\"Override compute_loss to use diffusion loss.\"\"\"\n        input_ids = inputs.get(\"input_ids\")\n        attention_mask = inputs.get(\"attention_mask\")\n        labels = inputs.get(\"labels\")\n\n        if input_ids is None:\n            raise ValueError(\"input_ids is required for diffusion training\")\n\n        loss, outputs = self._compute_diffusion_loss(\n            model, input_ids, attention_mask, labels\n        )\n\n        if return_outputs:\n            return loss, outputs\n        return loss\n\n    def _cache_special_token_ids(self):\n        \"\"\"Cache special token IDs to avoid repeated tokenizer access.\"\"\"\n        if self.processing_class is None:\n            self._special_token_ids = set()\n            return\n\n        tokenizer = self.processing_class\n        special_tokens = set()\n\n        if hasattr(tokenizer, \"bos_token_id\") and tokenizer.bos_token_id is not None:\n            special_tokens.add(tokenizer.bos_token_id)\n        if hasattr(tokenizer, \"eos_token_id\") and tokenizer.eos_token_id is not None:\n            special_tokens.add(tokenizer.eos_token_id)\n        if hasattr(tokenizer, \"pad_token_id\") and tokenizer.pad_token_id is not None:\n            special_tokens.add(tokenizer.pad_token_id)\n\n        self._special_token_ids = special_tokens\n\n    def _forward_process(\n        self,\n        input_ids: torch.Tensor,\n        attention_mask: torch.Tensor | None = None,\n        labels: torch.Tensor | None = None,\n        eps: float = 1e-3,\n    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n        \"\"\"\n        Forward noising process. A timestep is sampled along the process, and tokens are\n        masked with probability determined by the configured noise schedule.\n\n        Args:\n            input_ids: Input token ids [batch_size, seq_len].\n            attention_mask: Attention mask [batch_size, seq_len].\n            labels: Labels for SFT training [batch_size, seq_len].\n            eps: Small epsilon value for minimum masking probability.\n\n        Returns:\n            noisy_batch: Input with some tokens masked.\n            masked_indices: Boolean mask indicating which tokens were masked.\n            p_mask: Masking probabilities for each token [batch_size, seq_len].\n        \"\"\"\n        batch_size, seq_len = input_ids.shape\n        device = input_ids.device\n\n        # Sample random timesteps for each sample in batch\n        t = torch.rand(batch_size, device=device)\n        p_mask = (1 - eps) * t + eps  # [batch_size]\n        p_mask = p_mask[:, None].repeat(1, seq_len)  # [batch_size, seq_len]\n\n        # Don't mask padding tokens if attention_mask is provided\n        if attention_mask is not None:\n            valid_mask = attention_mask.bool()\n            p_mask = p_mask * valid_mask.float()\n\n        # Create mask to exclude special tokens\n        special_token_mask = torch.zeros_like(input_ids, dtype=torch.bool)\n        if self._special_token_ids:\n            for token_id in self._special_token_ids:\n                special_token_mask |= input_ids == token_id\n\n        # Create random mask based on p_mask\n        masked_indices = torch.rand((batch_size, seq_len), device=device) < p_mask\n        masked_indices = masked_indices & ~special_token_mask\n        if attention_mask is not None:\n            masked_indices = masked_indices & attention_mask.bool()\n\n        # For SFT data, only mask answer tokens\n        if labels is not None:\n            answer_mask = labels != -100\n            masked_indices = masked_indices & answer_mask\n\n        # Create masked input\n        mask_token_id = int(self.cfg.diffusion.mask_token_id)\n        mask_value = torch.full_like(input_ids, mask_token_id)\n        noisy_batch = torch.where(masked_indices, mask_value, input_ids)\n\n        return noisy_batch, masked_indices, p_mask\n\n    def _compute_diffusion_loss(\n        self,\n        model: nn.Module,\n        input_ids: torch.Tensor,\n        attention_mask: torch.Tensor | None = None,\n        labels: torch.Tensor | None = None,\n    ) -> tuple[torch.Tensor, torch.Tensor | Any]:\n        \"\"\"\n        Compute diffusion loss.\n\n        Args:\n            model: The model to compute loss for.\n            input_ids: Ground truth token ids [batch_size, seq_len].\n            attention_mask: Attention mask [batch_size, seq_len].\n            labels: Labels for SFT training [batch_size, seq_len].\n\n        Returns:\n            loss: Cross-entropy loss.\n            metrics: Dictionary of metrics.\n        \"\"\"\n        # Short-circuit empty sequences\n        if input_ids is None or input_ids.numel() == 0 or input_ids.shape[1] == 0:\n            zero = torch.tensor(\n                0.0,\n                device=(input_ids.device if input_ids is not None else None),\n                requires_grad=True,\n            )\n            return zero, {}\n\n        # If an attention_mask is provided and all positions are padding for every\n        # sample in this batch, skip the step.\n        if attention_mask is not None:\n            if attention_mask.dim() == 2 and (attention_mask.sum(dim=1) == 0).all():\n                zero = torch.tensor(0.0, device=input_ids.device, requires_grad=True)\n                return zero, {}\n\n        # Apply forward process\n        noisy_batch, masked_indices, p_mask = self._forward_process(\n            input_ids, attention_mask, labels, self.cfg.diffusion.eps\n        )\n\n        # Create bidirectional attention mask\n        bidirectional_mask = create_bidirectional_attention_mask(\n            input_ids, attention_mask, sample_packing=self.cfg.sample_packing\n        )\n\n        # Forward pass\n        outputs = model(\n            input_ids=noisy_batch.long(),\n            attention_mask=bidirectional_mask,\n        )\n        logits = shift_logits_to_input_positions(outputs.logits)\n\n        if masked_indices.sum() > 0:\n            valid_indices = torch.where(masked_indices)\n            batch_indices, seq_indices = valid_indices\n\n            masked_logits = logits[batch_indices, seq_indices]\n            masked_targets = input_ids[batch_indices, seq_indices]\n            masked_p_mask = p_mask[batch_indices, seq_indices]\n\n            # Compute cross-entropy loss without reduction\n            token_loss = F.cross_entropy(\n                masked_logits.float(), masked_targets, reduction=\"none\"\n            )\n\n            if self.cfg.diffusion.importance_weighting:\n                masked_p_mask = masked_p_mask.float()\n                weighted_loss = token_loss / masked_p_mask\n            else:\n                weighted_loss = token_loss\n\n            if labels is not None:\n                # For SFT data: normalize by answer token count per sample\n                answer_mask = labels != -100\n                answer_lengths = answer_mask.sum(dim=1).float()  # [batch_size]\n\n                # Get batch indices for masked tokens\n                masked_batch_indices = batch_indices\n\n                # Sum losses per sample and divide by answer length\n                batch_size = input_ids.shape[0]\n                loss_per_sample = torch.zeros(batch_size, device=input_ids.device)\n                for i in range(batch_size):\n                    sample_mask = masked_batch_indices == i\n                    if sample_mask.sum() > 0:\n                        sample_loss = weighted_loss[sample_mask].sum()\n                        denom = answer_lengths[i].clamp(min=1.0)\n                        loss_per_sample[i] = sample_loss / denom\n\n                loss = loss_per_sample.mean()\n            else:\n                # Non-SFT: when importance weighting is enabled, use unbiased estimator\n                # (sum(loss/p) / total_tokens). Otherwise, average over masked tokens\n                # for stable scaling across varying mask ratios.\n                if self.cfg.diffusion.importance_weighting:\n                    loss = weighted_loss.sum() / (\n                        input_ids.shape[0] * input_ids.shape[1]\n                    )\n                else:\n                    loss = weighted_loss.mean()\n\n            ce_loss = token_loss.mean()\n\n            # Compute accuracy on masked tokens\n            with torch.no_grad():\n                pred_tokens = masked_logits.argmax(dim=-1)\n                accuracy = (pred_tokens == masked_targets).float().mean()\n        else:\n            loss = torch.tensor(0.0, device=input_ids.device, requires_grad=True)\n            accuracy = torch.tensor(0.0, device=input_ids.device)\n            ce_loss = torch.tensor(0.0, device=input_ids.device)\n            masked_p_mask = torch.tensor(1.0, device=input_ids.device)\n\n        avg_p_mask = (\n            p_mask[masked_indices].mean().item() if masked_indices.any() else 0.0\n        )\n        metrics = {\n            \"loss\": loss.item(),\n            \"accuracy\": accuracy.item(),\n            \"mask_ratio\": masked_indices.float().mean().item(),\n            \"num_masked_tokens\": (masked_indices.sum().item(), \"sum\"),\n            \"avg_p_mask\": avg_p_mask,\n            \"ce_loss\": ce_loss.item(),\n        }\n\n        # If doing SFT training, log answer-specific metrics\n        if self.cfg.datasets is not None:\n            with torch.no_grad():\n                answer_mask = labels != -100\n                answer_lengths = answer_mask.sum(dim=1).float()  # type: ignore\n                total_answer_tokens = answer_mask.sum().item()  # type: ignore\n                total_tokens = labels.numel()  # type: ignore\n                metrics[\"answer_ratio\"] = total_answer_tokens / max(total_tokens, 1)\n                metrics[\"avg_answer_length\"] = answer_lengths.mean().item()\n\n        if self.cfg.diffusion.importance_weighting:\n            metrics[\"importance_weight_avg\"] = (1.0 / masked_p_mask).mean().item()\n\n        train_eval: Literal[\"train\", \"eval\"] = \"train\" if model.training else \"eval\"\n        self.store_metrics(metrics, train_eval=train_eval)\n\n        return loss, outputs\n"
  },
  {
    "path": "src/axolotl/integrations/diffusion/utils.py",
    "content": "\"\"\"Shared utilities for diffusion integration.\"\"\"\n\nfrom __future__ import annotations\n\nfrom typing import Any, Optional\n\nimport torch\n\nfrom axolotl.utils.dict import DictDefault\n\n\ndef resolve_mask_token_id(\n    tokenizer: Any,\n    cfg: DictDefault,\n    *,\n    allow_add: bool,\n    model: Any | None = None,\n    default_token: str = \"<|diffusion_mask|>\",\n) -> int:\n    \"\"\"Resolve mask token id. Training may add a new special token; inference won't.\"\"\"\n    # Determine vocab size if available\n    vocab_size = None\n    if tokenizer is not None:\n        if hasattr(tokenizer, \"vocab_size\") and tokenizer.vocab_size is not None:\n            try:\n                vocab_size = int(tokenizer.vocab_size)  # type: ignore[arg-type]\n            except Exception:\n                vocab_size = None\n        elif hasattr(tokenizer, \"__len__\"):\n            try:\n                vocab_size = int(len(tokenizer))\n            except Exception:\n                vocab_size = None\n\n    # Use explicit id from config if provided\n    diffusion_cfg = getattr(cfg, \"diffusion\", None)\n    # Fallback to top-level attr names only if nested missing (shouldn't happen)\n    cfg_id = (\n        getattr(diffusion_cfg, \"mask_token_id\", None)\n        if diffusion_cfg is not None\n        else getattr(cfg, \"diffusion_mask_token_id\", None)\n    )\n    if isinstance(cfg_id, int) and cfg_id >= 0:\n        if vocab_size is None or cfg_id < vocab_size:\n            return int(cfg_id)\n\n    def _existing_special_token_id(token_str: str | None) -> int | None:\n        \"\"\"Attempt to resolve an existing special token string to a real ID.\"\"\"\n        if not token_str or not hasattr(tokenizer, \"convert_tokens_to_ids\"):\n            return None\n        try:\n            token_id = tokenizer.convert_tokens_to_ids(token_str)\n        except Exception:\n            return None\n\n        if not isinstance(token_id, int) or token_id < 0:\n            return None\n\n        # Ensure it's registered as special and not UNK, and within vocab\n        unk_id = getattr(tokenizer, \"unk_token_id\", None)\n        specials = set(getattr(tokenizer, \"all_special_tokens\", []) or [])\n        addl = set(getattr(tokenizer, \"additional_special_tokens\", []) or [])\n        is_special = token_str in specials or token_str in addl\n        in_vocab = vocab_size is None or token_id < vocab_size\n        if (\n            (unk_id is not None and token_id == unk_id)\n            or not is_special\n            or not in_vocab\n        ):\n            return None\n        return token_id\n\n    # Try mask token string if provided\n    token_str = (\n        getattr(diffusion_cfg, \"mask_token_str\", None)\n        if diffusion_cfg is not None\n        else getattr(cfg, \"diffusion_mask_token_str\", None)\n    )\n    for candidate in (token_str, default_token):\n        token_id = _existing_special_token_id(candidate)\n        if isinstance(token_id, int):\n            try:\n                if diffusion_cfg is None:\n                    cfg.diffusion_mask_token_id = int(token_id)  # legacy fallback\n                else:\n                    diffusion_cfg.mask_token_id = int(token_id)\n            except Exception:\n                pass\n            return int(token_id)\n\n    # Optionally add and return a dedicated special token during training\n    if allow_add and hasattr(tokenizer, \"add_special_tokens\"):\n        token_to_add = token_str or default_token\n        try:\n            tokenizer.add_special_tokens({\"additional_special_tokens\": [token_to_add]})\n\n            # Resize embeddings if possible\n            if (\n                model is not None\n                and hasattr(tokenizer, \"__len__\")\n                and hasattr(model, \"resize_token_embeddings\")\n            ):\n                try:\n                    model.resize_token_embeddings(len(tokenizer))\n                except Exception:\n                    pass\n            new_id = tokenizer.convert_tokens_to_ids(token_to_add)\n            if isinstance(new_id, int) and new_id >= 0:\n                try:\n                    if diffusion_cfg is None:\n                        cfg.diffusion_mask_token_id = int(new_id)  # legacy fallback\n                    else:\n                        diffusion_cfg.mask_token_id = int(new_id)\n                except Exception:\n                    pass\n                return int(new_id)\n        except Exception:\n            pass\n\n    # Fallback to unk or 0 (do not update cfg)\n    fallback = getattr(tokenizer, \"unk_token_id\", 0) or 0\n    return int(fallback)\n\n\ndef create_bidirectional_attention_mask(\n    input_ids: torch.Tensor,\n    attention_mask: Optional[torch.Tensor] = None,\n    sample_packing: bool = False,\n) -> torch.Tensor:\n    \"\"\"\n    Create bidirectional attention mask to override default causal masking.\n    Handles sample-packed sequences where different samples are identified\n    by different attention mask values.\n\n    Args:\n        input_ids: Input token ids [batch_size, seq_len]\n        attention_mask: Attention mask [batch_size, seq_len]\n        sample_packing: Whether sample packing is enabled\n\n    Returns:\n        bidirectional_mask: 4D attention mask [batch_size, 1, seq_len, seq_len]\n    \"\"\"\n    batch_size, seq_len = input_ids.shape\n    device = input_ids.device\n\n    if attention_mask is None or not sample_packing:\n        return torch.ones(\n            batch_size, 1, seq_len, seq_len, dtype=torch.bool, device=device\n        )\n\n    # Handle sample packing: tokens can only attend within their sample\n    mask_i = attention_mask.unsqueeze(2)  # [batch_size, seq_len, 1]\n    mask_j = attention_mask.unsqueeze(1)  # [batch_size, 1, seq_len]\n\n    # Tokens can attend to each other if they have the same non-zero sample ID\n    bidirectional_mask = (mask_i == mask_j) & (mask_i > 0)\n\n    # Add head dimension: [batch_size, 1, seq_len, seq_len]\n    return bidirectional_mask.unsqueeze(1)\n\n\ndef shift_logits_to_input_positions(logits: torch.Tensor) -> torch.Tensor:\n    \"\"\"Align next-token logits with their input token positions for diffusion.\"\"\"\n    if logits.size(1) <= 1:\n        return logits\n    return torch.cat([logits[:, :1], logits[:, :-1]], dim=1)\n"
  },
  {
    "path": "src/axolotl/integrations/grokfast/LICENSE",
    "content": "MIT License\n\nCopyright (c) 2024 Jaerin Lee, Bong Gyun Kang, Kihoon Kim, Kyoung Mu Lee\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "src/axolotl/integrations/grokfast/README.md",
    "content": "# Grokfast Optimizer\n\nSee https://github.com/ironjr/grokfast\n\n## Usage\n\n```yaml\nplugins:\n  - axolotl.integrations.grokfast.GrokfastPlugin\n\ngrokfast_alpha: 2.0\ngrokfast_lamb: 0.98\n```\n\n## Citation\n\n```bib\n@article{lee2024grokfast,\n    title={{Grokfast}: Accelerated Grokking by Amplifying Slow Gradients},\n    author={Lee, Jaerin and Kang, Bong Gyun and Kim, Kihoon and Lee, Kyoung Mu},\n    journal={arXiv preprint arXiv:2405.20233},\n    year={2024}\n}\n```\n"
  },
  {
    "path": "src/axolotl/integrations/grokfast/__init__.py",
    "content": "\"\"\"\nGrokfast plugin for Axolotl\n\"\"\"\n\nfrom transformers.trainer_callback import TrainerCallback\n\nfrom axolotl.utils.logging import get_logger\n\nfrom ..base import BasePlugin\nfrom .args import GrokfastArgs as GrokfastArgs\nfrom .optimizer import gradfilter_ema\n\nLOG = get_logger(__name__)\n\n\nclass GrokfastCallbackHandler(TrainerCallback):\n    \"\"\"\n    Transformer trainer callbacks for Grokfast\n    \"\"\"\n\n    def __init__(self, *args_, alpha=0.98, lamb=2.0, **kwargs):\n        super().__init__(*args_, **kwargs)\n        self.grads = None\n        self.alpha = alpha\n        self.lamb = lamb\n\n    def on_train_begin(self, *args_, **kwargs):\n        self.grads = None\n\n    def on_pre_optimizer_step(self, args_, state, control, **kwargs):\n        model = kwargs.pop(\"model\")\n        self.grads = gradfilter_ema(model, self.grads, alpha=self.alpha, lamb=self.lamb)\n        return control\n\n\nclass GrokfastPlugin(BasePlugin):\n    \"\"\"\n    Plugin for Grokfast optimizer integraton with Axolotl.\n    \"\"\"\n\n    def get_input_args(self):\n        return \"axolotl.integrations.grokfast.GrokfastArgs\"\n\n    def add_callbacks_post_trainer(self, cfg, trainer):\n        LOG.info(\"Adding Grokfast callback to the trainer\")\n        callback = GrokfastCallbackHandler(\n            alpha=cfg.grokfast_alpha, lamb=cfg.grokfast_lamb\n        )\n        return [callback]\n"
  },
  {
    "path": "src/axolotl/integrations/grokfast/args.py",
    "content": "\"\"\"\nconfig args for grokfast plugin\n\"\"\"\n\nfrom typing import Optional\n\nfrom pydantic import BaseModel\n\n\nclass GrokfastArgs(BaseModel):\n    \"\"\"\n    Input args for Grokfast optimizer.\n    \"\"\"\n\n    grokfast_alpha: Optional[float] = 0.98\n    grokfast_lamb: Optional[float] = 2.0\n"
  },
  {
    "path": "src/axolotl/integrations/grokfast/optimizer.py",
    "content": "# Copyright: MIT License (c) 2024 Jaerin Lee, Bong Gyun Kang, Kihoon Kim, Kyoung Mu Lee\n# Reference: https://github.com/ironjr/grokfast\n\nfrom collections import deque\nfrom typing import Dict, Literal, Optional\n\nimport torch\nimport torch.nn as nn\n\n\ndef gradfilter_ma(\n    m: nn.Module,\n    grads: Optional[Dict[str, deque]] = None,\n    window_size: int = 100,\n    lamb: float = 5.0,\n    filter_type: Literal[\"mean\", \"sum\"] = \"mean\",\n    warmup: bool = True,\n    trigger: bool = False,  # For ablation study.\n) -> Dict[str, deque]:\n    if grads is None:\n        grads = {\n            n: deque(maxlen=window_size)\n            for n, p in m.named_parameters()\n            if p.requires_grad and p.grad is not None\n        }\n\n    for n, p in m.named_parameters():\n        if p.requires_grad and p.grad is not None:\n            grads[n].append(p.grad.data.detach())  # .cpu())\n\n            # Modify the gradients.\n            if not warmup or len(grads[n]) == window_size and not trigger:\n                if filter_type == \"mean\":\n                    avg = sum(grads[n]) / len(grads[n])\n                elif filter_type == \"sum\":\n                    avg = sum(grads[n])\n                else:\n                    raise ValueError(f\"Unrecognized filter_type {filter_type}\")\n                p.grad.data = p.grad.data + avg * lamb\n\n    return grads\n\n\ndef gradfilter_ema(\n    m: nn.Module,\n    grads: Optional[Dict[str, torch.Tensor]] = None,\n    alpha: float = 0.98,\n    lamb: float = 2.0,\n) -> Dict[str, torch.Tensor]:\n    if grads is None:\n        grads = {\n            n: p.grad.data.detach()\n            for n, p in m.named_parameters()\n            if p.requires_grad and p.grad is not None\n        }\n\n    for n, p in m.named_parameters():\n        if p.requires_grad and p.grad is not None:\n            grads[n] = grads[n] * alpha + p.grad.data.detach() * (1 - alpha)\n            p.grad.data = p.grad.data + grads[n] * lamb\n\n    return grads\n"
  },
  {
    "path": "src/axolotl/integrations/kd/README.md",
    "content": "# Knowledge Distillation\n\n## Usage\n\n```yaml\nplugins:\n  - \"axolotl.integrations.kd.KDPlugin\"\n\nkd_trainer: True\nkd_ce_alpha: 0.1\nkd_alpha: 0.9\nkd_temperature: 1.0\n\ntorch_compile: True  # torch>=2.6.0, recommended to reduce vram\n\ndatasets:\n  - path: ...\n    type: \"axolotl.integrations.kd.chat_template\"\n    field_messages: \"messages_combined\"\n    logprobs_field: \"llm_text_generation_vllm_logprobs\"  # for kd only, field of logprobs\n```\n\nAn example dataset can be found at [`axolotl-ai-co/evolkit-logprobs-pipeline-75k-v2-sample`](https://huggingface.co/datasets/axolotl-ai-co/evolkit-logprobs-pipeline-75k-v2-sample)\n"
  },
  {
    "path": "src/axolotl/integrations/kd/__init__.py",
    "content": "# Copyright 2024 Axolotl AI. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nPlugin init to add KD support to Axolotl.\n\"\"\"\n\nfrom typing import Any\n\nfrom transformers import Trainer\n\nfrom axolotl.integrations.base import BasePlugin\nfrom axolotl.integrations.kd.callbacks import KDTemperatureSchedulerCallback\n\nfrom .args import KDArgs as KDArgs\n\n\nclass KDPlugin(BasePlugin):\n    \"\"\"\n    Plugin for KD support in Axolotl.\n    \"\"\"\n\n    def get_input_args(self):\n        return \"axolotl.integrations.kd.KDArgs\"\n\n    def get_training_args_mixin(self):\n        return \"axolotl.integrations.kd.args.KDTrainingArgsMixin\"\n\n    def get_trainer_cls(self, cfg):\n        if cfg.kd_trainer:\n            from .trainer import AxolotlKDTrainer\n\n            return AxolotlKDTrainer\n        return None\n\n    def get_training_args(self, cfg):\n        return {\n            \"kd_ce_alpha\": cfg.kd_ce_alpha,\n            \"kd_alpha\": cfg.kd_alpha,\n            \"kd_temperature\": cfg.kd_temperature,\n            \"kd_beta\": cfg.kd_beta,\n            \"kd_normalize_topk\": cfg.kd_normalize_topk,\n        }\n\n    def get_collator_cls_and_kwargs(self, cfg, is_eval=False):\n        if not cfg.kd_trainer:\n            return None, None\n\n        from .collator import DataCollatorForKD, KDBatchSamplerDataCollatorForSeq2Seq\n\n        use_batch_sampler_collator = False\n        if is_eval is False and cfg.sample_packing:\n            use_batch_sampler_collator = True\n        if cfg.eval_sample_packing and is_eval:\n            use_batch_sampler_collator = True\n\n        if cfg.kd_online_server_base_url:\n            from .collator_online_teacher import OnlineTeacherCollator\n\n            return OnlineTeacherCollator, {\n                \"kd_online_server_base_url\": cfg.kd_online_server_base_url,\n                \"kd_online_topk\": cfg.kd_online_topk,\n                \"kd_temperature\": cfg.kd_temperature,\n                \"kd_online_server\": cfg.kd_online_server,\n                \"kd_online_timeout\": cfg.kd_online_timeout,\n                \"kd_normalize_topk\": cfg.kd_normalize_topk,\n            }\n\n        if use_batch_sampler_collator:\n            return KDBatchSamplerDataCollatorForSeq2Seq, {}\n        return DataCollatorForKD, {}\n\n    def pre_model_load(self, cfg):\n        from .kernels.models import apply_kernel\n\n        apply_kernel(cfg.model_config_type)\n\n    def add_callbacks_post_trainer(self, cfg: Any, trainer: Trainer) -> list:\n        \"\"\"\n        Adds temp scheduler callback to the Trainer instance.\n\n        Args:\n            cfg (Any): Configuration object containing the sparse recipe.\n            trainer (Trainer): Huggingface Trainer instance.\n\n        Returns:\n            list: List containing the configured callback instances.\n        \"\"\"\n        if cfg.kd_temperature_min is not None and cfg.kd_online_server_base_url:\n            callback = KDTemperatureSchedulerCallback(\n                cfg.kd_temperature,\n                cfg.kd_temperature_min,\n                trainer,\n            )\n            return [callback]\n\n        return []\n"
  },
  {
    "path": "src/axolotl/integrations/kd/args.py",
    "content": "# Copyright 2024 Axolotl AI. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nPlugin args for KD support.\n\"\"\"\n\nfrom dataclasses import dataclass\nfrom enum import Enum\n\nfrom pydantic import BaseModel, Field\n\n\nclass InferenceServerType(str, Enum):\n    \"\"\"\n    Online inferences server types to handle different request args\n    \"\"\"\n\n    vllm = \"vllm\"\n    sglang = \"sglang\"\n\n\nclass KDArgs(BaseModel):\n    \"\"\"\n    Input args for knowledge distillation.\n    \"\"\"\n\n    kd_trainer: float | None = None  # whether to use KD trainer\n    kd_ce_alpha: float | None = (\n        None  # loss coefficient for cross-entropy loss during KD\n    )\n    kd_alpha: float | None = None  # loss coefficient for KD loss\n    kd_temperature: float | None = None  # temperature for sampling during KD\n    kd_beta: float | None = 0.0  # beta coefficient for ratio of fwd and reverse KL\n    kd_normalize_topk: bool | None = (\n        None  # whether to normalize student logits during KD\n    )\n\n    # TODO online kd\n    kd_online_server_base_url: str | None = None\n    kd_online_topk: int | None = None\n    kd_online_server: InferenceServerType | None = Field(\n        default_factory=lambda: InferenceServerType.vllm\n    )\n    kd_online_timeout: int | None = 120\n    kd_temperature_min: float | None = (\n        None  # kd temperature scheduling during online kd\n    )\n\n\n@dataclass\nclass KDTrainingArgsMixin:\n    \"\"\"\n    Additional args for KD training.\n    \"\"\"\n\n    kd_ce_alpha: float | None = (\n        None  # loss coefficient for cross-entropy loss during KD\n    )\n    kd_alpha: float | None = None  # loss coefficient for KD loss\n    kd_temperature: float | None = None  # temperature for sampling during KD\n    kd_beta: float | None = None  # beta coefficient for ratio of fwd and reverse KL\n    kd_normalize_topk: float | None = (\n        None  # whether to normalize student logits during KD\n    )\n"
  },
  {
    "path": "src/axolotl/integrations/kd/callbacks.py",
    "content": "\"\"\"\nTransformers trainer callbacks to schedule the KD temperature during training\n\"\"\"\n\nimport math\n\nfrom transformers.trainer_callback import TrainerCallback\n\n\nclass KDTemperatureSchedulerCallback(TrainerCallback):\n    \"\"\"\n    KD temperature scheduler callback for the trainer.\n    \"\"\"\n\n    def __init__(self, temperature_start, temperature_min, trainer):\n        self.temperature_start = temperature_start\n        self.temperature_min = temperature_min\n        self.temperature = temperature_start\n\n        self.trainer = trainer\n\n    def on_step_end(self, args, state, control, **kwargs):\n        # cosine decay temperature over the max steps\n\n        progress = state.global_step / state.max_steps\n        # Cosine decay factor: 0.5 * (1 + cos(pi * progress))\n        # This factor goes from 1 (at progress=0) to 0 (at progress=1)\n        decay_factor = 0.5 * (1.0 + math.cos(math.pi * progress))\n        self.temperature = self.temperature_start - (\n            (self.temperature_start - self.temperature_min) * (1.0 - decay_factor)\n        )\n\n        if hasattr(self.trainer.data_collator, \"kd_temperature\"):\n            self.trainer.data_collator.kd_temperature = self.temperature\n"
  },
  {
    "path": "src/axolotl/integrations/kd/chat_template.py",
    "content": "# Copyright 2024 Axolotl AI. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nChat template prompt strategy loader with KD support\n\"\"\"\n\nimport logging\nfrom typing import Any, Dict\n\nimport torch\n\nfrom axolotl.prompt_strategies.chat_template import ChatTemplateStrategy, StrategyLoader\n\nLOG = logging.getLogger(__name__)\n\n\nclass ChatTemplateStrategyWithKD(ChatTemplateStrategy):\n    \"\"\"\n    Handle fields for logprob KD\n    \"\"\"\n\n    def __init__(\n        self,\n        prompter,\n        tokenizer,\n        train_on_inputs,\n        sequence_len,\n        roles_to_train=None,\n        train_on_eos=None,\n        train_on_eot=None,\n        eot_tokens=None,\n        split_thinking: bool | None = False,\n        logprobs_field=\"logprobs\",\n        gen_temperature=1.0,\n        kd_temperature=1.0,\n    ):\n        self.logprobs_field = logprobs_field\n        self.gen_temperature = gen_temperature\n        self.kd_temperature = kd_temperature\n\n        super().__init__(\n            prompter,\n            tokenizer,\n            train_on_inputs,\n            sequence_len,\n            roles_to_train=roles_to_train,\n            train_on_eos=train_on_eos,\n            train_on_eot=train_on_eot,\n            eot_tokens=eot_tokens,\n            split_thinking=split_thinking,\n        )\n\n    @property\n    def supports_batched(self) -> bool:\n        # batching doesn't work well for logprob data\n        return False\n\n    def transform_logprobs(self, sample):\n        \"\"\"\n        Transform logprobs to target format for KD training\n        \"\"\"\n\n        logprobs = sample.pop(self.logprobs_field)\n        target_seq_len = len(logprobs)\n        input_seq_len = len(sample[\"input_ids\"])\n        input_padding_len = input_seq_len - target_seq_len\n        # get non-zero top-k (prune None logprobs from vllm data step)\n        top_k_vals = [\n            len(logprobs[i])\n            for i in range(len(logprobs))\n            if logprobs[i] is not None and len(logprobs[i])\n        ]\n        max_top_k = max(set(top_k_vals), key=top_k_vals.count)\n        min_top_k = min(set(top_k_vals), key=top_k_vals.count)\n        top_k = min(max_top_k, min_top_k)\n        if top_k == 0:\n            raise ValueError(\"No non-zero top-k logprobs found.\")\n\n        target_logprobs = []\n        target_token_ids = []\n        target_mask = []\n\n        if input_padding_len < 0:\n            # logprobs is longer than target_seq_len,\n            # so we need to slice from the left/beginning of logprobs\n            logprobs = logprobs[:-input_seq_len]\n            input_padding_len = 0\n            # target_seq_len = input_seq_len\n\n        # truncate the second dimension of the logprobs to top_k\n        logprobs = [row[:top_k] for row in logprobs]\n\n        # fill with -inf for padding_len tokens for top_k tokens\n        # extend target_logprobs with a padding_len x top_k 2D list filled with -inf\n\n        # we shift for causal models in the trainer, so start the range from 0\n        for _ in range(0, input_padding_len):\n            target_logprobs.append([-float(\"inf\")] * top_k)\n            target_token_ids.append(list(range(top_k)))\n            target_mask.append([0] * top_k)\n\n        for position in range(input_padding_len, input_seq_len):\n            if sample[\"labels\"][position] == -100:\n                target_mask.append([0] * top_k)\n            else:\n                target_mask.append([1] * top_k)\n\n        for _, token_pos_logprobs in enumerate(logprobs):\n            # Initialize collections for logprobs and token_ids\n            position_logprobs = []\n            position_token_ids = []\n\n            # Process each token probability entry\n            for entry in token_pos_logprobs:\n                # Extract logprob value\n                logprob = entry[\"logprob\"]\n\n                # Parse token_id from the \"token_id:###\" format\n                token_id = int(entry[\"token\"].split(\":\")[1])\n\n                # Append to our collections\n                position_logprobs.append(logprob)\n                position_token_ids.append(token_id)\n\n            # Convert to a tensor for easier manipulation\n            position_logprobs_tensor = torch.tensor(\n                position_logprobs, dtype=torch.float\n            )\n\n            # Now we have distribution at T1 in log form, i.e. log p_{T1}(k).\n            # Next, re-scale to T2 = self.kd_temperature via exponent-based trick\n            # p_{T2}(k) = [p_{T1}(k)]^(T1 / T2) / Z\n            #\n            # Convert from log to probability\n            teacher_probs_t1 = position_logprobs_tensor.exp()\n            # normalize probabilities to sum to 1 in case they aren't already\n            teacher_probs_t1_sum = teacher_probs_t1.sum(dim=0, keepdim=True)\n            if teacher_probs_t1_sum > 1e-9:\n                teacher_probs_t1 = teacher_probs_t1 / teacher_probs_t1_sum\n            if self.kd_temperature != self.gen_temperature:\n                # Exponentiate by factor (T1 / T2)\n                exponent = self.gen_temperature / self.kd_temperature\n                teacher_probs_t2 = teacher_probs_t1**exponent\n            else:\n                teacher_probs_t2 = teacher_probs_t1\n            # Re-normalize\n            teacher_probs_t2 = teacher_probs_t2 / teacher_probs_t2.sum(\n                dim=0, keepdim=True\n            )\n            # Convert back to log\n            position_logprobs_tensor = torch.log(teacher_probs_t2)\n\n            # Now we have log p_{teacher, T2}(k) stored in position_logprobs_tensor\n            position_logprobs_scaled = position_logprobs_tensor.tolist()\n\n            target_logprobs.append(position_logprobs_scaled)\n            target_token_ids.append(position_token_ids)\n\n        # Update sample with transformed logprobs\n        sample[\"target_logprobs\"] = target_logprobs\n        sample[\"target_token_ids\"] = target_token_ids\n        sample[\"target_mask\"] = target_mask\n\n        return sample\n\n    def _tokenize_single_prompt(self, prompt):\n        logprobs = prompt.pop(self.logprobs_field)\n        tokenized_prompt = super()._tokenize_single_prompt(prompt)\n        tokenized_prompt[self.logprobs_field] = logprobs\n\n        # let subclasses add fields before transform\n        tokenized_prompt = self._prepare_kd_fields(tokenized_prompt, prompt)\n\n        tokenized_prompt = self.transform_logprobs(tokenized_prompt)\n        return tokenized_prompt\n\n    def _prepare_kd_fields(self, tokenized_prompt, original_prompt):\n        \"\"\"\n        Hook for subclasses to prepare additional KD fields before transform\n        \"\"\"\n        return tokenized_prompt\n\n\nclass ChatTemplateStrategyWithKDv2(ChatTemplateStrategyWithKD):\n    \"\"\"\n    Strat for datasets with complete structured KD logprob data\n    \"\"\"\n\n    def transform_logprobs(self, sample):\n        \"\"\"\n        Transform logprobs to target format for KD training\n        \"\"\"\n\n        logprobs = sample.pop(self.logprobs_field)\n        target_seq_len = len(logprobs)\n        input_seq_len = len(sample[\"input_ids\"])\n        input_padding_len = input_seq_len - target_seq_len\n        # get non-zero top-k (prune None logprobs from vllm data step)\n        top_k_vals = [\n            len(logprobs[i])\n            for i in range(len(logprobs))\n            if logprobs[i] is not None and len(logprobs[i])\n        ]\n        max_top_k = max(set(top_k_vals), key=top_k_vals.count)\n        min_top_k = min(set(top_k_vals), key=top_k_vals.count)\n        top_k = min(max_top_k, min_top_k)\n        if top_k == 0:\n            raise ValueError(\"No non-zero top-k logprobs found.\")\n\n        target_logprobs = []\n        target_token_ids = []\n        target_mask = []\n\n        if input_padding_len < 0:\n            # logprobs is longer than target_seq_len,\n            # so we need to slice from the left/beginning of logprobs\n            logprobs = logprobs[:-input_seq_len]\n            input_padding_len = 0\n            # target_seq_len = input_seq_len\n\n        # truncate the second dimension of the logprobs to top_k\n        logprobs = [row[:top_k] for row in logprobs]\n\n        # fill with -inf for padding_len tokens for top_k tokens\n        # extend target_logprobs with a padding_len x top_k 2D list filled with -inf\n\n        # we shift for causal models in the trainer, so start the range from 0\n        for _ in range(0, input_padding_len):\n            target_logprobs.append([-float(\"inf\")] * top_k)\n            target_token_ids.append(list(range(top_k)))\n            target_mask.append([0] * top_k)\n\n        for position in range(input_padding_len, input_seq_len):\n            if sample[\"labels\"][position] == -100:\n                target_mask.append([0] * top_k)\n            else:\n                target_mask.append([1] * top_k)\n\n        for token_pos_logprobs, pos_target_token_ids in zip(\n            logprobs, sample[\"target_token_ids\"], strict=False\n        ):\n            # Convert to a tensor for easier manipulation\n            position_logprobs_tensor = torch.tensor(\n                token_pos_logprobs, dtype=torch.float\n            )\n\n            # Now we have distribution at T1 in log form, i.e. log p_{T1}(k).\n            # Next, re-scale to T2 = self.kd_temperature via exponent-based trick\n            # p_{T2}(k) = [p_{T1}(k)]^(T1 / T2) / Z\n            #\n            # Convert from log to probability\n            teacher_probs_t1 = position_logprobs_tensor.exp()\n            # normalize probabilities to sum to 1 in case they aren't already\n            teacher_probs_t1_sum = teacher_probs_t1.sum(dim=0, keepdim=True)\n            if teacher_probs_t1_sum > 1e-9:\n                teacher_probs_t1 = teacher_probs_t1 / teacher_probs_t1_sum\n            if self.kd_temperature != self.gen_temperature:\n                # Exponentiate by factor (T1 / T2)\n                exponent = self.gen_temperature / self.kd_temperature\n                teacher_probs_t2 = teacher_probs_t1**exponent\n            else:\n                teacher_probs_t2 = teacher_probs_t1\n            # Re-normalize\n            teacher_probs_t2 = teacher_probs_t2 / teacher_probs_t2.sum(\n                dim=0, keepdim=True\n            )\n            # Convert back to log\n            position_logprobs_tensor = torch.log(teacher_probs_t2)\n\n            # Now we have log p_{teacher, T2}(k) stored in position_logprobs_tensor\n            position_logprobs_scaled = position_logprobs_tensor.tolist()\n\n            target_logprobs.append(position_logprobs_scaled)\n            target_token_ids.append(pos_target_token_ids)\n\n        # Update sample with transformed logprobs\n        sample[\"target_logprobs\"] = target_logprobs\n        sample[\"target_token_ids\"] = target_token_ids\n        sample[\"target_mask\"] = target_mask\n\n        return sample\n\n    def _prepare_kd_fields(self, tokenized_prompt, original_prompt):\n        \"\"\"\n        Add pre-tokenized target_token_ids for v2 format\n        \"\"\"\n        target_token_ids = original_prompt.pop(\"target_token_ids\", None)\n        if target_token_ids is not None:\n            tokenized_prompt[\"target_token_ids\"] = target_token_ids\n        return tokenized_prompt\n\n\nclass KDStrategyLoader(StrategyLoader):\n    \"\"\"\n    Load ChatTemplateStrategy with KD support using StrategyLoader.\n    \"\"\"\n\n    def _get_strategy_cls(self, cfg):\n        return ChatTemplateStrategyWithKD\n\n    def _get_strategy_params(self, cfg, ds_cfg: Dict[str, Any]):\n        strategy_params = super()._get_strategy_params(cfg, ds_cfg)\n        if logprobs_field := ds_cfg.get(\"logprobs_field\"):\n            strategy_params[\"logprobs_field\"] = logprobs_field\n        if gen_temperature := ds_cfg.get(\"temperature\"):\n            strategy_params[\"gen_temperature\"] = gen_temperature\n        if kd_temperature := cfg.get(\"kd_temperature\"):\n            strategy_params[\"kd_temperature\"] = kd_temperature\n\n        return strategy_params\n\n\nclass KDStrategyLoaderV2(KDStrategyLoader):\n    \"\"\"\n    Load KD chat template datasets with pre-tokenized logprob data\n    \"\"\"\n\n    def _get_strategy_cls(self, cfg):\n        return ChatTemplateStrategyWithKDv2\n\n\nload_legacy = KDStrategyLoader()\nload = KDStrategyLoaderV2()\n"
  },
  {
    "path": "src/axolotl/integrations/kd/collator.py",
    "content": "# Copyright 2024 Axolotl AI. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nDataCollator for axolotl to handle KD fields without using -inf for padding,\nand with a teacher_mask to identify padded positions.\n\"\"\"\n\nfrom dataclasses import dataclass\nfrom typing import Any, Optional, Union\n\nimport numpy as np\nimport torch\nfrom transformers import PreTrainedTokenizerBase\nfrom transformers.utils import PaddingStrategy\n\nfrom axolotl.utils.collators.batching import DataCollatorForSeq2Seq\n\n\n@dataclass\nclass DataCollatorForKD(DataCollatorForSeq2Seq):\n    \"\"\"\n    Data collator for KD, including handling KD-specific fields.\n\n    This version avoids using -inf and instead uses a large negative value for padding\n    target_logprobs. It also creates a teacher_mask to indicate which entries are valid.\n    \"\"\"\n\n    tokenizer: PreTrainedTokenizerBase\n    model: Optional[Any] = None\n    padding: Union[bool, str, PaddingStrategy] = True\n    max_length: Optional[int] = None\n    pad_to_multiple_of: Optional[int] = None\n    label_pad_token_id: int = -100\n    position_pad_token_id: int = 0\n    return_tensors: str = \"pt\"\n\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.tokenizer.deprecation_warnings[\"Asking-to-pad-a-fast-tokenizer\"] = True\n\n    def __call__(self, features, return_tensors=None):\n        if return_tensors is None:\n            return_tensors = self.return_tensors\n\n        padding_side = self.tokenizer.padding_side\n        max_len = 0\n\n        # Pad labels and position_ids first\n        for feature_name, pad_token_id in [\n            (\"labels\", self.label_pad_token_id),\n            (\"position_ids\", self.position_pad_token_id),\n        ]:\n            if feature_name in features[0]:\n                feat = [f[feature_name] for f in features]\n                max_len = max(len(x) for x in feat)\n                if self.pad_to_multiple_of is not None:\n                    max_len = (\n                        (max_len + self.pad_to_multiple_of - 1)\n                        // self.pad_to_multiple_of\n                    ) * self.pad_to_multiple_of\n\n                for f in features:\n                    remainder = [pad_token_id] * (max_len - len(f[feature_name]))\n                    if isinstance(f[feature_name], list):\n                        f[feature_name] = (\n                            f[feature_name] + remainder\n                            if padding_side == \"right\"\n                            else remainder + f[feature_name]\n                        )\n                    else:\n                        # If they are numpy arrays\n                        if padding_side == \"right\":\n                            f[feature_name] = np.concatenate(\n                                [f[feature_name], remainder]\n                            ).astype(np.int64)\n                        else:\n                            f[feature_name] = np.concatenate(\n                                [remainder, f[feature_name]]\n                            ).astype(np.int64)\n\n        # Handle target_logprobs and target_token_ids manually\n        target_logprobs_list = []\n        target_token_ids_list = []\n        target_mask_list = []\n        has_teacher_data = (\"target_logprobs\" in features[0]) and (\n            \"target_token_ids\" in features[0]\n        )\n\n        if has_teacher_data:\n            # Extract and remove from features\n            for f in features:\n                target_logprobs_list.append(f.pop(\"target_logprobs\"))\n                target_token_ids_list.append(f.pop(\"target_token_ids\"))\n                target_mask_list.append(f.pop(\"target_mask\"))\n\n            # Determine max lengths\n            max_teacher_seq_len = max_len or max(\n                len(seq) for seq in target_logprobs_list\n            )\n            max_k = max(len(seq_k) for seq in target_logprobs_list for seq_k in seq)\n\n            padded_target_logprobs = []\n            padded_target_token_ids = []\n            padded_teacher_mask_list = []\n\n            for t_logprobs, t_ids, t_mask in zip(\n                target_logprobs_list,\n                target_token_ids_list,\n                target_mask_list,\n                strict=False,\n            ):\n                t_logprobs_padded = []\n                t_ids_padded = []\n                t_mask_padded = []\n\n                for lp, ids, mask in zip(t_logprobs, t_ids, t_mask, strict=False):\n                    lp_len = len(lp)\n                    if lp_len < max_k:\n                        # Use -1e9 for padding logprobs and 0 for token_ids\n                        pad_len = max_k - lp_len\n                        lp = lp + [-1e9] * pad_len\n                        ids = ids + [0] * pad_len\n                        mask = mask + [0] * pad_len\n                    else:\n                        lp = lp[:max_k]\n                        ids = ids[:max_k]\n                        mask = mask[:max_k]\n\n                    t_logprobs_padded.append(lp)\n                    t_ids_padded.append(ids)\n                    t_mask_padded.append(mask)\n\n                seq_len_diff = max_teacher_seq_len - len(t_logprobs_padded)\n                if seq_len_diff > 0:\n                    # Pad sequences fully if needed\n                    t_logprobs_padded.extend(\n                        [[-1e9] * max_k for _ in range(seq_len_diff)]\n                    )\n                    t_ids_padded.extend([[0] * max_k for _ in range(seq_len_diff)])\n                    t_mask_padded.extend([[0] * max_k for _ in range(seq_len_diff)])\n\n                padded_target_logprobs.append(t_logprobs_padded)\n                padded_target_token_ids.append(t_ids_padded)\n                padded_teacher_mask_list.append(t_mask_padded)\n\n            # Convert to tensors\n            padded_target_logprobs = torch.tensor(\n                padded_target_logprobs, dtype=torch.float\n            )\n            padded_target_token_ids = torch.tensor(\n                padded_target_token_ids, dtype=torch.long\n            )\n            padded_teacher_mask_list = torch.tensor(\n                padded_teacher_mask_list, dtype=torch.int\n            )\n\n        # Pad using tokenizer for regular fields\n        features = self.tokenizer.pad(\n            features,\n            padding=self.padding,\n            max_length=self.max_length,\n            pad_to_multiple_of=self.pad_to_multiple_of,\n            return_tensors=return_tensors,\n        )\n\n        # Add back teacher data if present\n        if has_teacher_data:\n            features[\"target_logprobs\"] = padded_target_logprobs\n            features[\"target_token_ids\"] = padded_target_token_ids\n            features[\"target_mask\"] = padded_teacher_mask_list\n\n        # Prepare decoder_input_ids if the model supports it\n        if (\n            \"labels\" in features\n            and self.model is not None\n            and hasattr(self.model, \"prepare_decoder_input_ids_from_labels\")\n        ):\n            decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels(\n                labels=features[\"labels\"]\n            )\n            features[\"decoder_input_ids\"] = decoder_input_ids\n\n        return features\n\n\nclass KDBatchSamplerDataCollatorForSeq2Seq(DataCollatorForKD):\n    \"\"\"\n    Collator for multipack (batch of sub-batches) specifically for KD.\n    Adapts DataCollatorForKD so it can pack multiple sequences in a single batch item.\n    \"\"\"\n\n    def __call__(self, features, return_tensors=None):\n        \"\"\"\n        Expects that `features` could be either:\n          - a single list of dicts, OR\n          - a list of lists of dicts (the \"sub-batches\" to be packed).\n        \"\"\"\n        # 1) If we are *not* dealing with multiple sequences per batch element,\n        #    just pass straight to parent.\n        if not isinstance(features[0], list):\n            return super().__call__(features, return_tensors=return_tensors)\n\n        # 2) Otherwise, we *are* dealing with multiple sequences in each batch item.\n        #    We want to produce a single \"merged\" feature dict for each sub-batch.\n        out_features = [{} for _ in features]\n\n        for i, sub_features in enumerate(features):\n            # sub_features is a list of dicts, each dict = one sequence’s features\n            # We'll merge them into out_features[i].\n            #\n            # NOTE: You can customize how you combine fields as needed (e.g. summation\n            # or offset for attention_mask). Below is a straightforward concatenation/extension.\n\n            for field_name in sub_features[0].keys():\n                # Some fields you might want to skip or treat specially:\n                if field_name == \"length\":\n                    continue\n\n                # If it’s a KD field that’s a list-of-lists (e.g. target_logprobs),\n                # you typically just want to flatten them by extending.\n                if field_name in [\"target_logprobs\", \"target_token_ids\", \"target_mask\"]:\n                    combined = []\n                    for feat in sub_features:\n                        combined.extend(feat[field_name])\n                    out_features[i][field_name] = combined\n\n                elif field_name == \"attention_mask\":\n                    # Here we apply the (j+1) factor to differentiate each sub-sample\n                    # within this merged batch item.\n                    arrays = []\n                    for j, feat in enumerate(sub_features):\n                        if field_name in feat:\n                            arrays.append((j + 1) * np.array(feat[field_name]))\n                    out_features[i][field_name] = np.concatenate(arrays)\n                else:\n                    # By default, just concatenate them if they are arrays\n                    # or extend them if they are lists.\n                    # For example, input_ids or labels are often arrays.\n                    arrays = []\n                    for feat in sub_features:\n                        if field_name in feat and isinstance(\n                            feat[field_name], (list, torch.Tensor)\n                        ):\n                            if isinstance(feat[field_name][0], (dict, str)):\n                                continue\n                            arr = np.array(feat[field_name])\n                            arrays.append(arr)\n                    if arrays:\n                        out_features[i][field_name] = np.concatenate(arrays)\n\n        # 3) Now call the parent collator, which will do:\n        #    - padding of labels/position_ids\n        #    - KD-specific padding for target_logprobs, target_token_ids, etc.\n        #    - final conversion to return_tensors\n        return super().__call__(out_features, return_tensors=return_tensors)\n"
  },
  {
    "path": "src/axolotl/integrations/kd/collator_online_teacher.py",
    "content": "\"\"\"\nPacked data loader for online teacher training supporting vllm and sglang.\n\"\"\"\n\nimport hashlib\nimport hmac\nimport logging\nfrom typing import Any, Dict, List, Optional\n\nimport requests\nimport torch\nfrom orjson import orjson\n\nfrom axolotl.integrations.kd.collator import KDBatchSamplerDataCollatorForSeq2Seq\nfrom axolotl.integrations.kd.utils import normalize_logprobs\nfrom axolotl.utils.data.utils import retry_on_request_exceptions\n\nLOG = logging.getLogger(__name__)\n\n\ndef hmac_sha_from_int_list(int_list, key, hash_func=hashlib.sha256):\n    \"\"\"\n    Create HMAC-SHA hash from a list of integers\n\n    Args:\n        int_list: List of integers\n        key: Secret key (string or bytes)\n        hash_func: Hash function (default: sha256)\n\n    Returns:\n        HMAC digest as hex string\n    \"\"\"\n    # Convert key to bytes if it's a string\n    if isinstance(key, str):\n        key = key.encode(\"utf-8\")\n\n    # Convert list of ints to bytes\n    # Method 1: Convert each int to bytes and concatenate\n    data = b\"\".join(i.to_bytes(4, byteorder=\"big\") for i in int_list)\n\n    # Create HMAC\n    h = hmac.new(key, data, hash_func)\n    return h.hexdigest()\n\n\nclass OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):\n    \"\"\"\n    Collator for online teacher training.\n    \"\"\"\n\n    DEFAULT_LABEL_PAD_TOKEN_ID: int = -100\n\n    def __init__(\n        self,\n        *args: Any,\n        kd_online_server_base_url: Optional[str] = None,\n        kd_online_topk: Optional[int] = None,\n        kd_temperature: Optional[float] = 1.0,\n        kd_online_server: Optional[str] = \"vllm\",\n        kd_online_timeout: Optional[int] = 120,\n        kd_cache_dir: Optional[str] = None,\n        kd_normalize_topk: Optional[bool] = True,\n        **kwargs: Any,\n    ):\n        super().__init__(*args, **kwargs)\n\n        if kd_online_server_base_url is None:\n            raise ValueError(\n                \"kd_online_server_base_url must be provided for OnlineTeacherDataloader\"\n            )\n        if kd_online_topk is None or kd_online_topk <= 0:\n            raise ValueError(\n                \"kd_online_topk must be a positive integer for OnlineTeacherDataloader\"\n            )\n\n        self.kd_online_server_base_url = kd_online_server_base_url.rstrip(\"/\")\n        self.kd_online_topk = kd_online_topk\n        self.kd_temperature = kd_temperature\n        self.kd_online_server = kd_online_server\n        self.http_session = requests.Session()\n        self.kd_online_timeout = kd_online_timeout\n        self.kd_cache_dir = kd_cache_dir\n        self.kd_normalize_topk = kd_normalize_topk\n\n    def _normalize_logprobs(self, raw_logprobs: List[float]) -> List[float]:\n        \"\"\"\n        Re-normalizes top-k raw logprobs as probabilities, and converts back to logprobs.\n        \"\"\"\n        if not raw_logprobs or self.kd_online_topk == 0:\n            return (\n                [-float(\"inf\")] * self.kd_online_topk if self.kd_online_topk > 0 else []\n            )\n\n        raw_logprobs_tensor = torch.tensor(raw_logprobs, dtype=torch.float32)\n        return normalize_logprobs(raw_logprobs_tensor, self.kd_online_topk).tolist()\n\n    @retry_on_request_exceptions(max_retries=10, delay=5)\n    def fetch_online_logprobs_sglang(\n        self, batch_input_ids: List[List[int]], labels: List[List[int]]\n    ):\n        \"\"\"\n        Fetches logprobs from an online teacher served by sglang for a batch of input_ids.\n        Assumes API returns token IDs as strings in logprob dictionary keys.\n        \"\"\"\n        api_endpoint = f\"{self.kd_online_server_base_url}/generate\"\n\n        payload = {\n            \"input_ids\": batch_input_ids,\n            \"return_logprob\": True,\n            \"top_logprobs_num\": self.kd_online_topk,\n            \"logprob_start_len\": 0,\n            \"return_text_in_logprobs\": True,\n            \"echo\": True,\n            \"sampling_params\": {\n                \"max_new_tokens\": 0,\n                \"temperature\": self.kd_temperature,\n                \"skip_special_tokens\": False,\n            },\n        }\n\n        # Initialize with empty lists, so if API call fails, these are returned.\n        ret_data_target_token_ids: List[List[List[int]]] = []\n        ret_data_target_logprobs: List[List[List[float]]] = []\n        ret_data_target_mask: List[List[List[int]]] = []\n\n        try:\n            response = self.http_session.post(\n                api_endpoint, json=payload, timeout=self.kd_online_timeout\n            )\n            response.raise_for_status()\n            api_data: list[dict] = response.json()\n\n            # Ensure api_data is a list, and its length matches batch_input_ids\n            if not isinstance(api_data, list) or len(api_data) != len(batch_input_ids):\n                LOG.error(\n                    f\"API response format error. Expected a list of {len(batch_input_ids)} \"\n                    f\"items, got {type(api_data)} with length {len(api_data) if isinstance(api_data, list) else 'N/A'}.\"\n                )\n                # Return empty data; items processed later will get default empty KD fields\n                return {\n                    \"target_token_ids\": ret_data_target_token_ids,\n                    \"target_logprobs\": ret_data_target_logprobs,\n                    \"target_mask\": ret_data_target_mask,\n                }\n\n            for sequence_data, seq_input_ids, seq_labels in zip(\n                api_data, batch_input_ids, labels, strict=False\n            ):\n                current_target_logprobs = []\n                current_target_token_ids = []\n                current_target_mask = []\n\n                meta_info = sequence_data.pop(\"meta_info\", {})\n                # Ensure input_top_logprobs is a list\n                input_top_logprobs: Optional[list[None | list[tuple]]] = meta_info.pop(\n                    \"input_top_logprobs\", []\n                )\n                if not isinstance(input_top_logprobs, list):\n                    LOG.warning(\n                        f\"Received non-list input_top_logprobs: {input_top_logprobs}. Skipping sequence.\"\n                    )\n                    input_top_logprobs = []  # Treat as empty\n\n                # basic check that the logprob data len matches the input len, so no need to handle padding\n                assert len(seq_input_ids) == len(input_top_logprobs)\n\n                for i, _, label in zip(\n                    range(len(seq_input_ids)), seq_input_ids, seq_labels, strict=False\n                ):\n                    if i < len(input_top_logprobs) and input_top_logprobs[i] is None:\n                        # this is always the case for the first token.\n                        # there is never logprob data for the first token since that's a true input\n                        # so we replace the None value with padding data\n                        current_target_logprobs.append(\n                            [-float(\"inf\")] * self.kd_online_topk\n                        )\n                        current_target_token_ids.append([0] * self.kd_online_topk)\n                        current_target_mask.append([0] * self.kd_online_topk)\n                    elif (\n                        i < len(input_top_logprobs)\n                        and input_top_logprobs[i] is not None\n                    ):\n                        pos_top_logprobs_data = input_top_logprobs[i]\n                        # Ensure pos_top_logprobs_data is a list of lists as expected\n                        if not (\n                            isinstance(pos_top_logprobs_data, list)\n                            and all(\n                                isinstance(item, list) for item in pos_top_logprobs_data\n                            )\n                            and len(pos_top_logprobs_data) > 0\n                            and len(pos_top_logprobs_data[0]) == 3\n                        ):  # [logprob, token_id, token_str]\n                            LOG.warning(\n                                f\"Malformed pos_top_logprobs_data: {pos_top_logprobs_data}. Padding this position.\"\n                            )\n                            current_target_logprobs.append(\n                                [-float(\"inf\")] * self.kd_online_topk\n                            )\n                            current_target_token_ids.append([0] * self.kd_online_topk)\n                            current_target_mask.append([0] * self.kd_online_topk)\n                            continue\n\n                        # pos_top_logprobs: list of logprobs, pos_token_ids: list of token_ids\n                        pos_logprobs_raw, pos_token_ids, _ = [\n                            list(row)\n                            for row in zip(*pos_top_logprobs_data, strict=False)\n                        ]\n\n                        # Ensure correct length (top_k)\n                        if len(pos_logprobs_raw) < self.kd_online_topk:\n                            pad_len = self.kd_online_topk - len(pos_logprobs_raw)\n                            pos_logprobs_raw.extend([-float(\"inf\")] * pad_len)\n                            pos_token_ids.extend([0] * pad_len)  # Pad with 0 token_id\n\n                        # truncate to top_k in case the response was longer\n                        current_target_token_ids.append(\n                            pos_token_ids[: self.kd_online_topk]\n                        )\n\n                        if self.kd_normalize_topk:\n                            normalized_logprobs_for_position = self._normalize_logprobs(\n                                pos_logprobs_raw[: self.kd_online_topk]\n                            )\n                            current_target_logprobs.append(\n                                normalized_logprobs_for_position\n                            )\n                        else:\n                            current_target_logprobs.append(\n                                pos_logprobs_raw[: self.kd_online_topk]\n                            )\n\n                        # Mask depends on the corresponding label for the student\n                        if label == self.DEFAULT_LABEL_PAD_TOKEN_ID:\n                            current_target_mask.append([0] * self.kd_online_topk)\n                        else:\n                            current_target_mask.append([1] * self.kd_online_topk)\n                    else:\n                        # Pad if no logprobs for this position (either due to length mismatch or None entry)\n                        current_target_logprobs.append(\n                            [-float(\"inf\")] * self.kd_online_topk\n                        )\n                        current_target_token_ids.append([0] * self.kd_online_topk)\n                        current_target_mask.append([0] * self.kd_online_topk)\n\n                ret_data_target_token_ids.append(current_target_token_ids)\n                ret_data_target_logprobs.append(current_target_logprobs)\n                ret_data_target_mask.append(current_target_mask)\n\n        except requests.exceptions.RequestException as e:\n            LOG.error(f\"Error fetching logprobs from online teacher: {e}\")\n            raise e\n            # ret_logprobs_data will be returned with empty lists, handled by the caller.\n        except Exception as e:  # Catch other potential errors during processing\n            LOG.error(\n                f\"Unexpected error processing API response in fetch_online_logprobs: {e}\",\n                exc_info=True,\n            )\n            raise e\n\n        return {\n            \"target_token_ids\": ret_data_target_token_ids,\n            \"target_logprobs\": ret_data_target_logprobs,\n            \"target_mask\": ret_data_target_mask,\n        }\n\n    @retry_on_request_exceptions(max_retries=10, delay=5)\n    def fetch_online_logprobs_vllm(\n        self, batch_input_ids: List[List[int]], labels: List[List[int]]\n    ):\n        \"\"\"\n        Fetches logprobs from an online teacher served by vllm for a batch of input_ids.\n        Assumes API returns token IDs as strings in logprob dictionary keys.\n        \"\"\"\n        api_endpoint = f\"{self.kd_online_server_base_url}/v1/completions\"\n\n        payload = {\n            \"prompt\": batch_input_ids,\n            \"echo\": True,\n            \"logprobs\": True,\n            \"prompt_logprobs\": self.kd_online_topk,\n            \"top_logprobs\": self.kd_online_topk,\n            \"max_new_tokens\": 0,\n            \"skip_special_tokens\": False,\n            \"temperature\": self.kd_temperature,\n            \"sampling_params\": {\n                \"max_tokens\": 0,\n            },\n        }\n\n        # Initialize with empty lists, so if API call fails, these are returned.\n        ret_data_target_token_ids: List[List[List[int]]] = []\n        ret_data_target_logprobs: List[List[List[float]]] = []\n        ret_data_target_mask: List[List[List[int]]] = []\n\n        try:\n            headers = {\"Accept-Encoding\": \"deflate, gzip, br, zstd\"}\n            response = self.http_session.post(\n                api_endpoint,\n                json=payload,\n                headers=headers,\n                timeout=self.kd_online_timeout,\n            )\n            response.raise_for_status()\n            api_data: dict = orjson.loads(response.content)\n            choices: list[dict] = api_data[\"choices\"]\n\n            # Ensure api_data is a list, and its length matches batch_input_ids\n            if not isinstance(choices, list) or len(choices) != len(batch_input_ids):\n                LOG.error(\n                    f\"API response format error. Expected a list of {len(batch_input_ids)} \"\n                    f\"items, got {type(api_data)} with length {len(api_data) if isinstance(api_data, list) else 'N/A'}.\"\n                )\n                # Return empty data; items processed later will get default empty KD fields\n                return {\n                    \"target_token_ids\": ret_data_target_token_ids,\n                    \"target_logprobs\": ret_data_target_logprobs,\n                    \"target_mask\": ret_data_target_mask,\n                }\n\n            for sequence_data, seq_input_ids, seq_labels in zip(\n                choices, batch_input_ids, labels, strict=False\n            ):\n                # seq_input_ids: List[int]\n                # seq_labels: List[int]\n\n                current_target_logprobs = []\n                current_target_token_ids = []\n                current_target_mask = []\n\n                # Ensure input_top_logprobs is a list\n                input_top_logprobs: Optional[list[None | dict[str, dict]]] = (\n                    sequence_data.pop(\"prompt_logprobs\", [])\n                )\n\n                if not isinstance(input_top_logprobs, list):\n                    LOG.warning(\n                        f\"Received non-list input_top_logprobs: {input_top_logprobs}. Skipping sequence.\"\n                    )\n                    input_top_logprobs = []  # Treat as empty\n\n                # basic check that the logprob data len matches the input len, so no need to handle padding\n                assert len(seq_input_ids) == len(input_top_logprobs)\n\n                seq_len = len(seq_input_ids)\n\n                for i, _, label in zip(\n                    range(seq_len), seq_input_ids, seq_labels, strict=False\n                ):\n                    if i < len(input_top_logprobs) and input_top_logprobs[i] is None:\n                        # this is always the case for the first token.\n                        # there is never logprob data for the first token since that's a true input\n                        continue\n                    if (\n                        i < len(input_top_logprobs)\n                        and input_top_logprobs[i] is not None\n                    ):\n                        pos_top_logprobs_data: dict[str, dict] = input_top_logprobs[i]  # type: ignore[assignment]\n                        # Ensure pos_top_logprobs_data is a list of lists as expected\n                        if not (\n                            isinstance(pos_top_logprobs_data, dict)\n                            and all(\n                                isinstance(item, dict)\n                                for item in pos_top_logprobs_data.values()\n                            )\n                            and len(pos_top_logprobs_data.keys()) > 0\n                        ):  # [logprob, token_id, token_str]\n                            LOG.warning(\n                                f\"Malformed pos_top_logprobs_data: {pos_top_logprobs_data}. Padding this position.\"\n                            )\n                            current_target_logprobs.append(\n                                [-float(\"inf\")] * self.kd_online_topk\n                            )\n                            current_target_token_ids.append(\n                                list(range(self.kd_online_topk))\n                            )\n                            current_target_mask.append([0] * self.kd_online_topk)\n                            continue\n\n                        # pos_top_logprobs: list of logprobs, pos_token_ids: list of token_ids\n                        pos_token_ids_str = list(pos_top_logprobs_data.keys())\n                        pos_logprobs_dict = pos_top_logprobs_data.values()\n                        pos_token_ids = [\n                            int(token_id) for token_id in pos_token_ids_str\n                        ]\n                        pos_logprobs_raw = [\n                            float(logprob.get(\"logprob\", -float(\"inf\")))\n                            for logprob in pos_logprobs_dict\n                        ]\n\n                        # Ensure correct length (top_k)\n                        if len(pos_logprobs_raw) < self.kd_online_topk:\n                            pad_len = self.kd_online_topk - len(pos_logprobs_raw)\n                            LOG.warning(\n                                f\"Padding position {i} with {pad_len} top-k tokens and logprobs.\"\n                            )\n                            pos_logprobs_raw.extend([-float(\"inf\")] * pad_len)\n                            pos_token_ids.extend([0] * pad_len)  # Pad with 0 token_id\n\n                        # truncate to top_k in case the response was longer\n                        current_target_token_ids.append(\n                            pos_token_ids[: self.kd_online_topk]\n                        )\n\n                        if self.kd_normalize_topk:\n                            normalized_logprobs_for_position = self._normalize_logprobs(\n                                pos_logprobs_raw[: self.kd_online_topk]\n                            )\n                            current_target_logprobs.append(\n                                normalized_logprobs_for_position\n                            )\n                        else:\n                            current_target_logprobs.append(\n                                pos_logprobs_raw[: self.kd_online_topk]\n                            )\n\n                        # Mask depends on the corresponding label for the student\n                        if label == self.DEFAULT_LABEL_PAD_TOKEN_ID:\n                            current_target_mask.append([0] * self.kd_online_topk)\n                        else:\n                            current_target_mask.append([1] * self.kd_online_topk)\n                    else:\n                        # Pad if no logprobs for this position (either due to length mismatch or None entry)\n                        current_target_logprobs.append(\n                            [-float(\"inf\")] * self.kd_online_topk\n                        )\n                        current_target_token_ids.append(\n                            list(range(self.kd_online_topk))\n                        )\n                        current_target_mask.append([0] * self.kd_online_topk)\n                for _ in range(max(0, seq_len - len(current_target_logprobs))):\n                    current_target_logprobs.append(\n                        [-float(\"inf\")] * self.kd_online_topk\n                    )\n                    current_target_token_ids.append(list(range(self.kd_online_topk)))\n                    current_target_mask.append([0] * self.kd_online_topk)\n\n                ret_data_target_token_ids.append(current_target_token_ids)\n                ret_data_target_logprobs.append(current_target_logprobs)\n                ret_data_target_mask.append(current_target_mask)\n\n                # TODO save and load targets to disk for caching for next epoch\n                # generate a hmac SHA256 hash over the list seq_input_ids and convert it to an int\n                # if self.kd_cache_dir:\n                #     hash_input_ids = hmac_sha_from_int_list(\n                #         seq_input_ids, f\"{self.kd_online_server_base_url}:{self.kd_online_topk}\"\n                #     )\n                #     with open(f\"{self.kd_cache_dir}/{hash_input_ids}.parquet\", \"wb\") as f:\n                #         pd.DataFrame(ret_logprobs_data).to_parquet(f, index=False)\n\n        except requests.exceptions.RequestException as e:\n            LOG.error(f\"Error fetching logprobs from online teacher: {e}\")\n            raise e\n            # ret_logprobs_data will be returned with empty lists, handled by the caller.\n        except Exception as e:  # Catch other potential errors during processing\n            LOG.error(\n                f\"Unexpected error processing API response in fetch_online_logprobs: {e}\",\n                exc_info=True,\n            )\n            raise e\n\n        return {\n            \"target_token_ids\": ret_data_target_token_ids,\n            \"target_logprobs\": ret_data_target_logprobs,\n            \"target_mask\": ret_data_target_mask,\n        }\n\n    def __call__(\n        self, features: List[List[Dict[str, Any]]], return_tensors: Optional[str] = None\n    ) -> Dict[str, Any]:\n        if not features:\n            return super().__call__(features, return_tensors=return_tensors)\n\n        for (\n            sub_batch_features\n        ) in features:  # sub_batch_features is List[Dict[str, Any]]\n            if not sub_batch_features:\n                continue\n\n            input_ids_for_api_call: List[List[int]] = []\n            labels_for_api_call: List[List[int]] = []\n            # Store references to the original item dictionaries to update them in-place\n            items_for_api_call: List[Dict[str, Any]] = []\n\n            for item_dict in sub_batch_features:\n                if not isinstance(item_dict, dict):\n                    LOG.warning(\n                        f\"Skipping non-dict item in sub_batch_features: {item_dict}\"\n                    )\n                    continue\n\n                current_input_ids = item_dict.get(\"input_ids\")\n                current_labels = item_dict.get(\"labels\")\n\n                if current_input_ids is not None and current_labels is not None:\n                    # Ensure input_ids and labels are lists of ints for JSON serialization\n                    input_ids_list = (\n                        current_input_ids.tolist()\n                        if hasattr(current_input_ids, \"tolist\")\n                        else list(current_input_ids)\n                    )\n                    labels_list = (\n                        current_labels.tolist()\n                        if hasattr(current_labels, \"tolist\")\n                        else list(current_labels)\n                    )\n\n                    input_ids_for_api_call.append(input_ids_list)\n                    labels_for_api_call.append(labels_list)\n                    items_for_api_call.append(item_dict)\n                else:\n                    # This item will not get teacher logprobs from the API.\n                    # Initialize KD fields to empty lists so downstream collators handle them uniformly.\n                    item_dict.setdefault(\"target_token_ids\", [])\n                    item_dict.setdefault(\"target_logprobs\", [])\n                    item_dict.setdefault(\"target_mask\", [])\n\n            # print(items_for_api_call)\n            if items_for_api_call:  # Only call API if there's something to process\n                if self.kd_online_server == \"sglang\":\n                    api_responses_for_sub_batch = self.fetch_online_logprobs_sglang(\n                        input_ids_for_api_call, labels_for_api_call\n                    )\n                else:\n                    api_responses_for_sub_batch = self.fetch_online_logprobs_vllm(\n                        input_ids_for_api_call, labels_for_api_call\n                    )\n\n                # api_responses_for_sub_batch has keys: \"target_token_ids\", \"target_logprobs\", \"target_mask\"\n                # Each value is a list, corresponding to items_for_api_call\n                for i, item_to_update in enumerate(items_for_api_call):\n                    # TODO make sure to figure out which input in sub_batch_features to update the batch in the original `features` object so the super class can handle it properly.\n                    if api_responses_for_sub_batch and i < len(\n                        api_responses_for_sub_batch[\"target_token_ids\"]\n                    ):  # Check bounds\n                        assert len(\n                            api_responses_for_sub_batch[\"target_token_ids\"][i]\n                        ) == len(item_to_update[\"input_ids\"])\n                        assert len(\n                            api_responses_for_sub_batch[\"target_logprobs\"][i]\n                        ) == len(item_to_update[\"input_ids\"])\n                        assert len(\n                            api_responses_for_sub_batch[\"target_mask\"][i]\n                        ) == len(item_to_update[\"labels\"])\n                        item_to_update[\"target_token_ids\"] = (\n                            api_responses_for_sub_batch[\"target_token_ids\"][i]\n                        )\n                        item_to_update[\"target_logprobs\"] = api_responses_for_sub_batch[\n                            \"target_logprobs\"\n                        ][i]\n                        item_to_update[\"target_mask\"] = api_responses_for_sub_batch[\n                            \"target_mask\"\n                        ][i]\n                    else:\n                        # API call failed for this item, or response was shorter than expected.\n                        # Ensure KD fields are initialized as empty lists.\n                        LOG.warning(\n                            f\" (index {i}), or API response was too short. \"\n                            f\"API response keys: {list(api_responses_for_sub_batch.keys()) if api_responses_for_sub_batch else 'None'}\"\n                        )\n                        item_to_update.setdefault(\"target_token_ids\", [])\n                        item_to_update.setdefault(\"target_logprobs\", [])\n                        item_to_update.setdefault(\"target_mask\", [])\n\n        return super().__call__(features, return_tensors=return_tensors)\n"
  },
  {
    "path": "src/axolotl/integrations/kd/kernels/__init__.py",
    "content": "\"\"\"\nLiger Chunked loss optimizations module\n\"\"\"\n\nfrom .liger import LigerFusedLinearKLTopKLogprobLoss\nfrom .models import apply_kernel\n\n__all__ = [\"LigerFusedLinearKLTopKLogprobLoss\", \"apply_kernel\"]\n"
  },
  {
    "path": "src/axolotl/integrations/kd/kernels/liger.py",
    "content": "\"\"\"\nLiger Kernels for Chunked Top-K Log-Prob Distillation\n\"\"\"\n\nimport torch\nimport torch.nn.functional as F\nfrom liger_kernel.chunked_loss.fused_linear_distillation import (\n    LigerFusedLinearDistillationBase,\n)\n\nfrom axolotl.integrations.kd.utils import normalize_logprobs\n\n\nclass LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):\n    \"\"\"\n    Chunked kl-div loss for top-k logprobs\n    \"\"\"\n\n    @staticmethod\n    def distillation_loss_fn(\n        student_logits_temp_scaled: torch.Tensor,  # [chunk_size, vocab_size], already temp-scaled\n        target_token_ids_chunk: torch.Tensor,  # [chunk_size, top_k]\n        target_logprobs_chunk: torch.Tensor,  # [chunk_size, top_k], already temp-scaled and normalized logprobs\n        target_mask_chunk: torch.Tensor,  # [chunk_size, top_k]\n        beta: float = 0.0,\n        normalize_topk: bool = True,\n    ) -> torch.Tensor:\n        \"\"\"\n        Compute Top-K KL divergence loss for a chunk.\n        Args:\n            student_logits_temp_scaled: Student logits, scaled by temperature. Shape: (N, V).\n            target_token_ids_chunk: Top-k teacher token IDs. Shape: (N, K).\n            target_logprobs_chunk: Top-k teacher log probabilities (temp-scaled, normalized). Shape: (N, K).\n            target_mask_chunk: Mask for valid top-k tokens. Shape: (N, K).\n            beta: Controls the type of KL divergence.\n                  0.0 for Forward KL (P_teacher || P_student).\n                  1.0 for Reverse KL (P_student || P_teacher).\n                  0.5 for Symmetric KL (average of Forward and Reverse).\n            normalize_topk: Whether to normalize the log probabilities\n        Returns:\n            Sum of KL divergence losses for the chunk.\n        \"\"\"\n        topk = target_token_ids_chunk.shape[-1]\n        student_logits_temp_scaled = (  # [chunk_size, vocab_size]\n            student_logits_temp_scaled.float()\n        )\n        target_logprobs_chunk = target_logprobs_chunk.float()\n\n        # Gather student logits for the top-k teacher token IDs\n        # target_token_ids_chunk: [chunk_size, top_k]\n        # student_logits_topk_temp_scaled: [chunk_size, top_k]\n        student_logits_topk_temp_scaled = torch.gather(\n            student_logits_temp_scaled, dim=-1, index=target_token_ids_chunk\n        )\n\n        # Student log-probabilities for the gathered top-k tokens\n        student_lse = torch.logsumexp(\n            student_logits_temp_scaled, dim=-1, keepdim=True\n        )  # [chunk_size, 1]\n        student_logprobs_topk_temp_scaled = (\n            student_logits_topk_temp_scaled - student_lse\n        )\n\n        # we have the top-k student logprobs, normalize them\n        if normalize_topk:\n            student_logprobs_topk_temp_scaled = normalize_logprobs(\n                student_logprobs_topk_temp_scaled, topk\n            )\n\n        valid_mask = target_mask_chunk.to(torch.bool)  # [chunk_size, top_k]\n\n        student_logprobs_topk_valid = student_logprobs_topk_temp_scaled[valid_mask]\n        teacher_logprobs_valid = target_logprobs_chunk[valid_mask]\n\n        # Teacher probabilities P(y|x_teacher) from logprobs\n        # target_logprobs_valid are already normalized (log(softmax(teacher_logits/T)))\n        teacher_probs_valid = teacher_logprobs_valid.exp()\n        # Student probabilities P_student from log P_student\n        student_probs_topk_valid = student_logprobs_topk_valid.exp()\n\n        # kd_loss_per_token = torch.zeros_like(target_logprobs_valid)\n\n        # KL divergence: sum(P_teacher * (log P_teacher - log P_student))\n        # = sum(P_teacher * log P_teacher) - sum(P_teacher * log P_student)\n        # The distillation loss is often formulated as -sum(P_teacher * log P_student)\n        # or as sum(P_teacher * (log_softmax_teacher - log_softmax_student))\n        # Here, target_logprobs_valid are log_softmax_teacher.\n        # student_logprobs_topk_valid are log_softmax_student (for the selected K indices).\n        if beta == 0.0:  # Contribution from Forward KL\n            fwd_kl_per_token = teacher_probs_valid * (\n                teacher_logprobs_valid - student_logprobs_topk_valid\n            )\n            kd_loss = fwd_kl_per_token.sum()\n        elif beta == 1.0:  # Contribution from Reverse KL\n            rev_kl_per_token = student_probs_topk_valid * (\n                student_logprobs_topk_valid - teacher_logprobs_valid\n            )\n            kd_loss = rev_kl_per_token.sum()\n        else:\n            # JSD - Jensen-Shannon Divergence / Symmetric\n            mean_probs = (\n                1 - beta\n            ) * student_probs_topk_valid + beta * teacher_probs_valid\n            log_mean_probs = mean_probs.log()\n            student_kl = F.kl_div(\n                log_mean_probs,\n                student_logprobs_topk_valid,\n                reduction=\"sum\",\n                log_target=True,\n            )\n            teacher_kl = F.kl_div(\n                log_mean_probs, teacher_logprobs_valid, reduction=\"sum\", log_target=True\n            )\n            jsd_loss = beta * teacher_kl + (1 - beta) * student_kl\n            kd_loss = jsd_loss\n\n        return kd_loss\n\n    @staticmethod\n    def _compute_loss_kl_topk(\n        student_input_chunk: torch.Tensor,\n        student_weight: torch.Tensor,\n        # Args for student_bias, target_token_ids_chunk etc. are passed to the lambda wrapped by grad_and_value\n        # or through `partial`. Let's make them explicit here for clarity.\n        target_token_ids_chunk: torch.Tensor,\n        target_logprobs_chunk: torch.Tensor,\n        target_mask_chunk: torch.Tensor,\n        target_chunk: torch.Tensor,  # For hard loss (true labels)\n        student_bias: torch.Tensor = None,  # This will be one of the grad targets\n        # Other params passed via `partial` from `forward`\n        distillation_loss_fn=None,\n        ignore_index: int = -100,\n        weight_hard_loss: float = 0.5,\n        weight_soft_loss: float = 0.5,\n        compute_ce_loss: bool = True,\n        temperature: float = 1.0,\n        beta: float = 0.0,\n        normalize_topk: bool = True,\n    ):\n        # Compute student logits for the chunk from hidden states and LM head\n        # student_input_chunk: [chunk_size, hidden_dim]\n        # student_lm_head_weight: [vocab_size, hidden_dim]\n        # student_logits_chunk: [chunk_size, vocab_size]\n        student_logits_chunk = F.linear(\n            student_input_chunk, student_weight, student_bias\n        )\n\n        ce_loss = torch.tensor(\n            0.0, device=student_logits_chunk.device, dtype=student_logits_chunk.dtype\n        )\n        if compute_ce_loss and weight_hard_loss > 0.0:\n            ce_loss = F.cross_entropy(\n                student_logits_chunk.view(-1, student_logits_chunk.shape[-1]),\n                target_chunk.view(-1),\n                reduction=\"sum\",\n                ignore_index=ignore_index,\n            )\n\n        soft_loss = torch.tensor(\n            0.0, device=student_logits_chunk.device, dtype=student_logits_chunk.dtype\n        )\n        if weight_soft_loss > 0.0:\n            student_logits_chunk_temp_scaled = student_logits_chunk / temperature\n\n            # Assuming student_weight.shape[0] (vocab_size) is adequate for target_token_ids_chunk.max()\n            # No explicit padding here; user must ensure vocab alignment or pre-pad student_weight.\n\n            soft_loss = distillation_loss_fn(\n                student_logits_chunk_temp_scaled,\n                target_token_ids_chunk,\n                target_logprobs_chunk,\n                target_mask_chunk,\n                beta=beta,\n                normalize_topk=normalize_topk,\n            )\n\n        return soft_loss, ce_loss\n\n    @classmethod\n    def forward(\n        cls,\n        ctx,\n        student_input: torch.Tensor,  # [batch_size, seq_len, dim]\n        student_lm_head_weight: torch.Tensor,  # [dim, vocab_size]\n        target_token_ids: torch.Tensor,  # [batch_size, seq_len, top_k]\n        target_logprobs: torch.Tensor,  # [batch_size, seq_len, top_k]\n        target_mask: torch.Tensor,  # [batch_size, seq_len, top_k]\n        true_labels: torch.Tensor,  # [batch_size, seq_len]\n        student_lm_head_bias: torch.Tensor = None,\n        weight_hard_loss: float = 0.5,\n        weight_soft_loss: float = 0.5,\n        ignore_index: int = -100,\n        temperature: float = 1.0,\n        beta: float = 0.0,\n        compiled: bool = False,\n        chunk_size: int = 1024,\n        compute_ce_loss: bool = True,\n        normalize_topk: bool = True,\n    ):\n        CHUNK_SIZE = chunk_size\n        grad_weight_acc = torch.zeros_like(student_lm_head_weight)\n        grad_inputs_list = []\n        grad_bias_acc = (\n            torch.zeros_like(student_lm_head_bias)\n            if student_lm_head_bias is not None\n            else None\n        )\n        kd_loss_acc = torch.zeros(\n            (), device=student_input.device, dtype=student_input.dtype\n        )\n        ce_loss_acc = torch.zeros(\n            (), device=student_input.device, dtype=student_input.dtype\n        )\n\n        # This function will be what torch.func.grad_and_value differentiates.\n        # It takes student_input_chunk, student_weight (full), student_bias (full) as primals.\n        # Other necessary data (target_*, etc.) are passed as non-differentiable arguments.\n        def loss_fn_for_grad(\n            _student_input_chunk,\n            _student_lm_head_weight,  # full weight\n            _student_lm_head_bias,  # full bias\n            # Fixed arguments for a given chunk, not differentiated:\n            _target_token_ids_chunk,\n            _target_logprobs_chunk,\n            _target_mask_chunk,\n            _true_labels_chunk,\n        ):\n            return cls._compute_loss_kl_topk(\n                student_input_chunk=_student_input_chunk,\n                student_weight=_student_lm_head_weight,\n                target_token_ids_chunk=_target_token_ids_chunk,\n                target_logprobs_chunk=_target_logprobs_chunk,\n                target_mask_chunk=_target_mask_chunk,\n                target_chunk=_true_labels_chunk,\n                student_bias=_student_lm_head_bias,\n                distillation_loss_fn=cls.distillation_loss_fn,\n                ignore_index=ignore_index,\n                weight_hard_loss=weight_hard_loss,\n                weight_soft_loss=weight_soft_loss,\n                compute_ce_loss=compute_ce_loss,\n                temperature=temperature,\n                beta=beta,\n                normalize_topk=normalize_topk,\n            )\n\n        def accumulate_chunk_grads(\n            student_input_chunk_ac,\n            target_token_ids_chunk_ac,\n            target_logprobs_chunk_ac,\n            target_mask_chunk_ac,\n            true_labels_chunk_ac,\n        ):\n            # student_weight and student_bias are closed over from the outer scope (full tensors)\n            if student_lm_head_bias is not None:\n                (\n                    (chunk_grad_input, chunk_grad_weight, chunk_grad_bias),\n                    (chunk_kd_loss, chunk_ce_loss),\n                ) = torch.func.grad_and_value(\n                    loss_fn_for_grad, argnums=(0, 1, 2), has_aux=True\n                )(\n                    student_input_chunk_ac,\n                    student_lm_head_weight,\n                    student_lm_head_bias,  # primals\n                    target_token_ids_chunk_ac,\n                    target_logprobs_chunk_ac,\n                    target_mask_chunk_ac,\n                    true_labels_chunk_ac,\n                )  # non-primals\n                grad_bias_acc.add_(chunk_grad_bias)\n            else:\n                argnums_for_grad = (0, 1)  # Differentiate wrt input_chunk, weight\n                (\n                    (chunk_grad_input, chunk_grad_weight),  # No grad for bias\n                    (chunk_kd_loss, chunk_ce_loss),\n                ) = torch.func.grad_and_value(\n                    loss_fn_for_grad, argnums=argnums_for_grad, has_aux=True\n                )(\n                    student_input_chunk_ac,\n                    student_lm_head_weight,\n                    None,  # Pass None for student_bias primal\n                    target_token_ids_chunk_ac,\n                    target_logprobs_chunk_ac,\n                    target_mask_chunk_ac,\n                    true_labels_chunk_ac,\n                )\n\n            grad_weight_acc.add_(chunk_grad_weight)\n            kd_loss_acc.add_(chunk_kd_loss)\n            ce_loss_acc.add_(chunk_ce_loss)\n\n            return chunk_grad_input\n\n        if compiled:\n            accumulate_chunk_grads_compiled = torch.compile(\n                accumulate_chunk_grads, dynamic=True, backend=\"inductor\"\n            )  # dynamic=True often helpful\n        else:\n            accumulate_chunk_grads_compiled = accumulate_chunk_grads\n\n        # Use the same chunking logic as LigerFusedLinearDistillationBase.forward\n        B, N, D = student_input.shape\n        K = target_token_ids.shape[-1]\n\n        student_input_flat = student_input.reshape(-1, student_input.shape[-1])\n        target_token_ids_flat = target_token_ids.reshape(-1, target_token_ids.shape[-1])\n        target_logprobs_flat = target_logprobs.reshape(-1, target_logprobs.shape[-1])\n        target_mask_flat = target_mask.reshape(-1, target_mask.shape[-1])\n        # pad and shift for cross entropy loss\n        true_labels = torch.nn.functional.pad(true_labels, (0, 1), value=ignore_index)\n        true_labels_flat = true_labels[:, 1:].contiguous().view(-1)\n\n        num_chunks = max(1, student_input_flat.shape[0] // CHUNK_SIZE)\n\n        _student_input_chunks = torch.chunk(\n            student_input_flat, chunks=num_chunks, dim=0\n        )\n        _target_token_ids_chunks = torch.chunk(\n            target_token_ids_flat, chunks=num_chunks, dim=0\n        )\n        _target_logprobs_chunks = torch.chunk(\n            target_logprobs_flat, chunks=num_chunks, dim=0\n        )\n        _target_mask_chunks = torch.chunk(target_mask_flat, chunks=num_chunks, dim=0)\n        _true_labels_chunks = torch.chunk(true_labels_flat, chunks=num_chunks, dim=0)\n\n        for i in range(num_chunks):\n            grad_input_chunk = accumulate_chunk_grads_compiled(\n                _student_input_chunks[i],\n                _target_token_ids_chunks[i],\n                _target_logprobs_chunks[i],\n                _target_mask_chunks[i],\n                _true_labels_chunks[i],\n            )\n            grad_inputs_list.append(grad_input_chunk)\n\n        grad_inputs_combined = torch.cat(grad_inputs_list, dim=0)\n        ctx.save_for_backward(grad_inputs_combined, grad_weight_acc, grad_bias_acc)\n\n        # For matching None returns in backward for non-tensor/non-grad_requiring inputs\n        ctx.hyperparams_count = 9  # Corresponds to number of hyperparams after main tensors in fwd signature\n        ctx.bias_was_none = student_lm_head_bias is None\n        ctx.orig_dims = (B, N, D, K)\n\n        # since this is packed, there is simply a single batch, so batchmean reduction of kl-div is simply the accumulated sum\n        # we still need to scale the kd_loss by the temp^2\n        kd_loss_acc = kd_loss_acc * (temperature**2)\n        final_loss = weight_soft_loss * kd_loss_acc + weight_hard_loss * ce_loss_acc\n\n        return final_loss\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        grad_input_flat, grad_weight, grad_bias_maybe = (\n            ctx.saved_tensors\n        )  # grad_input_flat is (B*N, D)\n\n        # Scale gradients by grad_output if it's not 1.0\n        if not torch.equal(\n            grad_output,\n            torch.tensor(1.0, device=grad_output.device, dtype=grad_output.dtype),\n        ):\n            grad_input_flat = grad_input_flat * grad_output\n            grad_weight = grad_weight * grad_output\n            if grad_bias_maybe is not None:\n                grad_bias_maybe = grad_bias_maybe * grad_output\n\n        # Reshape grad_input_flat to match original student_input shape (B, N, D)\n        # ctx.orig_dims stores (B, N, D, K)\n        # We need the first three dimensions for student_input's shape.\n        # Ensure that orig_dims are not (0,0,0,K) for empty inputs leading to view errors\n        if (\n            ctx.orig_dims[0] * ctx.orig_dims[1] * ctx.orig_dims[2] == 0\n            and grad_input_flat.numel() == 0\n        ):\n            # If original input was empty, gradient should also be empty with correct shape\n            grad_input_reshaped = torch.zeros(\n                ctx.orig_dims[0],\n                ctx.orig_dims[1],\n                ctx.orig_dims[2],\n                dtype=grad_input_flat.dtype,\n                device=grad_input_flat.device,\n            )\n        elif grad_input_flat.numel() == 0 and not (\n            ctx.orig_dims[0] * ctx.orig_dims[1] * ctx.orig_dims[2] == 0\n        ):\n            # This case should ideally not happen if forward path is correct (non-empty input -> non-empty flat grad)\n            # but as a safeguard:\n            grad_input_reshaped = torch.zeros(\n                ctx.orig_dims[0],\n                ctx.orig_dims[1],\n                ctx.orig_dims[2],\n                dtype=grad_input_flat.dtype,\n                device=grad_input_flat.device,\n            )\n        else:\n            grad_input_reshaped = grad_input_flat.view(\n                ctx.orig_dims[0], ctx.orig_dims[1], ctx.orig_dims[2]\n            )\n\n        nones_for_hyperparams = [None] * ctx.hyperparams_count\n        grad_bias_return = grad_bias_maybe if not ctx.bias_was_none else None\n\n        return (\n            grad_input_reshaped,  # Gradient for student_input (reshaped)\n            grad_weight,  # Gradient for student_lm_head_weight\n            None,  # Gradient for target_token_ids\n            None,  # Gradient for target_logprobs\n            None,  # Gradient for target_mask\n            None,  # Gradient for true_labels\n            grad_bias_return,  # Gradient for student_lm_head_bias\n            *nones_for_hyperparams,  # Grads for weight_hard_loss, ..., compute_ce_loss\n        )\n\n\nclass LigerFusedLinearKLTopKLogprobLoss(torch.nn.Module):\n    \"\"\"\n    wrapper for chunked top-k logprob kl-d\n    \"\"\"\n\n    def __init__(\n        self,\n        weight_hard_loss: float = 0.5,\n        weight_soft_loss: float = 0.5,\n        temperature: float = 1.0,  # This is the kd_temperature\n        beta: float = 1.0,\n        ignore_index: int = -100,\n        compiled: bool = True,\n        chunk_size: int = 1024,\n        compute_ce_loss: bool = True,\n        normalize_topk: bool = True,\n    ):\n        super().__init__()\n        if not (0.0 <= weight_hard_loss <= 1.0 and 0.0 <= weight_soft_loss <= 1.0):\n            raise ValueError(\"Loss weights must be between 0.0 and 1.0.\")\n        if temperature <= 0:\n            raise ValueError(\"Temperature must be positive.\")\n\n        self.weight_hard_loss = weight_hard_loss\n        self.weight_soft_loss = weight_soft_loss\n        self.temperature = temperature\n        self.beta = beta\n        self.ignore_index = ignore_index\n        self.compiled = compiled\n        self.chunk_size = chunk_size\n        self.compute_ce_loss = compute_ce_loss\n        self.normalize_topk = normalize_topk\n\n        if not self.compute_ce_loss and self.weight_hard_loss > 0.0:\n            print(\n                f\"Warning: compute_ce_loss is False, but weight_hard_loss ({self.weight_hard_loss}) > 0. Hard loss will effectively be zero.\"\n            )\n            # self.weight_hard_loss = 0.0 # Or let user manage this\n        if self.weight_soft_loss == 0.0:\n            print(\n                \"Warning: weight_soft_loss is 0.0. Soft (KD) loss will not be computed.\"\n            )\n\n    def forward(\n        self,\n        lm_head_weight: torch.Tensor,  # Weights of the linear layer in the LM head\n        student_hidden_states: torch.Tensor,  # student_hidden_states before the lm_head\n        target_token_ids: torch.Tensor,\n        target_logprobs: torch.Tensor,\n        target_mask: torch.Tensor,\n        true_labels: torch.Tensor,\n        student_bias: torch.Tensor = None,\n    ) -> torch.Tensor:\n        return LigerFusedLinearKLTopKLogprobFunction.apply(\n            student_hidden_states,\n            lm_head_weight,\n            target_token_ids,\n            target_logprobs,\n            target_mask,\n            true_labels,\n            student_bias,\n            self.weight_hard_loss,\n            self.weight_soft_loss,\n            self.ignore_index,\n            self.temperature,\n            self.beta,\n            self.compiled,\n            self.chunk_size,\n            self.compute_ce_loss,\n            self.normalize_topk,\n        )\n"
  },
  {
    "path": "src/axolotl/integrations/kd/kernels/models.py",
    "content": "\"\"\"\nmodel patcher for chunked top-k kl-div\n\"\"\"\n\nfrom typing import Optional, Union, Unpack\n\nimport torch\nfrom transformers import Cache\nfrom transformers.modeling_outputs import CausalLMOutputWithPast\n\ntry:\n    from transformers.modeling_flash_attention_utils import FlashAttentionKwargs\n    from transformers.utils import LossKwargs\n\n    class TransformersKwargs(FlashAttentionKwargs, LossKwargs):\n        \"\"\"\n        placeholder kwargs for hf model classes\n        \"\"\"\n\nexcept ImportError:\n    from transformers.utils.generic import (  # type: ignore[no-redef]\n        TransformersKwargs,\n    )\n\nfrom axolotl.utils.callbacks.models import get_causal_lm_model_cls_prefix\n\n\ndef kldiv_forward_llama_like(\n    self,\n    input_ids: Optional[torch.LongTensor] = None,\n    target_logprobs: Optional[torch.Tensor] = None,\n    target_token_ids: Optional[torch.LongTensor] = None,\n    target_mask: Optional[torch.Tensor] = None,\n    attention_mask: Optional[torch.Tensor] = None,\n    position_ids: Optional[torch.LongTensor] = None,\n    past_key_values: Optional[Cache] = None,\n    inputs_embeds: Optional[torch.FloatTensor] = None,\n    labels: Optional[torch.LongTensor] = None,\n    use_cache: Optional[bool] = None,\n    output_attentions: Optional[bool] = None,\n    output_hidden_states: Optional[bool] = None,\n    cache_position: Optional[torch.LongTensor] = None,\n    logits_to_keep: Union[int, torch.Tensor] = 0,\n    **kwargs: Unpack[TransformersKwargs],  # type: ignore[misc]\n) -> CausalLMOutputWithPast:\n    output_attentions = (\n        output_attentions\n        if output_attentions is not None\n        else self.config.output_attentions\n    )\n    output_hidden_states = (\n        output_hidden_states\n        if output_hidden_states is not None\n        else self.config.output_hidden_states\n    )\n\n    # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\n    outputs = self.model(\n        input_ids=input_ids,\n        attention_mask=attention_mask,\n        position_ids=position_ids,\n        past_key_values=past_key_values,\n        inputs_embeds=inputs_embeds,\n        use_cache=use_cache,\n        output_attentions=output_attentions,\n        output_hidden_states=output_hidden_states,\n        cache_position=cache_position,\n        **kwargs,\n    )\n\n    hidden_states = outputs.last_hidden_state\n\n    # Only compute necessary logits, and do not upcast them to float if we are not computing the loss\n    # TODO, we can optimize this further by filtering hidden_states on sequence dimension using labels != -100\n    # self._loss_function should be LigerFusedLinearKLTopKLogprobLoss\n\n    loss = self._loss_function(\n        self.lm_head.weight,\n        hidden_states,\n        target_token_ids,\n        target_logprobs,\n        target_mask,\n        true_labels=labels,\n    )\n    num_items_in_batch = kwargs.pop(\"num_items_in_batch\", -1)\n    if num_items_in_batch is not None and num_items_in_batch > 0:\n        loss = loss / num_items_in_batch\n\n    return CausalLMOutputWithPast(\n        loss=loss,\n        logits=None,\n        past_key_values=outputs.past_key_values,\n        hidden_states=outputs.hidden_states,\n        attentions=outputs.attentions,\n    )\n\n\ndef apply_kernel(model_type):\n    # Dynamically import the module and attention class\n    module_path = f\"transformers.models.{model_type}.modeling_{model_type}\"\n    model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type)\n    module = __import__(module_path, fromlist=[f\"{model_cls_prefix}ForCausalLM\"])\n    model_cls = getattr(module, f\"{model_cls_prefix}ForCausalLM\")\n    model_cls.forward = kldiv_forward_llama_like\n"
  },
  {
    "path": "src/axolotl/integrations/kd/topk_logprob/__init__.py",
    "content": ""
  },
  {
    "path": "src/axolotl/integrations/kd/topk_logprob/forward_kl.py",
    "content": "# Copyright 2024 Axolotl AI. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nloss for top_k KL divergence\n\"\"\"\n\nimport torch\nfrom torch import nn\n\n\n@torch.jit.script\ndef loss(\n    student_logits: torch.Tensor,\n    target_token_ids: torch.Tensor,\n    target_logprobs: torch.Tensor,\n    target_mask: torch.Tensor,\n    num_items_in_batch: int = -1,  # Use -1 to indicate \"None\"\n    kd_temperature: float = 1.0,\n) -> torch.Tensor:\n    \"\"\"\n    A KD loss function that is TorchScript-friendly.\n\n    Arguments:\n        student_logits (torch.Tensor): The logits of the student model.\n            Shape: [B, student_seq_len, vocab_size]\n        target_token_ids (torch.Tensor): The top-k teacher/target token IDs\n            Shape: [B, teacher_seq_len, top_k]\n        target_logprobs (torch.Tensor): The top-k teacher/target logprobs, these should already be re-normalized.\n            Shape: [B, teacher_seq_len, top_k]\n        target_mask (torch.Tensor): The mask for valid tokens.\n            Shape: [B, teacher_seq_len, top_k]\n        num_items_in_batch (int, optional): The number of items in the batch.\n        kd_temperature (float, optional): The temperature for KD.\n            Default: 1.0\n    \"\"\"\n\n    target_logprobs = target_logprobs.float()\n\n    # Determine the teacher sequence length\n    # target_token_ids shape: [B, teacher_seq_len, K]\n    # student_logits shape:   [B, student_seq_len, vocab_size]\n    teacher_seq_len = target_token_ids.shape[1]\n\n    # Slice student logits to match teacher-provided sequence length\n    student_logits_for_kd = (\n        student_logits[:, :teacher_seq_len, :] / kd_temperature\n    )  # [B, teacher_seq_len, vocab_size]\n\n    # keep in full precision for numerical stability of loss\n    student_logits_for_kd = student_logits_for_kd.float()\n\n    # Gather student logits for teacher's top-K tokens\n    student_logits_topk = torch.gather(\n        student_logits_for_kd, dim=-1, index=target_token_ids\n    )  # [B, teacher_seq_len, K]\n\n    # Compute logsumexp across full vocabulary\n    student_lse = torch.logsumexp(student_logits_for_kd, dim=-1, keepdim=True)\n\n    #  Convert just the top-k logits to logprobs\n    student_logprobs_topk = student_logits_topk - student_lse\n\n    # Convert teacher_mask to boolean for indexing\n    # In TorchScript, .bool() is sometimes unsupported, so we do:\n    valid_mask = target_mask.to(torch.bool)\n\n    # Prune tensors to only keep valid tokens\n    student_logprobs_topk = student_logprobs_topk[valid_mask]\n    target_logprobs = target_logprobs[valid_mask]\n\n    # Convert teacher logprobs to probabilities\n    teacher_probs = target_logprobs.exp()\n\n    # Compute forward KL\n    kd_loss_per_token = teacher_probs * (target_logprobs - student_logprobs_topk)\n    kd_loss = kd_loss_per_token.sum()\n\n    # Normalize by number of items (if provided) or by valid tokens\n    if num_items_in_batch > 0:\n        kd_loss = kd_loss / float(num_items_in_batch)\n    else:\n        # Fall back to average over valid tokens\n        kd_loss = kd_loss / float(kd_loss_per_token.size(0))\n\n    return kd_loss\n\n\nclass ChunkedTopKKDLoss(nn.Module):\n    \"\"\"\n    A wrapper that chunks (splits) the student and teacher outputs along the time dimension\n    to reduce peak memory usage when upcasting from bf16 to fp32, especially for large vocabularies.\n\n    Usage is analogous to ForwardKLWithChunkedOutputLoss but adapted to top-K teacher logprobs.\n    \"\"\"\n\n    def __init__(self, num_output_chunks: int = 8, kd_temperature: float = 1.0):\n        super().__init__()\n        self.num_output_chunks = num_output_chunks\n        self.kd_temperature = kd_temperature\n\n    def forward(\n        self,\n        student_logits: torch.Tensor,  # [B, seq_len, vocab_size]\n        target_token_ids: torch.Tensor,  # [B, seq_len, K]\n        target_logprobs: torch.Tensor,  # [B, seq_len, K]\n        target_mask: torch.Tensor,  # [B, seq_len, K]\n        num_items_in_batch: int = -1,  # optional batch size for normalization\n    ) -> torch.Tensor:\n        # 1. Split along the \"token\" dimension (dim=1).\n        student_logits_chunks = student_logits.chunk(self.num_output_chunks, dim=1)\n        token_ids_chunks = target_token_ids.chunk(self.num_output_chunks, dim=1)\n        logprobs_chunks = target_logprobs.chunk(self.num_output_chunks, dim=1)\n        mask_chunks = target_mask.chunk(self.num_output_chunks, dim=1)\n\n        # We'll accumulate a global \"sum of losses\" and \"sum of valid tokens\"\n        # so that our final average is consistent with the entire sequence/batch.\n        total_loss = 0.0\n        total_valid_tokens = 0\n\n        # 2. Loop over each chunk and compute a chunk-specific loss.\n        for st_chunk, tid_chunk, lp_chunk, msk_chunk in zip(\n            student_logits_chunks,\n            token_ids_chunks,\n            logprobs_chunks,\n            mask_chunks,\n            strict=False,\n        ):\n            # We pass num_items_in_batch=-1 so that the kd_loss\n            # will average over *this chunk's* valid tokens only.\n            chunk_loss = loss(\n                student_logits=st_chunk,\n                target_token_ids=tid_chunk,\n                target_logprobs=lp_chunk,\n                target_mask=msk_chunk,\n                num_items_in_batch=-1,  # ensure per-chunk averaging by valid tokens\n                kd_temperature=self.kd_temperature,\n            )\n\n            # kd_loss returns an average over the chunk's valid tokens.\n            # We want a global average in the end, so we need to re‐weight\n            # by the number of valid tokens in this chunk and keep track of the total.\n            chunk_valid_mask = msk_chunk.to(torch.bool)\n            chunk_valid_count = chunk_valid_mask.sum()  # scalar tensor\n\n            # Re-scale \"chunk average\" back to \"chunk sum\"\n            chunk_loss_sum = chunk_loss * chunk_valid_count\n\n            total_loss += chunk_loss_sum\n            total_valid_tokens += chunk_valid_count\n\n        # 3. Normalize *once* at the end.\n        if num_items_in_batch > 0:\n            # If the user gave us a manual denominator (e.g. total items in batch),\n            # we divide by it. Typically used if each item is of different length.\n            final_loss = total_loss / float(num_items_in_batch)\n        else:\n            # Otherwise, divide by total valid tokens across all chunks.\n            # to get the same result as a non-chunked approach.\n            final_loss = total_loss / float(total_valid_tokens)\n\n        return final_loss\n"
  },
  {
    "path": "src/axolotl/integrations/kd/trainer.py",
    "content": "# Copyright 2024 Axolotl AI. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nKD trainer\n\"\"\"\n\nfrom typing_extensions import override\n\nfrom axolotl.core.trainers.base import AxolotlTrainer\n\nfrom .kernels.liger import LigerFusedLinearKLTopKLogprobLoss\n\n\nclass AxolotlKDTrainer(AxolotlTrainer):\n    \"\"\"\n    Custom trainer subclass for Knowledge Distillation (KD)\n    \"\"\"\n\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.model_accepts_loss_kwargs = True\n\n        loss_fn = LigerFusedLinearKLTopKLogprobLoss(\n            self.args.kd_ce_alpha,  # hard label loss\n            self.args.kd_alpha,  # kd loss\n            self.args.kd_temperature,\n            self.args.kd_beta or 0.0,\n            compute_ce_loss=bool(self.args.kd_ce_alpha),\n            normalize_topk=self.args.kd_normalize_topk,\n        )\n        target = self.model\n\n        # Unwrap PEFT wrapper\n        if hasattr(target, \"get_base_model\"):\n            target = target.get_base_model()\n\n        # Set on the actual model instance\n        target._loss_function = loss_fn\n\n    def _set_signature_columns_if_needed(self):\n        super()._set_signature_columns_if_needed()\n        columns_to_add = []\n        if self._signature_columns:\n            if \"target_logprobs\" not in self._signature_columns:\n                columns_to_add.append(\"target_logprobs\")\n            if \"target_token_ids\" not in self._signature_columns:\n                columns_to_add.append(\"target_token_ids\")\n            if \"target_mask\" not in self._signature_columns:\n                columns_to_add.append(\"target_mask\")\n            if columns_to_add:\n                self._signature_columns += columns_to_add\n\n    @override\n    def compute_loss(\n        self,\n        model,\n        inputs,\n        return_outputs=False,\n        num_items_in_batch=None,\n    ):\n        \"\"\"\n        How the loss is computed by Trainer. By default, all models return the loss in the first element.\n\n        Subclass and override for custom behavior.\n        \"\"\"\n        if (\n            self.args.sample_packing\n            and hasattr(inputs, \"attention_mask\")\n            and hasattr(inputs, \"position_ids\")\n        ):\n            del inputs[\"attention_mask\"]\n\n        if num_items_in_batch is None and \"labels\" in inputs:\n            num_items_in_batch = (inputs[\"labels\"] != -100).sum().item()\n\n        if self.model_accepts_loss_kwargs:\n            loss_kwargs = {}\n            if num_items_in_batch is not None:\n                loss_kwargs[\"num_items_in_batch\"] = num_items_in_batch\n            inputs = {**inputs, **loss_kwargs}\n\n        outputs = model(**inputs)\n\n        if isinstance(outputs, dict):\n            loss = outputs[\"loss\"]\n        elif isinstance(outputs, tuple):\n            loss = outputs[0]\n        else:\n            loss = outputs.loss if hasattr(outputs, \"loss\") else outputs\n\n        return (loss, outputs) if return_outputs else loss\n"
  },
  {
    "path": "src/axolotl/integrations/kd/utils.py",
    "content": "\"\"\"Helper KD utils\"\"\"\n\nimport math\nfrom typing import List, Union\n\nimport numpy as np\nimport torch\nfrom torch import FloatTensor, Tensor\n\n\ndef normalize_logprobs(logprobs: FloatTensor, topk: int) -> FloatTensor:\n    \"\"\"\n    Re-normalizes top-k raw logprobs as probabilities, and converts back to logprobs.\n    \"\"\"\n    # Ensure raw_logprobs matches kd_online_topk length for tensor operations\n    # This should ideally be handled by the caller ensuring correct padding/truncation first\n    if logprobs.shape[-1] != topk:\n        # pad last dimension of logprobs to match topk length with -inf\n        padding_len = topk - logprobs.shape[-1]\n        padding_tensor = torch.full(\n            (\n                *logprobs.shape[:-1],\n                padding_len,\n            ),  # Takes all dimensions of logprobs except the last, then appends padding_needed\n            float(\"-inf\"),\n            dtype=logprobs.dtype,\n            device=logprobs.device,\n        )\n        logprobs = torch.cat((logprobs, padding_tensor), dim=-1)\n\n    # Convert logprobs at T_online to probabilities\n    # use log sum exp trick to avoid underflow\n    position_logprobs_lse = torch.logsumexp(logprobs, dim=-1, keepdim=True)\n    teacher_probs_t_online = torch.exp(logprobs - position_logprobs_lse)\n\n    # Normalize probabilities (sum to 1)\n    # This is important if the top-k from server aren't a full distribution\n    teacher_probs_t_online_sum = teacher_probs_t_online.sum(dim=-1, keepdim=True)\n    teacher_probs_t_online = teacher_probs_t_online / teacher_probs_t_online_sum\n\n    final_logprobs_tensor = torch.log(teacher_probs_t_online)\n\n    return final_logprobs_tensor\n\n\ndef strided_chunk_views(\n    tensor: Union[np.ndarray, torch.Tensor],\n    chunks: int,\n    dim: int = 0,\n    stride: int = 1,\n    chunk_size: int | None = None,\n) -> List[Union[np.ndarray, torch.Tensor]]:\n    \"\"\"\n    Split a tensor into chunks along a dimension with striding, prioritizing views over copies.\n\n    Args:\n        tensor: Input tensor (numpy array or torch tensor)\n        chunks: Number of chunks to create\n        dim: Dimension along which to chunk (default: 0)\n        stride: Stride between chunk starting positions (default: 1)\n        chunk_size: Size of each chunk. If None, calculated automatically (default: None)\n\n    Returns:\n        List of tensor chunks (views when possible, copies when necessary)\n    \"\"\"\n\n    # Get the size of the specified dimension\n    dim_size = tensor.shape[dim]\n\n    # Calculate chunk size if not provided\n    if chunk_size is None:\n        chunk_size = (dim_size + chunks - 1) // chunks  # Ceiling division\n\n    chunks_list = []\n\n    for i in range(chunks):\n        start_idx = i * stride\n        end_idx = min(start_idx + chunk_size, dim_size)\n\n        # Break if we've gone beyond the tensor\n        if start_idx >= dim_size:\n            break\n\n        # Create slice objects for all dimensions\n        slices = [slice(None)] * tensor.ndim\n        slices[dim] = slice(start_idx, end_idx)\n\n        chunk = tensor[tuple(slices)]\n        chunks_list.append(chunk)\n\n    return chunks_list\n\n\ndef chunk_overlap(input_tensor: Tensor, chunks: int, dim: int = 0, overlap: int = 1):\n    dim_size = input_tensor.shape[dim]\n    stride = math.ceil(dim_size / chunks)\n\n    return strided_chunk_views(\n        input_tensor, chunks, dim, stride=stride, chunk_size=stride + overlap\n    )\n"
  },
  {
    "path": "src/axolotl/integrations/kernels/README.md",
    "content": "# Kernels Integration\n\nMoE (Mixture of Experts) kernels speed up training for MoE layers and reduce VRAM costs. In transformers v5, `batched_mm` and `grouped_mm` were integrated as built-in options via the `experts_implementation` config kwarg:\n\n```python\nclass ExpertsInterface(GeneralInterface):\n    _global_mapping = {\n        \"batched_mm\": batched_mm_experts_forward,\n        \"grouped_mm\": grouped_mm_experts_forward,\n    }\n```\n\nIn our custom integration, we add support for **ScatterMoE** and **SonicMoE**, which are more efficient and faster than `grouped_mm`.\n\n## Usage\n\nAdd the following to your axolotl YAML config:\n\n```yaml\nplugins:\n  - axolotl.integrations.kernels.KernelsPlugin\n\nuse_kernels: true\n\n# Choose one (mutually exclusive):\nuse_scattermoe: true\n# OR\nuse_sonicmoe: true\n```\n\n**Important:** Setting `experts_implementation` is incompatible with custom kernel options.\n\n### SonicMoE installation\n\n**Prerequisites:**\n- NVIDIA Hopper (H100, H200) or Blackwell (B200, GB200) GPU\n- CUDA 12.9+ (13.0+ for B300)\n- PyTorch 2.7+ (2.9.1 recommended)\n- For B300: Triton 3.6.0\n\n```bash\npip install --ignore-requires-python --no-deps \"sonic-moe @ git+https://github.com/Dao-AILab/sonic-moe.git@116e2df0a41874f77fa0ad269ce7df3f0cfcb956\" && pip install nvidia-cutlass-dsl==4.4.0 quack-kernels==0.2.5\n```\n\nSee the [SonicMoE installation guide](https://github.com/Dao-AILab/sonic-moe?tab=readme-ov-file#-installation) for the latest prerequisite details.\n\n**Note:** Blackwell support is in upstream beta. On Blackwell GPUs, Axolotl automatically sets `USE_QUACK_GEMM=1` to enable the Blackwell kernels.\n\n## How It Works\n\nThe `KernelsPlugin` runs before model loading and:\n\n### ScatterMoE\n1. Registers the ScatterMoE kernel from the local `libs/scattermoe_lora` package (includes fused LoRA support via Triton kernels).\n2. Patches the model's `SparseMoeBlock` forward method with the optimized ScatterMoE implementation.\n\n### SonicMoE\n1. Resolves the model's MoE block class(es) from `constants.py`.\n2. Patches the forward method with SonicMoE's optimized kernels and registers a weight converter for the interleaved gate/up projection format.\n3. Supports both softmax->topk and sigmoid->topk routing strategies.\n\nBoth paths use the shared `resolve_moe_block_classes` utility in `constants.py` for model-type-to-class resolution.\n\n#### Supported Models\n\nSee `constants.py` for the full list of supported model types (Qwen2-MoE, Qwen3-MoE, OLMoE, Mixtral, DeepSeek-V3, GLM-MoE, MiniMax, etc.).\n\n## Limitations\n\nScatterMoE uses a softmax -> topk routing, so results may be different for some model architectures as baseline (GPT-OSS, etc). Incompatible with `GLM_MOE_DSA` (GLM 5) and `GLM4_MOE_LITE` (GLM 4.7 Flash) at the moment.\n\nSonicMoE supports both softmax->topk and sigmoid->topk routing, covering a wider range of architectures.\n\nScatterMoE does not work for GLM4.7 Flash (glm4_moe_lite) atm.\n\n## Note on MegaBlocks\n\nWe tested [MegaBlocks](https://huggingface.co/kernels-community/megablocks) but were unable to ensure numerical accuracy, so we did not integrate it. It was also incompatible with many newer model architectures in transformers.\n"
  },
  {
    "path": "src/axolotl/integrations/kernels/__init__.py",
    "content": "from .args import KernelsArgs\nfrom .plugin import KernelsPlugin\n\n__all__ = [\n    \"KernelsArgs\",\n    \"KernelsPlugin\",\n]\n"
  },
  {
    "path": "src/axolotl/integrations/kernels/args.py",
    "content": "from pydantic import BaseModel, model_validator\n\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\n\nclass KernelsArgs(BaseModel):\n    use_scattermoe: bool | None = None\n    use_sonicmoe: bool | None = None\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_mutually_exclusive(cls, data):\n        if data.get(\"use_scattermoe\") and data.get(\"use_sonicmoe\"):\n            raise ValueError(\n                \"Cannot use both ScatterMoE and SonicMoE simultaneously. \"\n                \"Please set only one of `use_scattermoe` or `use_sonicmoe` to true.\"\n            )\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_use_kernels(cls, data):\n        if data.get(\"use_kernels\") is not True:\n            LOG.warning(\n                \"`use_kernels` must be set to True to use this. Automatically setting it to True.\"\n            )\n            data[\"use_kernels\"] = True\n\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_experts_implementation(cls, data):\n        experts_implementation = data.get(\"experts_implementation\")\n        if experts_implementation is None:\n            # transformers may default to batched_mm when unset\n            data[\"experts_implementation\"] = \"eager\"\n        elif experts_implementation != \"eager\":\n            LOG.warning(\n                \"`experts_implementation` must be set to 'eager' to use this. Automatically setting it to 'eager'.\"\n            )\n            data[\"experts_implementation\"] = \"eager\"\n\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def disable_mlp_kernel(cls, data):\n        if data.get(\"use_scattermoe\") is True or data.get(\"use_sonicmoe\") is True:\n            if data.get(\"lora_mlp_kernel\") is True:\n                LOG.warning(\n                    \"Disabling lora_mlp_kernel when using custom MoE kernels due to compatibility issues.\"\n                )\n                data[\"lora_mlp_kernel\"] = False\n            data[\"mlp_kernel\"] = False\n\n        return data\n"
  },
  {
    "path": "src/axolotl/integrations/kernels/autotune_callback.py",
    "content": "\"\"\"Trainer callback for reporting Triton autotune results from scattermoe-lora kernels.\"\"\"\n\nimport logging\n\nimport torch\nfrom transformers import (\n    TrainerCallback,\n    TrainerControl,\n    TrainerState,\n    TrainingArguments,\n)\n\nLOG = logging.getLogger(__name__)\n\n# Give up looking for autotune data after this many training steps.\n_MAX_POLL_STEP = 5\n\n\ndef _get_gpu_info() -> dict:\n    \"\"\"Return basic GPU identification for the current device.\"\"\"\n    if not torch.cuda.is_available():\n        return {}\n    try:\n        idx = torch.cuda.current_device()\n        props = torch.cuda.get_device_properties(idx)\n        return {\n            \"gpu_name\": props.name,\n            \"gpu_compute_capability\": f\"{props.major}.{props.minor}\",\n            \"gpu_memory_bytes\": props.total_memory,\n        }\n    except Exception:  # pylint: disable=broad-exception-caught\n        return {}\n\n\ndef _get_smem_capacity() -> dict:\n    \"\"\"Return shared memory capacity from the runtime lora_ops module.\"\"\"\n    try:\n        from axolotl.integrations.kernels.autotune_collector import (\n            _find_lora_ops_module,\n        )\n\n        lora_ops = _find_lora_ops_module()\n        if lora_ops is None:\n            return {}\n        fn = getattr(lora_ops, \"_get_smem_capacity\", None)\n        if fn is None:\n            return {}\n        return {\"smem_capacity_bytes\": fn()}\n    except Exception:  # pylint: disable=broad-exception-caught\n        return {}\n\n\nclass AutotuneReportCallback(TrainerCallback):\n    \"\"\"Reports Triton kernel autotune selections via telemetry.\n\n    Fires **once** after the first training step completes (step 1), at\n    which point the forward and backward passes have both run and the\n    autotuned kernels have populated their caches.  If for some reason\n    the caches are still empty (e.g. the kernel was never invoked), the\n    callback retries on subsequent steps up to ``_MAX_POLL_STEP`` and\n    then stops polling.\n\n    After reporting (or giving up) every subsequent ``on_step_end``\n    call short-circuits on the ``_reported`` flag — zero hot-path cost.\n    \"\"\"\n\n    def __init__(self):\n        self._reported = False\n\n    # pylint: disable=unused-argument\n    def on_step_end(\n        self,\n        args: TrainingArguments,\n        state: TrainerState,\n        control: TrainerControl,\n        **kwargs,\n    ):\n        if self._reported:\n            return\n\n        # Lazy import — Triton / scattermoe kernels may not be installed.\n        from axolotl.integrations.kernels.autotune_collector import (\n            collect_autotune_configs,\n        )\n\n        configs = collect_autotune_configs()\n\n        if not configs:\n            if state.global_step >= _MAX_POLL_STEP:\n                LOG.debug(\n                    \"No autotune data found after %d steps; giving up.\",\n                    state.global_step,\n                )\n                self._reported = True\n            return\n\n        self._reported = True\n\n        from axolotl.telemetry.manager import TelemetryManager\n\n        telemetry_manager = TelemetryManager.get_instance()\n        if not telemetry_manager.enabled:\n            return\n\n        properties = {\n            \"kernel_count\": len(configs),\n            \"kernels\": configs,\n        }\n        properties.update(_get_gpu_info())\n        properties.update(_get_smem_capacity())\n\n        telemetry_manager.send_event(\n            event_type=\"scattermoe-autotune\",\n            properties=properties,\n        )\n\n        LOG.info(\n            \"Reported %d scattermoe kernel autotune config(s) to telemetry.\",\n            len(configs),\n        )\n"
  },
  {
    "path": "src/axolotl/integrations/kernels/autotune_collector.py",
    "content": "\"\"\"Collect Triton autotune results from scattermoe-lora kernels.\n\nThis module reads the ``.cache`` attribute from Triton ``@triton.autotune``\ndecorated kernel objects and returns structured dicts describing the selected\nconfigurations.  It has **no** telemetry dependency — callers decide what to\ndo with the data.\n\"\"\"\n\nimport logging\nimport sys\nfrom types import ModuleType\nfrom typing import Any\n\nLOG = logging.getLogger(__name__)\n\n# (human-readable name, attribute on the lora_ops module)\n_KERNEL_REGISTRY: list[tuple[str, str]] = [\n    (\"scatter2scatter_lora_fwd\", \"_scatter2scatter_lora\"),\n    (\"scatter2scatter_lora_dX\", \"_scatter2scatter_lora_dX\"),\n    (\"group_bwd_lora\", \"_group_bwd_lora\"),\n    (\"group_bwd_lora_fused\", \"_group_bwd_lora_fused\"),\n]\n\n# The autotune key declared on every kernel: key=[\"M\", \"N\", \"K\"]\n_KEY_NAMES: list[str] = [\"M\", \"N\", \"K\"]\n\n\ndef _parse_key_tuple(key_tuple: tuple) -> dict[str, Any]:\n    \"\"\"Turn the autotune cache key tuple into a labelled dict.\n\n    Triton builds the cache key from the values of the declared ``key``\n    args (``M``, ``N``, ``K``) followed by dtype signature elements.\n    We label the first three and store the rest under ``_extra``.\n    \"\"\"\n    result: dict[str, Any] = {}\n    for i, name in enumerate(_KEY_NAMES):\n        if i < len(key_tuple):\n            result[name] = key_tuple[i]\n    if len(key_tuple) > len(_KEY_NAMES):\n        result[\"_extra\"] = [str(v) for v in key_tuple[len(_KEY_NAMES) :]]\n    return result\n\n\ndef _find_lora_ops_module() -> ModuleType | None:\n    \"\"\"Locate the *runtime* ``lora_ops`` module in ``sys.modules``.\n\n    The HF ``kernels`` package loads ``scattermoe_lora`` via\n    ``import_from_path`` which registers it in ``sys.modules`` under a\n    hash-suffixed name (e.g. ``scattermoe_lora_a1b2c3d4``).  A normal\n    import (``from axolotl.integrations.kernels...``) would create a\n    *separate* module instance whose kernel objects have empty\n    ``.cache`` dicts because autotuning ran on the runtime copy.\n\n    We search ``sys.modules`` for any module whose name contains\n    ``lora_ops`` and that has the ``_scatter2scatter_lora`` kernel\n    attribute — that is the runtime copy with populated caches.\n    \"\"\"\n    for name, module in list(sys.modules.items()):\n        if (\n            module is not None\n            and \"lora_ops\" in name\n            and hasattr(module, \"_scatter2scatter_lora\")\n        ):\n            return module\n    return None\n\n\ndef collect_autotune_configs() -> list[dict[str, Any]]:\n    \"\"\"Read autotune caches from the four scattermoe-lora kernels.\n\n    Returns a (possibly empty) list of dicts, each containing:\n\n    * ``kernel`` – human-readable kernel name\n    * ``key``    – dict with the ``M``/``N``/``K`` problem dimensions\n    * ``config`` – dict with the selected tile sizes, ``num_warps``,\n      and ``num_stages``\n\n    Returns ``[]`` if the kernel module cannot be found or if no\n    autotune cache entries exist yet.\n    \"\"\"\n    lora_ops = _find_lora_ops_module()\n    if lora_ops is None:\n        LOG.debug(\n            \"lora_ops module not found in sys.modules; skipping autotune collection\"\n        )\n        return []\n\n    results: list[dict[str, Any]] = []\n\n    for friendly_name, attr_name in _KERNEL_REGISTRY:\n        kernel_fn = getattr(lora_ops, attr_name, None)\n        if kernel_fn is None:\n            continue\n\n        cache = getattr(kernel_fn, \"cache\", None)\n        if not cache:\n            continue\n\n        for key_tuple, config in cache.items():\n            config_dict = dict(config.kwargs)\n            config_dict[\"num_warps\"] = config.num_warps\n            config_dict[\"num_stages\"] = config.num_stages\n            if getattr(config, \"num_ctas\", None) is not None:\n                config_dict[\"num_ctas\"] = config.num_ctas\n\n            results.append(\n                {\n                    \"kernel\": friendly_name,\n                    \"key\": _parse_key_tuple(key_tuple),\n                    \"config\": config_dict,\n                }\n            )\n\n    return results\n"
  },
  {
    "path": "src/axolotl/integrations/kernels/constants.py",
    "content": "\"\"\"\nSupported MoE block mappings for kernel integrations.\n\nMaps model_type to the SparseMoeBlock class name(s) in transformers.\nUsed by both ScatterMoE and SonicMoE kernel paths.\n\nValues can be a single class name (str) or a list of class names for models\nwith multiple MoE block types (e.g. qwen3_omni_moe has Thinker + Talker).\n\"\"\"\n\nimport importlib\n\nSPARSE_MOE_BLOCK = {\n    # softmax -> topk routing\n    \"qwen2_moe\": \"Qwen2MoeSparseMoeBlock\",\n    \"qwen3_moe\": \"Qwen3MoeSparseMoeBlock\",\n    \"qwen3_5_moe\": \"Qwen3_5MoeSparseMoeBlock\",\n    \"qwen3_5_moe_text\": \"Qwen3_5MoeSparseMoeBlock\",\n    \"qwen3_next\": \"Qwen3NextSparseMoeBlock\",\n    \"qwen3_vl_moe\": \"Qwen3VLMoeTextSparseMoeBlock\",\n    # qwen3_omni_moe: Thinker (standard) + Talker (shared experts + shared_expert_gate)\n    \"qwen3_omni_moe\": [\n        \"Qwen3OmniMoeThinkerTextSparseMoeBlock\",\n        \"Qwen3OmniMoeTalkerTextSparseMoeBlock\",\n    ],\n    \"olmoe\": \"OlmoeSparseMoeBlock\",\n    \"mixtral\": \"MixtralSparseMoeBlock\",\n    \"minimax\": \"MiniMaxSparseMoeBlock\",\n    # softmax -> topk routing (with group-based expert selection)\n    \"mistral4\": \"Mistral4MoE\",\n    # sigmoid -> topk routing (with group-based expert selection)\n    \"glm_moe_dsa\": \"GlmMoeDsaMoE\",\n    \"deepseek_v3\": \"DeepseekV3MoE\",\n    \"glm4_moe\": \"Glm4MoeMoE\",\n    \"glm4_moe_lite\": \"Glm4MoeLiteMoE\",\n    \"glm4v_moe\": \"Glm4vMoeTextMoE\",\n    # sigmoid -> topk routing (no group selection)\n    \"minimax_m2\": \"MiniMaxM2SparseMoeBlock\",\n    # Models below need custom routing (not yet implemented):\n    # \"ernie4_5_moe\": \"Ernie4_5_MoeSparseMoeBlock\",  # softmax->topk, e_score_correction_bias between softmax and topk\n    # \"deepseek_v2\": \"DeepseekV2Moe\",  # softmax->topk, group_limited_greedy, different attr names (num_group)\n    # \"hunyuan_v1_moe\": \"HunYuanMoEV1Moe\",  # softmax->topk, gate.wg (not gate.weight), scatter routing\n    # \"gpt_oss\": \"GptOssMLP\",  # topk->softmax, transposed layout [E,H,2*I], custom GLU, expert biases\n}\n\n\ndef resolve_moe_block_classes(model_type: str):\n    \"\"\"Resolve all MoE block classes from transformers for the given model type.\n\n    Returns a list of classes (one for most models, multiple for models with\n    distinct MoE block types like qwen3_omni_moe).\n    \"\"\"\n    entry = SPARSE_MOE_BLOCK.get(model_type)\n    if entry is None:\n        raise ValueError(\n            f\"Unsupported MoE model type '{model_type}'. \"\n            f\"Supported types: {list(SPARSE_MOE_BLOCK.keys())}\"\n        )\n\n    cls_names = entry if isinstance(entry, list) else [entry]\n    module_path = f\"transformers.models.{model_type}.modeling_{model_type}\"\n    try:\n        module = importlib.import_module(module_path)\n    except ModuleNotFoundError:\n        # Text sub-model types (e.g. qwen3_5_moe_text) share the parent module\n        if model_type.endswith(\"_text\"):\n            parent_type = model_type.removesuffix(\"_text\")\n            module_path = f\"transformers.models.{parent_type}.modeling_{parent_type}\"\n            module = importlib.import_module(module_path)\n        else:\n            raise\n\n    classes = []\n    for cls_name in cls_names:\n        moe_cls = getattr(module, cls_name, None)\n        if moe_cls is None:\n            raise ValueError(f\"Could not find class '{cls_name}' in '{module_path}'\")\n        classes.append(moe_cls)\n\n    return classes\n"
  },
  {
    "path": "src/axolotl/integrations/kernels/libs/__init__.py",
    "content": ""
  },
  {
    "path": "src/axolotl/integrations/kernels/libs/scattermoe_lora/__init__.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n# Copyright (c) Axolotl AI\n# Licensed under the Apache License, Version 2.0\n\nfrom . import layers\nfrom .lora_ops import ParallelExperts\nfrom .parallel_experts import flatten_sort_count, parallel_linear\nfrom .parallel_linear_lora import ScatterMoELoRA, parallel_linear_lora\n\n__all__ = [\n    \"layers\",\n    \"ParallelExperts\",\n    \"flatten_sort_count\",\n    \"parallel_linear\",\n    \"ScatterMoELoRA\",\n    \"parallel_linear_lora\",\n    \"lora_ops\",\n]\n"
  },
  {
    "path": "src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/__init__.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n#\n# Original work Copyright (c) Shawn Tan and ScatterMoE Contributors\n# Adapted from https://github.com/shawntan/scattermoe\n# See https://github.com/shawntan/scattermoe/blob/main/LICENSE\n#\n# Modifications and LoRA adaptation Copyright (c) Axolotl AI\n# Licensed under the Apache License, Version 2.0\n\nfrom . import lora_ops, ops\n\n__all__ = [\"ops\", \"lora_ops\"]\n"
  },
  {
    "path": "src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/lora_ops.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n# Copyright (c) Axolotl AI\n# Licensed under the Apache License, Version 2.0\n\n\"\"\"\nFused ScatterMoE + LoRA Triton Kernels\n=======================================\n\nProvides fused forward and backward kernels for ScatterMoE with LoRA adapters.\n\nForward: Y = X @ W + scaling * (X @ A^T) @ B^T\nBackward (LoRA training, W frozen):\n  - dX = dY @ W^T + scaling * (dY @ B) @ A    (input gradient)\n  - dA = scaling * (dY @ B)^T @ X              (LoRA A gradient)\n  - dB = scaling * dY^T @ (X @ A^T)            (LoRA B gradient)\n\nLoRA weight layout (from PEFT ParamWrapper):\n  - A: [r*E, K]  -- for expert e, rows [e*r : (e+1)*r] give A_e of shape [r, K]\n  - B: [N, r*E]  -- for expert e, cols [e*r : (e+1)*r] give B_e of shape [N, r]\n\nKey design decisions:\n  - The forward kernel fuses X@W and X@A^T in the same K-loop for data reuse on X,\n    then computes (X@A^T) @ B^T in the epilogue.\n  - The backward dA/dB kernel operates on grouped (expert-contiguous) data and\n    iterates over tokens per expert, accumulating gradients in registers.\n  - R (LoRA rank) is a tl.constexpr, allowing tl.arange(0, R). We pad R to a\n    power-of-2 for Triton tile compatibility; typical ranks (4, 8, 16, 32, 64)\n    already satisfy this.\n\"\"\"\n\nfrom itertools import product\nfrom typing import Optional\n\nimport torch\nimport triton\nimport triton.language as tl\n\n# =============================================================================\n# Configuration\n# =============================================================================\n\nBLOCK_M = 128\nALLOW_TF32 = True\n\n\ndef _next_power_of_2(n: int) -> int:\n    \"\"\"Round up to next power of 2.\"\"\"\n    n -= 1\n    n |= n >> 1\n    n |= n >> 2\n    n |= n >> 4\n    n |= n >> 8\n    n |= n >> 16\n    return n + 1\n\n\n# Triton tl.dot requires minimum tile dimensions of 16 on modern GPUs.\nMIN_TRITON_DOT_SIZE = 16\n\n\ndef _block_r_for_rank(r: int) -> int:\n    \"\"\"Compute BLOCK_R: next power-of-2 >= max(r, MIN_TRITON_DOT_SIZE).\"\"\"\n    return _next_power_of_2(max(r, MIN_TRITON_DOT_SIZE))\n\n\n# =============================================================================\n# Token Rounding: pad expert counts to BLOCK_M multiples\n# =============================================================================\n\n\ndef round_expert_counts(\n    sorted_expert_idxs: torch.Tensor,\n    sorted_scattered_idxs: torch.Tensor,\n    expert_offsets: torch.Tensor,\n    E: int,\n    block_m: int = BLOCK_M,\n) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n    \"\"\"\n    Pad each expert's token count to a multiple of block_m to eliminate\n    partial-tile waste in the backward kernel.\n\n    Padding is done by duplicating the last valid token index for each expert.\n    The kernel's M_mask = M_idx < real_end_idx masks these padding entries, so\n    correctness is preserved (they contribute 0 to the accumulation via other=0.0).\n\n    This only helps the backward dA/dB kernel where per-expert iteration is\n    explicit. The forward scatter2scatter kernel handles partial tiles via masking.\n\n    Args:\n        sorted_expert_idxs: Expert assignments sorted [M*k]\n        sorted_scattered_idxs: Original indices sorted [M*k]\n        expert_offsets: Cumulative token counts per expert [E]\n        E: Number of experts\n        block_m: Block size for token dimension (default: BLOCK_M)\n\n    Returns:\n        padded_expert_idxs: [M_padded] expert assignments with padding\n        padded_scattered_idxs: [M_padded] original indices with padding\n        padded_offsets: [E] cumulative padded counts (for kernel iteration range)\n        real_offsets: [E] original cumulative counts (for M_mask in kernel)\n    \"\"\"\n    device = sorted_expert_idxs.device\n\n    # Compute per-expert counts\n    counts = torch.zeros(E, dtype=torch.int64, device=device)\n    prev = 0\n    for e in range(E):\n        curr = expert_offsets[e].item()\n        counts[e] = curr - prev\n        prev = curr\n\n    # Round up each count to multiple of block_m\n    padded_counts = ((counts + block_m - 1) // block_m) * block_m\n    # Experts with 0 tokens stay at 0\n    padded_counts = torch.where(\n        counts > 0, padded_counts, torch.zeros_like(padded_counts)\n    )\n    total_padded = padded_counts.sum().item()\n\n    padded_expert_idxs = torch.empty(\n        total_padded, dtype=sorted_expert_idxs.dtype, device=device\n    )\n    padded_scattered_idxs = torch.empty(\n        total_padded, dtype=sorted_scattered_idxs.dtype, device=device\n    )\n\n    src_offset = 0\n    dst_offset = 0\n    for e in range(E):\n        count = counts[e].item()\n        padded_count = padded_counts[e].item()\n\n        if count > 0:\n            # Copy original tokens\n            padded_expert_idxs[dst_offset : dst_offset + count] = sorted_expert_idxs[\n                src_offset : src_offset + count\n            ]\n            padded_scattered_idxs[dst_offset : dst_offset + count] = (\n                sorted_scattered_idxs[src_offset : src_offset + count]\n            )\n\n            # Pad with last valid token (masked out by kernel via M_mask)\n            if padded_count > count:\n                padded_expert_idxs[dst_offset + count : dst_offset + padded_count] = (\n                    sorted_expert_idxs[src_offset + count - 1]\n                )\n                padded_scattered_idxs[\n                    dst_offset + count : dst_offset + padded_count\n                ] = sorted_scattered_idxs[src_offset + count - 1]\n\n        src_offset += count\n        dst_offset += padded_count\n\n    # Padded offsets: cumulative padded counts (for iteration range in kernel)\n    padded_offsets = padded_counts.cumsum(-1).to(expert_offsets.dtype)\n    # Real offsets: original cumulative counts (for M_mask in kernel)\n    real_offsets = expert_offsets.clone()\n\n    return padded_expert_idxs, padded_scattered_idxs, padded_offsets, real_offsets\n\n\n# =============================================================================\n# Autotuning: SMEM estimation and config pruning\n# =============================================================================\n\n_SMEM_CAPACITY: int | None = None\n\n\ndef _get_smem_capacity() -> int:\n    \"\"\"Get device shared memory capacity (bytes). Cached after first call.\"\"\"\n    global _SMEM_CAPACITY\n    if _SMEM_CAPACITY is None:\n        props = triton.runtime.driver.active.utils.get_device_properties(\n            torch.cuda.current_device()\n        )\n        _SMEM_CAPACITY = props[\"max_shared_mem\"]\n    return _SMEM_CAPACITY\n\n\ndef _estimate_smem_usage(\n    num_stages: int, BLOCK_M: int, BLOCK_N: int, BLOCK_K: int, dtype_bytes: int = 2\n) -> int:\n    \"\"\"Estimate shared memory in bytes for a GEMM-style tile.\n\n    Formula: stages * BLOCK_K * (BLOCK_M + BLOCK_N) + BLOCK_M * BLOCK_N\n    Multiply by dtype_bytes (2 for fp16/bf16).\n    \"\"\"\n    return (\n        num_stages * BLOCK_K * (BLOCK_M + BLOCK_N) + BLOCK_M * BLOCK_N\n    ) * dtype_bytes\n\n\n# Conservative margin (bytes) subtracted from SMEM capacity to account for\n# estimation inaccuracies and kernel overhead (registers spilled to SMEM, etc.)\n_SMEM_SLACK = 10_000\n\n\ndef _estimate_register_pressure(\n    num_warps: int,\n    *tile_sizes: tuple[int, int],\n) -> float:\n    \"\"\"Rough estimate of per-thread register footprint from live tile sizes.\n\n    This is a heuristic, NOT an accurate register count.  Triton uses tensor\n    core MMA fragments that pack multiple elements per register, and can spill\n    to local memory when the hardware limit (255 regs/thread) is exceeded.\n\n    The estimate is used to prune only truly extreme configs that would cause\n    excessive spilling or compilation failures.  The threshold is set high\n    (``_MAX_REGS_SOFT_LIMIT``) because the heuristic overestimates — it\n    doesn't account for MMA fragment packing.  Configs like M=64,N=64,K=64\n    (est ~520) work fine in practice via spilling.\n\n    Returns estimated registers per thread.\n    \"\"\"\n    # Each thread in a warp holds ~1/32 of the tile elements\n    tile_regs = sum(r * c for r, c in tile_sizes) / 32\n    scalar_overhead = 40\n    return tile_regs + scalar_overhead\n\n\n# Soft limit for register pressure pruning.  Only prune configs with extreme\n# tile products (e.g. M=128,K=256,N=256) that reliably crash on Blackwell.\n# Moderate configs (M=64,N=64,K=64, est ~520) work via register spilling.\n_MAX_REGS_SOFT_LIMIT = 1024\n\n\n# =============================================================================\n# Forward Kernel: scatter2scatter with fused LoRA\n# =============================================================================\n\n\n@triton.jit\ndef _compute_expert_block_lora(\n    E_idx,\n    E_mask,\n    M_in_idx,\n    N_block,\n    N_mask,\n    # Base weight\n    X_ptr,\n    stride_xm,\n    stride_xk,\n    W_ptr,\n    stride_we,\n    stride_wk,\n    stride_wn,\n    # LoRA weights\n    A_ptr,\n    stride_ar,\n    stride_ak,  # A: [r*E, K], stride_ar = stride for r*E dim, stride_ak = stride for K dim\n    B_ptr,\n    stride_bn,\n    stride_br,  # B: [N, r*E], stride_bn = stride for N dim, stride_br = stride for r*E dim\n    # Dimensions\n    K,\n    ACTUAL_R: tl.constexpr,  # True LoRA rank (for indexing into weight arrays)\n    acc,\n    no_k_mask,\n    BLOCK_M: tl.constexpr,\n    BLOCK_K: tl.constexpr,\n    BLOCK_N: tl.constexpr,\n    BLOCK_R: tl.constexpr,  # Padded tile size >= max(ACTUAL_R, 16)\n    scaling,\n    allow_tf32: tl.constexpr,\n):\n    \"\"\"\n    Compute Y_block = X_block @ W_e + scaling * (X_block @ A_e^T) @ B_e^T\n\n    for tokens in this M-block assigned to expert E_idx.\n\n    ACTUAL_R is the true LoRA rank used for indexing into A[e*r:(e+1)*r, :].\n    BLOCK_R >= ACTUAL_R is the padded tile dimension (must be >= 16 for tl.dot).\n    When BLOCK_R > ACTUAL_R, loads are masked on the R dimension.\n    \"\"\"\n    K_block = tl.arange(0, BLOCK_K)\n    R_block = tl.arange(0, BLOCK_R)\n    R_mask = R_block < ACTUAL_R  # Mask for padding when BLOCK_R > ACTUAL_R\n\n    # Base weight pointers: W[E_idx, :, :] is [K, N], load [BLOCK_K, BLOCK_N]\n    X_blk_ptrs = X_ptr + M_in_idx[:, None] * stride_xm + K_block[None, :] * stride_xk\n    W_blk_ptrs = (\n        W_ptr\n        + E_idx * stride_we\n        + K_block[:, None] * stride_wk\n        + N_block[None, :] * stride_wn\n    )\n\n    # LoRA A pointers: A[e*ACTUAL_R:(e+1)*ACTUAL_R, :] for expert e, shape [r, K]\n    A_expert_offset = E_idx * ACTUAL_R\n    A_blk_ptrs = (\n        A_ptr\n        + (A_expert_offset + R_block)[:, None] * stride_ar\n        + K_block[None, :] * stride_ak\n    )\n\n    iters = tl.cdiv(K, BLOCK_K)\n\n    # Accumulator for X @ A^T: [BLOCK_M, BLOCK_R]\n    xa_acc = tl.zeros((BLOCK_M, BLOCK_R), dtype=tl.float32)\n\n    # Determine the input element type for consistent casting.\n    # Masked tl.load with other=0.0 can upcast bf16->fp32 in some Triton versions,\n    # causing dtype mismatches in tl.dot.  We cast all tiles to the same type.\n    INPUT_DTYPE = X_ptr.dtype.element_ty\n\n    for i in range(iters):\n        if no_k_mask:\n            x = tl.load(X_blk_ptrs, mask=E_mask[:, None], other=0.0).to(INPUT_DTYPE)\n            w = tl.load(W_blk_ptrs, mask=N_mask[None, :], other=0.0).to(INPUT_DTYPE)\n            a = tl.load(A_blk_ptrs, mask=R_mask[:, None], other=0.0).to(INPUT_DTYPE)\n        else:\n            K_mask = (i * BLOCK_K + K_block) < K\n            x = tl.load(\n                X_blk_ptrs, mask=E_mask[:, None] & K_mask[None, :], other=0.0\n            ).to(INPUT_DTYPE)\n            w = tl.load(\n                W_blk_ptrs, mask=K_mask[:, None] & N_mask[None, :], other=0.0\n            ).to(INPUT_DTYPE)\n            a = tl.load(\n                A_blk_ptrs, mask=R_mask[:, None] & K_mask[None, :], other=0.0\n            ).to(INPUT_DTYPE)\n\n        # Base: acc += X @ W  ([M, K] @ [K, N] -> [M, N])\n        acc += tl.dot(x, w, allow_tf32=allow_tf32).to(tl.float32)\n\n        # LoRA: xa_acc += X @ A^T  ([M, K] @ [K, R] -> [M, R])\n        xa_acc += tl.dot(x, tl.trans(a), allow_tf32=allow_tf32).to(tl.float32)\n\n        X_blk_ptrs += BLOCK_K * stride_xk\n        W_blk_ptrs += BLOCK_K * stride_wk\n        A_blk_ptrs += BLOCK_K * stride_ak\n\n    # Epilogue: load B[e] and compute (X @ A^T) @ B^T\n    # B[e] is B[:, e*ACTUAL_R:(e+1)*ACTUAL_R], shape [N, r]. Load [BLOCK_N, BLOCK_R].\n    B_expert_offset = E_idx * ACTUAL_R\n    B_blk_ptrs = (\n        B_ptr\n        + N_block[:, None] * stride_bn\n        + (B_expert_offset + R_block)[None, :] * stride_br\n    )\n    b = tl.load(\n        B_blk_ptrs, mask=N_mask[:, None] & R_mask[None, :], other=0.0\n    )  # [BLOCK_N, BLOCK_R]\n\n    # tl.dot requires non-float32 inputs (tensor cores); cast back to input dtype\n    b_inp = b.to(INPUT_DTYPE)\n\n    # (X @ A^T) @ B^T: [M, R] @ [R, N] -> [M, N]\n    lora_out = tl.dot(xa_acc.to(INPUT_DTYPE), tl.trans(b_inp), allow_tf32=allow_tf32)\n\n    acc += scaling * lora_out\n    return acc\n\n\ndef _scatter2scatter_lora_configs():\n    \"\"\"Generate forward kernel autotune configs.\n\n    Search space includes BLOCK_M to allow trading token-tile size for\n    larger BLOCK_K/BLOCK_N tiles.  On GPUs with ~99KB SMEM, BLOCK_M=128\n    forces BLOCK_K=32 and BLOCK_N=32; BLOCK_M=64 allows BLOCK_K=128\n    (4× fewer inner-loop iterations).\n\n    Search space:\n      BLOCK_M:    {32, 64, 128}\n      BLOCK_N:    {32, 64, 128, 256}\n      BLOCK_K:    {32, 64, 128}\n      num_warps:  {4, 8}\n      num_stages: {3, 4, 5}\n    \"\"\"\n    configs = []\n    for block_m, block_n, block_k, warps, stages in product(\n        [32, 64, 128],  # BLOCK_M\n        [32, 64, 128, 256],  # BLOCK_N\n        [32, 64, 128],  # BLOCK_K\n        [4, 8],  # num_warps\n        [3, 4, 5],  # num_stages\n    ):\n        configs.append(\n            triton.Config(\n                {\"BLOCK_M\": block_m, \"BLOCK_N\": block_n, \"BLOCK_K\": block_k},\n                num_stages=stages,\n                num_warps=warps,\n            )\n        )\n    return configs\n\n\ndef _prune_fwd_configs(configs, named_args, **kwargs):\n    \"\"\"Prune forward configs based on SMEM capacity and register pressure.\n\n    The forward kernel inner loop loads three tiles per pipeline stage:\n      X[BLOCK_M, BLOCK_K], W[BLOCK_K, BLOCK_N], A[BLOCK_R, BLOCK_K].\n    The base estimate only accounts for X and W. We add:\n      - A tile [BLOCK_R, BLOCK_K] per pipeline stage (loaded in the inner loop)\n      - B tile [BLOCK_N, BLOCK_R] loaded once in the epilogue\n      - Extra headroom for compiler overhead (register spills, metadata)\n    \"\"\"\n    smem_cap = _get_smem_capacity()\n\n    # Get BLOCK_R from named_args if available, else assume worst case\n    block_r = named_args.get(\"BLOCK_R\", 64)\n\n    scored = []\n    for config in configs:\n        block_m = config.kwargs[\"BLOCK_M\"]\n        block_n = config.kwargs[\"BLOCK_N\"]\n        block_k = config.kwargs[\"BLOCK_K\"]\n        # Base: stages * BLOCK_K * (BLOCK_M + BLOCK_N) + BLOCK_M * BLOCK_N\n        smem_base = _estimate_smem_usage(config.num_stages, block_m, block_n, block_k)\n        # A tile [BLOCK_R, BLOCK_K] loaded per stage in the inner loop\n        smem_lora_loop = config.num_stages * block_r * block_k * 2\n        # B tile [BLOCK_N, BLOCK_R] loaded once in epilogue\n        smem_lora_epilogue = block_n * block_r * 2\n        smem = smem_base + smem_lora_loop + smem_lora_epilogue\n\n        # Register pressure: live tiles are acc[M,N], xa_acc[M,R],\n        # x[M,K], w[K,N], a[R,K], plus epilogue b[N,R]\n        est_regs = _estimate_register_pressure(\n            config.num_warps,\n            (block_m, block_n),  # acc\n            (block_m, block_r),  # xa_acc\n            (block_m, block_k),  # x tile\n            (block_k, block_n),  # w tile\n            (block_r, block_k),  # a tile\n            (block_n, block_r),  # b tile (epilogue)\n        )\n        if est_regs > _MAX_REGS_SOFT_LIMIT:\n            continue\n\n        scored.append((smem, config))\n\n    pruned = [c for s, c in scored if s <= smem_cap - _SMEM_SLACK]\n    if pruned:\n        return pruned\n    if scored:\n        # All surviving configs exceed SMEM — return the one with smallest usage\n        scored.sort(key=lambda x: x[0])\n        return [scored[0][1]]\n    # All configs pruned by register pressure — fall back to smallest tiles\n    return [\n        min(\n            configs,\n            key=lambda c: (\n                c.kwargs[\"BLOCK_M\"] * c.kwargs[\"BLOCK_N\"] * c.kwargs[\"BLOCK_K\"]\n            ),\n        )\n    ]\n\n\n@triton.autotune(\n    configs=_scatter2scatter_lora_configs(),\n    key=[\"M\", \"N\", \"K\"],\n    prune_configs_by={\"early_config_prune\": _prune_fwd_configs},\n)\n@triton.heuristics(\n    {\n        \"NO_K_MASK\": lambda args: (args[\"K\"] % args[\"BLOCK_K\"]) == 0,\n        \"NO_N_MASK\": lambda args: (args[\"N\"] % args[\"BLOCK_N\"]) == 0,\n    }\n)\n@triton.jit\ndef _scatter2scatter_lora(\n    # Input/Output\n    X_ptr,\n    stride_xm: tl.constexpr,\n    stride_xk: tl.constexpr,\n    W_ptr,\n    stride_we,\n    stride_wk: tl.constexpr,\n    stride_wn: tl.constexpr,\n    Y_ptr,\n    stride_ym: tl.constexpr,\n    stride_yn: tl.constexpr,\n    # Bias\n    Bias_ptr,\n    stride_bias_e: tl.constexpr,\n    stride_bias_n: tl.constexpr,\n    # LoRA weights\n    LA_ptr,\n    stride_la_r,\n    stride_la_k,  # A: [r*E, K]\n    LB_ptr,\n    stride_lb_n,\n    stride_lb_r,  # B: [N, r*E]\n    # Routing\n    grouped_idx_ptr,\n    expert_idxs_ptr,\n    # Dimensions\n    FAN_OUT: tl.constexpr,\n    M,\n    K: tl.constexpr,\n    N: tl.constexpr,\n    E: tl.constexpr,\n    ACTUAL_R: tl.constexpr,  # True LoRA rank (for weight indexing)\n    # Block sizes\n    BLOCK_M: tl.constexpr,\n    BLOCK_N: tl.constexpr,\n    BLOCK_K: tl.constexpr,\n    BLOCK_R: tl.constexpr,  # Padded tile size >= max(ACTUAL_R, 16)\n    # Config\n    ACC_TYPE: tl.constexpr,\n    scaling,\n    allow_tf32: tl.constexpr,\n    x_grouped: tl.constexpr,\n    y_grouped: tl.constexpr,\n    NO_K_MASK: tl.constexpr,\n    NO_N_MASK: tl.constexpr,\n):\n    \"\"\"\n    Fused scatter2scatter with LoRA: Y = X @ W + scaling * (X @ A^T) @ B^T + bias\n    \"\"\"\n    pid = tl.program_id(axis=0)\n\n    N_BLOCK_COUNT = tl.cdiv(N, BLOCK_N)\n    M_block_id = pid // N_BLOCK_COUNT\n    N_block_id = pid % N_BLOCK_COUNT\n\n    M_block = M_block_id * BLOCK_M + tl.arange(0, BLOCK_M)\n    N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)\n    N_mask = N_block < N\n    M_boundary_mask = M_block < (FAN_OUT * M)\n\n    E_idxs = tl.load(expert_idxs_ptr + M_block, mask=M_boundary_mask, other=E)\n\n    no_k_mask = NO_K_MASK\n\n    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)\n\n    E_first_idx = tl.min(E_idxs)\n    E_last_idx = tl.minimum(tl.max(E_idxs), E - 1)\n    M_idx = tl.load(grouped_idx_ptr + M_block, mask=M_boundary_mask).to(tl.int32)\n\n    for E_idx in range(E_first_idx, E_last_idx + 1):\n        E_mask = E_idxs == E_idx\n        if x_grouped:\n            M_in_idx = M_block\n        else:\n            M_in_idx = M_idx // FAN_OUT\n\n        acc = _compute_expert_block_lora(\n            E_idx,\n            E_mask,\n            M_in_idx,\n            N_block,\n            N_mask,\n            X_ptr,\n            stride_xm,\n            stride_xk,\n            W_ptr,\n            stride_we,\n            stride_wk,\n            stride_wn,\n            LA_ptr,\n            stride_la_r,\n            stride_la_k,\n            LB_ptr,\n            stride_lb_n,\n            stride_lb_r,\n            K,\n            ACTUAL_R,\n            acc,\n            no_k_mask,\n            BLOCK_M,\n            BLOCK_K,\n            BLOCK_N,\n            BLOCK_R,\n            scaling,\n            allow_tf32=allow_tf32,\n        )\n\n    # Add bias if present\n    if Bias_ptr is not None:\n        B_blk_ptrs = (\n            Bias_ptr\n            + E_idxs[:, None] * stride_bias_e\n            + N_block[None, :] * stride_bias_n\n        )\n        acc += tl.load(B_blk_ptrs, mask=M_boundary_mask[:, None] & N_mask[None, :])\n\n    # Store output\n    if y_grouped:\n        M_out_idx = M_block\n    else:\n        M_out_idx = M_idx\n    Y_blk_ptrs = Y_ptr + (M_out_idx[:, None] * stride_ym + N_block[None, :] * stride_yn)\n    tl.store(Y_blk_ptrs, acc, mask=M_boundary_mask[:, None] & N_mask[None, :])\n\n\ndef _scatter2scatter_lora_split(\n    X: torch.Tensor,\n    W: torch.Tensor,\n    sorted_expert_idxs: torch.Tensor,\n    sorted_scattered_idxs: torch.Tensor,\n    k: int,\n    lora_A: torch.Tensor,\n    lora_B: torch.Tensor,\n    scaling: float,\n    b: Optional[torch.Tensor] = None,\n    x_grouped: bool = False,\n    y_grouped: bool = False,\n    out: Optional[torch.Tensor] = None,\n) -> torch.Tensor:\n    \"\"\"Split base+LoRA forward: 3 scatter2scatter calls, no fused LoRA kernel.\n\n    Faster for models with few large experts (e.g. Mixtral E=8, I=14336)\n    because the base kernel runs at full speed without LoRA SMEM overhead,\n    and the LoRA matmuls (R=16) are tiny separate passes.\n\n    Y = scatter(X, W) + scaling * scatter(scatter(X, A^T), B^T)\n    \"\"\"\n    from axolotl.integrations.kernels.libs.scattermoe_lora.kernels.ops import (\n        scatter2scatter,\n    )\n\n    E = W.size(0)\n    R = lora_A.size(0) // E\n    K = W.size(1)\n    N = W.size(2)\n\n    # 1. Base: Y_base = X @ W  (uses base kernel with optimal tile sizes)\n    output = scatter2scatter(\n        X=X,\n        W=W,\n        b=b,\n        sorted_expert_idxs=sorted_expert_idxs,\n        sorted_scattered_idxs=sorted_scattered_idxs,\n        k=k,\n        x_grouped=x_grouped,\n        y_grouped=y_grouped,\n        out=out,\n    )\n\n    # 2. XA = X @ A^T  (tiny: output is [M*k, R])\n    # Reshape A: [R*E, K] → [E, K, R] (expert weights for scatter2scatter)\n    W_A = lora_A.reshape(E, R, K).permute(0, 2, 1).contiguous()\n    XA = scatter2scatter(\n        X=X,\n        W=W_A,\n        sorted_expert_idxs=sorted_expert_idxs,\n        sorted_scattered_idxs=sorted_scattered_idxs,\n        k=k,\n        x_grouped=x_grouped,\n        y_grouped=True,\n    )\n\n    # 3. Y_lora = XA @ B^T  (R is tiny, so this is very fast)\n    # Reshape B: [N, R*E] → [E, R, N]\n    W_B = lora_B.T.reshape(E, R, N).contiguous()\n    Y_lora = scatter2scatter(\n        X=XA,\n        W=W_B,\n        sorted_expert_idxs=sorted_expert_idxs,\n        sorted_scattered_idxs=sorted_scattered_idxs,\n        k=1,\n        x_grouped=True,\n        y_grouped=y_grouped,\n    )\n\n    # 4. Y = Y_base + scaling * Y_lora\n    output.add_(Y_lora, alpha=scaling)\n    return output\n\n\n# Threshold for switching from fused to split LoRA forward.\n# Split wins when per-expert matmul is large (bandwidth-bound LoRA tile\n# loads dominate in the fused kernel's inner loop).\n# Empirically: split wins for E<=32 with K*N > 20M (e.g. Mixtral, Phi-MoE).\n_SPLIT_LORA_FWD_THRESHOLD = 20_000_000  # per-expert K*N\n_SPLIT_LORA_FWD_MAX_EXPERTS = 32\n\n\ndef scatter2scatter_lora(\n    X: torch.Tensor,\n    W: torch.Tensor,\n    sorted_expert_idxs: torch.Tensor,\n    sorted_scattered_idxs: torch.Tensor,\n    k: int,\n    lora_A: torch.Tensor,\n    lora_B: torch.Tensor,\n    scaling: float,\n    b: Optional[torch.Tensor] = None,\n    x_grouped: bool = False,\n    y_grouped: bool = False,\n    out: Optional[torch.Tensor] = None,\n) -> torch.Tensor:\n    \"\"\"\n    Scatter2scatter with LoRA: Y[i] = X[i] @ W[e] + scaling * (X[i] @ A[e]^T) @ B[e]^T + b[e]\n\n    Automatically selects between:\n    - Fused kernel: single Triton kernel with LoRA in the inner loop.\n      Best for many small experts (E>=64, small K*N).\n    - Split dispatch: 3 separate scatter2scatter calls (base + XA + lora).\n      Best for few large experts (E<=32, large K*N like Mixtral).\n\n    Args:\n        X: Input [M, K] or [M*k, K] if x_grouped\n        W: Expert weights [E, K, N]\n        sorted_expert_idxs: Expert assignments sorted [M*k]\n        sorted_scattered_idxs: Original indices sorted [M*k]\n        k: Fan-out (top-k)\n        lora_A: LoRA A weights [r*E, K]\n        lora_B: LoRA B weights [N, r*E]\n        scaling: LoRA scaling factor (alpha/r)\n        b: Optional bias [E, N]\n        x_grouped: Input pre-grouped by expert\n        y_grouped: Keep output grouped\n        out: Optional pre-allocated output buffer\n\n    Returns:\n        Y: Output [M*k, N]\n    \"\"\"\n    E = W.size(0)\n    K = W.size(1)\n    N = W.size(2)\n\n    # Dispatch: split for few large experts, fused for many small experts\n    if E <= _SPLIT_LORA_FWD_MAX_EXPERTS and K * N >= _SPLIT_LORA_FWD_THRESHOLD:\n        return _scatter2scatter_lora_split(\n            X,\n            W,\n            sorted_expert_idxs,\n            sorted_scattered_idxs,\n            k,\n            lora_A,\n            lora_B,\n            scaling,\n            b,\n            x_grouped,\n            y_grouped,\n            out,\n        )\n\n    assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0)\n    assert sorted_scattered_idxs.size(0) == X.size(0) * k\n\n    R = lora_A.size(0) // E\n\n    # Pad R to power of 2 for Triton tile size\n    BLOCK_R = _block_r_for_rank(R)\n\n    L_scattered = sorted_expert_idxs.size(0)\n\n    if out is None:\n        output = torch.empty((L_scattered, N), device=X.device, dtype=X.dtype)\n    else:\n        assert out.size(0) == L_scattered and out.size(1) == N\n        output = out\n\n    def grid(META):\n        return (\n            triton.cdiv(L_scattered, META[\"BLOCK_M\"]) * triton.cdiv(N, META[\"BLOCK_N\"]),\n        )\n\n    if b is None:\n        stride_be = stride_bn = 0\n        b_ptr = None\n    else:\n        stride_be, stride_bn = b.stride()\n        b_ptr = b\n\n    _scatter2scatter_lora[grid](\n        X,\n        X.stride(0),\n        X.stride(1),\n        W,\n        W.stride(0),\n        W.stride(1),\n        W.stride(2),\n        output,\n        output.stride(0),\n        output.stride(1),\n        b_ptr,\n        stride_be,\n        stride_bn,\n        lora_A,\n        lora_A.stride(0),\n        lora_A.stride(1),\n        lora_B,\n        lora_B.stride(0),\n        lora_B.stride(1),\n        sorted_scattered_idxs,\n        sorted_expert_idxs,\n        FAN_OUT=k,\n        M=X.size(0),\n        K=K,\n        N=N,\n        E=E,\n        ACTUAL_R=R,\n        BLOCK_R=BLOCK_R,\n        ACC_TYPE=tl.float32,\n        scaling=scaling,\n        allow_tf32=ALLOW_TF32,\n        x_grouped=x_grouped,\n        y_grouped=y_grouped,\n    )\n\n    return output\n\n\n# =============================================================================\n# Backward Kernel: Fused dX = dY @ W^T + scaling * (dY @ B) @ A\n# =============================================================================\n\n\n@triton.jit\ndef _compute_expert_block_lora_dX(\n    E_idx,\n    E_mask,\n    M_in_idx,\n    K_block,\n    K_mask,\n    # Input: DY (gradient w.r.t. output)\n    DY_ptr,\n    stride_dym,\n    stride_dyn,\n    # Base weight W^T: we load W[e] as [K, N] and index as W^T[e] = [N, K]\n    W_ptr,\n    stride_we,\n    stride_wk,\n    stride_wn,\n    # LoRA weights\n    A_ptr,\n    stride_ar,\n    stride_ak,  # A: [r*E, K]\n    B_ptr,\n    stride_bn,\n    stride_br,  # B: [N, r*E]\n    # Dimensions\n    N,\n    ACTUAL_R: tl.constexpr,\n    acc,\n    no_n_mask,\n    BLOCK_M: tl.constexpr,\n    BLOCK_N: tl.constexpr,\n    BLOCK_K: tl.constexpr,\n    BLOCK_R: tl.constexpr,\n    scaling,\n    allow_tf32: tl.constexpr,\n):\n    \"\"\"\n    Compute dX_block = DY_block @ W_e^T + scaling * (DY_block @ B_e) @ A_e\n\n    for tokens in this M-block assigned to expert E_idx.\n\n    Inner loop over N dimension (reduction dim for dY @ W^T and dY @ B).\n    Output dimension is K.\n    Epilogue computes (dY @ B) @ A.\n\n    Transpose mapping from forward:\n      Forward: X@W (K-loop), X@A^T (K-loop), (X@A^T)@B^T (epilogue)\n      Backward: DY@W^T (N-loop), DY@B (N-loop), (DY@B)@A (epilogue)\n    \"\"\"\n    N_block = tl.arange(0, BLOCK_N)\n    R_block = tl.arange(0, BLOCK_R)\n    R_mask = R_block < ACTUAL_R\n\n    # DY pointers: DY is [M_total, N], load [BLOCK_M, BLOCK_N]\n    DY_blk_ptrs = (\n        DY_ptr + M_in_idx[:, None] * stride_dym + N_block[None, :] * stride_dyn\n    )\n\n    # W^T pointers: W[e] is [K, N], W^T[e] is [N, K]. We load W^T as [BLOCK_N, BLOCK_K].\n    # W stored as [E, K, N], so W^T[e][n, k] = W[e][k, n] = W_ptr + e*stride_we + k*stride_wk + n*stride_wn\n    # As [BLOCK_N, BLOCK_K] tile: row=n, col=k\n    WT_blk_ptrs = (\n        W_ptr\n        + E_idx * stride_we\n        + N_block[:, None] * stride_wn  # row = n dimension\n        + K_block[None, :] * stride_wk\n    )  # col = k dimension\n\n    # B pointers: B[e] is B[:, e*R:(e+1)*R], shape [N, R]. Load [BLOCK_N, BLOCK_R].\n    B_expert_offset = E_idx * ACTUAL_R\n    B_blk_ptrs = (\n        B_ptr\n        + N_block[:, None] * stride_bn\n        + (B_expert_offset + R_block)[None, :] * stride_br\n    )\n\n    iters = tl.cdiv(N, BLOCK_N)\n\n    # Accumulator for DY @ B: [BLOCK_M, BLOCK_R]\n    dy_b_acc = tl.zeros((BLOCK_M, BLOCK_R), dtype=tl.float32)\n\n    # Determine the input element type for consistent casting.\n    INPUT_DTYPE = DY_ptr.dtype.element_ty\n\n    for i in range(iters):\n        if no_n_mask:\n            dy = tl.load(DY_blk_ptrs, mask=E_mask[:, None], other=0.0).to(INPUT_DTYPE)\n            wt = tl.load(WT_blk_ptrs, mask=K_mask[None, :], other=0.0).to(INPUT_DTYPE)\n            b = tl.load(B_blk_ptrs, mask=R_mask[None, :], other=0.0).to(INPUT_DTYPE)\n        else:\n            N_mask_iter = (i * BLOCK_N + N_block) < N\n            dy = tl.load(\n                DY_blk_ptrs, mask=E_mask[:, None] & N_mask_iter[None, :], other=0.0\n            ).to(INPUT_DTYPE)\n            wt = tl.load(\n                WT_blk_ptrs, mask=N_mask_iter[:, None] & K_mask[None, :], other=0.0\n            ).to(INPUT_DTYPE)\n            b = tl.load(\n                B_blk_ptrs, mask=N_mask_iter[:, None] & R_mask[None, :], other=0.0\n            ).to(INPUT_DTYPE)\n\n        # Base: acc += DY @ W^T  ([M, N] @ [N, K] -> [M, K])\n        acc += tl.dot(dy, wt, allow_tf32=allow_tf32).to(tl.float32)\n\n        # LoRA: dy_b_acc += DY @ B  ([M, N] @ [N, R] -> [M, R])\n        dy_b_acc += tl.dot(dy, b, allow_tf32=allow_tf32).to(tl.float32)\n\n        DY_blk_ptrs += BLOCK_N * stride_dyn\n        WT_blk_ptrs += BLOCK_N * stride_wn\n        B_blk_ptrs += BLOCK_N * stride_bn\n\n    # Epilogue: load A[e] and compute (DY @ B) @ A\n    # A[e] is A[e*R:(e+1)*R, :], shape [R, K]. Load [BLOCK_R, BLOCK_K].\n    A_expert_offset = E_idx * ACTUAL_R\n    A_blk_ptrs = (\n        A_ptr\n        + (A_expert_offset + R_block)[:, None] * stride_ar\n        + K_block[None, :] * stride_ak\n    )\n    a_e = tl.load(A_blk_ptrs, mask=R_mask[:, None] & K_mask[None, :], other=0.0).to(\n        INPUT_DTYPE\n    )\n\n    # (DY @ B) @ A: [M, R] @ [R, K] -> [M, K]\n    # tl.dot requires non-float32 inputs (tensor cores); cast accumulator back to input dtype\n    lora_dx = tl.dot(dy_b_acc.to(INPUT_DTYPE), a_e, allow_tf32=allow_tf32)\n\n    acc += scaling * lora_dx\n    return acc\n\n\ndef _scatter2scatter_lora_dX_configs():\n    \"\"\"Generate backward dX kernel autotune configs.\n\n    The inner loop is over N (not K as in forward). The output dimension is K.\n    So BLOCK_K tiles the output and BLOCK_N tiles the reduction.\n\n    BLOCK_M is now autotunable (was fixed at 128).\n\n    Search space:\n      BLOCK_M:    {32, 64, 128}        (token tile)\n      BLOCK_K:    {32, 64, 128, 256}   (output tile)\n      BLOCK_N:    {32, 64, 128, 256}   (reduction tile)\n      num_warps:  {4, 8}\n      num_stages: {3, 4, 5}\n    \"\"\"\n    configs = []\n    for block_m, block_k, block_n, warps, stages in product(\n        [32, 64, 128],  # BLOCK_M\n        [32, 64, 128, 256],  # BLOCK_K (output dimension)\n        [32, 64, 128, 256],  # BLOCK_N (reduction dimension)\n        [4, 8],  # num_warps\n        [3, 4, 5],  # num_stages\n    ):\n        configs.append(\n            triton.Config(\n                {\"BLOCK_M\": block_m, \"BLOCK_K\": block_k, \"BLOCK_N\": block_n},\n                num_stages=stages,\n                num_warps=warps,\n            )\n        )\n    return configs\n\n\ndef _prune_dX_configs(configs, named_args, **kwargs):\n    \"\"\"Prune backward dX configs based on SMEM capacity and register pressure.\n\n    The dX kernel inner loop loads three tiles per pipeline stage:\n      DY[BLOCK_M, BLOCK_N], W^T[BLOCK_N, BLOCK_K], B[BLOCK_N, BLOCK_R].\n    The base estimate only accounts for DY and W^T. We add:\n      - B tile [BLOCK_N, BLOCK_R] per pipeline stage (loaded in the inner loop)\n      - A tile [BLOCK_R, BLOCK_K] loaded once in the epilogue\n      - Extra headroom for compiler overhead (register spills, metadata)\n    \"\"\"\n    smem_cap = _get_smem_capacity()\n\n    # Get BLOCK_R from named_args if available, else assume worst case\n    block_r = named_args.get(\"BLOCK_R\", 64)\n\n    scored = []\n    for config in configs:\n        block_m = config.kwargs[\"BLOCK_M\"]\n        block_k = config.kwargs[\"BLOCK_K\"]\n        block_n = config.kwargs[\"BLOCK_N\"]\n        # Base: stages * BLOCK_N * (BLOCK_M + BLOCK_K) + BLOCK_M * BLOCK_K\n        smem_base = _estimate_smem_usage(config.num_stages, block_m, block_k, block_n)\n        # B tile [BLOCK_N, BLOCK_R] loaded per stage in the inner loop\n        smem_lora_loop = config.num_stages * block_n * block_r * 2\n        # A tile [BLOCK_R, BLOCK_K] loaded once in epilogue\n        smem_lora_epilogue = block_r * block_k * 2\n        smem = smem_base + smem_lora_loop + smem_lora_epilogue\n\n        # Register pressure: live tiles are acc[M,K], dy_b_acc[M,R],\n        # dy[M,N], wt[N,K], b[N,R], plus epilogue a[R,K]\n        est_regs = _estimate_register_pressure(\n            config.num_warps,\n            (block_m, block_k),  # acc\n            (block_m, block_r),  # dy_b_acc\n            (block_m, block_n),  # dy tile\n            (block_n, block_k),  # wt tile\n            (block_n, block_r),  # b tile\n            (block_r, block_k),  # a tile (epilogue)\n        )\n        if est_regs > _MAX_REGS_SOFT_LIMIT:\n            continue\n\n        scored.append((smem, config))\n\n    pruned = [c for s, c in scored if s <= smem_cap - _SMEM_SLACK]\n    if pruned:\n        return pruned\n    if scored:\n        # All surviving configs exceed SMEM — return the one with smallest usage\n        scored.sort(key=lambda x: x[0])\n        return [scored[0][1]]\n    # All configs pruned by register pressure — fall back to smallest tiles\n    return [\n        min(\n            configs,\n            key=lambda c: (\n                c.kwargs[\"BLOCK_M\"] * c.kwargs[\"BLOCK_K\"] * c.kwargs[\"BLOCK_N\"]\n            ),\n        )\n    ]\n\n\n@triton.autotune(\n    configs=_scatter2scatter_lora_dX_configs(),\n    key=[\"M\", \"N\", \"K\"],\n    prune_configs_by={\"early_config_prune\": _prune_dX_configs},\n)\n@triton.heuristics(\n    {\n        \"NO_K_MASK\": lambda args: (args[\"K\"] % args[\"BLOCK_K\"]) == 0,\n        \"NO_N_MASK\": lambda args: (args[\"N\"] % args[\"BLOCK_N\"]) == 0,\n    }\n)\n@triton.jit\ndef _scatter2scatter_lora_dX(\n    # Input: DY (gradient w.r.t. output, grouped)\n    DY_ptr,\n    stride_dym: tl.constexpr,\n    stride_dyn: tl.constexpr,\n    # Base weight: W [E, K, N] (we compute DY @ W^T)\n    W_ptr,\n    stride_we,\n    stride_wk: tl.constexpr,\n    stride_wn: tl.constexpr,\n    # Output: dX\n    DX_ptr,\n    stride_dxm: tl.constexpr,\n    stride_dxk: tl.constexpr,\n    # LoRA weights\n    LA_ptr,\n    stride_la_r,\n    stride_la_k,  # A: [r*E, K]\n    LB_ptr,\n    stride_lb_n,\n    stride_lb_r,  # B: [N, r*E]\n    # Routing\n    grouped_idx_ptr,\n    expert_idxs_ptr,\n    # Dimensions\n    FAN_OUT: tl.constexpr,\n    M,\n    K: tl.constexpr,\n    N: tl.constexpr,\n    E: tl.constexpr,\n    ACTUAL_R: tl.constexpr,\n    # Block sizes\n    BLOCK_M: tl.constexpr,\n    BLOCK_N: tl.constexpr,\n    BLOCK_K: tl.constexpr,\n    BLOCK_R: tl.constexpr,\n    # Config\n    ACC_TYPE: tl.constexpr,\n    scaling,\n    allow_tf32: tl.constexpr,\n    dy_grouped: tl.constexpr,\n    dx_grouped: tl.constexpr,\n    NO_K_MASK: tl.constexpr,\n    NO_N_MASK: tl.constexpr,\n):\n    \"\"\"\n    Fused backward dX = DY @ W^T + scaling * (DY @ B) @ A\n\n    DY is in expert-grouped order (x_grouped=True).\n    dX is output in ungrouped or grouped order based on dx_grouped.\n\n    Grid: (cdiv(M_total, BLOCK_M) * cdiv(K, BLOCK_K),)\n    \"\"\"\n    pid = tl.program_id(axis=0)\n\n    K_BLOCK_COUNT = tl.cdiv(K, BLOCK_K)\n    M_block_id = pid // K_BLOCK_COUNT\n    K_block_id = pid % K_BLOCK_COUNT\n\n    M_block = M_block_id * BLOCK_M + tl.arange(0, BLOCK_M)\n    K_block = K_block_id * BLOCK_K + tl.arange(0, BLOCK_K)\n    K_mask = K_block < K\n    M_boundary_mask = M_block < (FAN_OUT * M)\n\n    E_idxs = tl.load(expert_idxs_ptr + M_block, mask=M_boundary_mask, other=E)\n\n    no_n_mask = NO_N_MASK\n\n    acc = tl.zeros((BLOCK_M, BLOCK_K), dtype=ACC_TYPE)\n\n    E_first_idx = tl.min(E_idxs)\n    E_last_idx = tl.minimum(tl.max(E_idxs), E - 1)\n    M_idx = tl.load(grouped_idx_ptr + M_block, mask=M_boundary_mask).to(tl.int32)\n\n    for E_idx in range(E_first_idx, E_last_idx + 1):\n        E_mask = E_idxs == E_idx\n        if dy_grouped:\n            M_in_idx = M_block\n        else:\n            M_in_idx = M_idx // FAN_OUT\n\n        acc = _compute_expert_block_lora_dX(\n            E_idx,\n            E_mask,\n            M_in_idx,\n            K_block,\n            K_mask,\n            DY_ptr,\n            stride_dym,\n            stride_dyn,\n            W_ptr,\n            stride_we,\n            stride_wk,\n            stride_wn,\n            LA_ptr,\n            stride_la_r,\n            stride_la_k,\n            LB_ptr,\n            stride_lb_n,\n            stride_lb_r,\n            N,\n            ACTUAL_R,\n            acc,\n            no_n_mask,\n            BLOCK_M,\n            BLOCK_N,\n            BLOCK_K,\n            BLOCK_R,\n            scaling,\n            allow_tf32=allow_tf32,\n        )\n\n    # Store output\n    if dx_grouped:\n        M_out_idx = M_block\n    else:\n        M_out_idx = M_idx\n    DX_blk_ptrs = DX_ptr + (\n        M_out_idx[:, None] * stride_dxm + K_block[None, :] * stride_dxk\n    )\n    tl.store(DX_blk_ptrs, acc, mask=M_boundary_mask[:, None] & K_mask[None, :])\n\n\ndef scatter2scatter_lora_dX(\n    DY: torch.Tensor,\n    W: torch.Tensor,\n    sorted_expert_idxs: torch.Tensor,\n    sorted_scattered_idxs: torch.Tensor,\n    k: int,\n    lora_A: torch.Tensor,\n    lora_B: torch.Tensor,\n    scaling: float,\n    dy_grouped: bool = True,\n    dx_grouped: bool = False,\n    out: Optional[torch.Tensor] = None,\n) -> torch.Tensor:\n    \"\"\"\n    Fused backward dX = DY @ W^T + scaling * (DY @ B) @ A\n\n    Replaces the separate:\n      1. base_ops.scatter2scatter(DY, W^T, x_grouped=True, ...)\n      2. _compute_lora_input_grad(DY, A, B, ...)\n\n    Args:\n        DY: Gradient w.r.t. output [M*k, N] (grouped by expert)\n        W: Expert weights [E, K, N] (NOT transposed — kernel handles W^T internally)\n        sorted_expert_idxs: Expert assignments sorted [M*k]\n        sorted_scattered_idxs: Original indices sorted [M*k]\n        k: Fan-out (top-k)\n        lora_A: LoRA A weights [r*E, K]\n        lora_B: LoRA B weights [N, r*E]\n        scaling: LoRA scaling factor\n        dy_grouped: Whether DY is in grouped (expert-sorted) order (default True)\n        dx_grouped: Whether to output dX in grouped order (default False)\n        out: Optional pre-allocated output buffer\n\n    Returns:\n        dX: Input gradient [M*k, K]\n    \"\"\"\n    assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0)\n\n    E = W.size(0)\n    K = W.size(1)\n    N = W.size(2)\n    R = lora_A.size(0) // E\n\n    BLOCK_R = _block_r_for_rank(R)\n\n    L_scattered = sorted_expert_idxs.size(0)\n\n    # M for the kernel is DY.size(0) when dy_grouped, else the original M\n    if dy_grouped:\n        M = DY.size(0)\n        fan_out = 1  # DY is already expanded\n    else:\n        M = DY.size(0)\n        fan_out = k\n\n    if out is None:\n        output = torch.empty((L_scattered, K), device=DY.device, dtype=DY.dtype)\n    else:\n        assert out.size(0) == L_scattered and out.size(1) == K\n        output = out\n\n    def grid(META):\n        return (\n            triton.cdiv(L_scattered, META[\"BLOCK_M\"]) * triton.cdiv(K, META[\"BLOCK_K\"]),\n        )\n\n    _scatter2scatter_lora_dX[grid](\n        DY,\n        DY.stride(0),\n        DY.stride(1),\n        W,\n        W.stride(0),\n        W.stride(1),\n        W.stride(2),\n        output,\n        output.stride(0),\n        output.stride(1),\n        lora_A,\n        lora_A.stride(0),\n        lora_A.stride(1),\n        lora_B,\n        lora_B.stride(0),\n        lora_B.stride(1),\n        sorted_scattered_idxs,\n        sorted_expert_idxs,\n        FAN_OUT=fan_out,\n        M=M,\n        K=K,\n        N=N,\n        E=E,\n        ACTUAL_R=R,\n        # BLOCK_M is autotuned (injected by triton.autotune from Config kwargs)\n        BLOCK_R=BLOCK_R,\n        ACC_TYPE=tl.float32,\n        scaling=scaling,\n        allow_tf32=ALLOW_TF32,\n        dy_grouped=dy_grouped,\n        dx_grouped=dx_grouped,\n    )\n\n    return output\n\n\n# =============================================================================\n# Backward Kernel: LoRA gradient computation (dA, dB)\n# =============================================================================\n\n\ndef _group_bwd_lora_configs():\n    \"\"\"Generate backward (dA/dB) kernel autotune configs.\n\n    Search space includes smaller tile sizes and fewer pipeline stages to\n    support GPUs with limited shared memory (e.g. ~99KB on some GPUs).\n\n    Search space:\n      BLOCK_M:    {32, 64, 128, 256}   (token-loop tile)\n      BLOCK_K:    {32, 64, 128, 256}\n      BLOCK_N:    {32, 64, 128, 256}\n      num_warps:  {4, 8}\n      num_stages: {3, 4, 5}\n\n    The backward kernel also uses BLOCK_R (from LoRA rank), but that is\n    determined by the rank and not autotunable.\n    \"\"\"\n    configs = []\n    for block_m, block_k, block_n, warps, stages in product(\n        [32, 64, 128, 256],  # BLOCK_M\n        [32, 64, 128, 256],  # BLOCK_K\n        [32, 64, 128, 256],  # BLOCK_N\n        [4, 8],  # num_warps\n        [3, 4, 5],  # num_stages\n    ):\n        configs.append(\n            triton.Config(\n                {\"BLOCK_M\": block_m, \"BLOCK_K\": block_k, \"BLOCK_N\": block_n},\n                num_stages=stages,\n                num_warps=warps,\n            )\n        )\n    return configs\n\n\ndef _prune_bwd_lora_configs(configs, named_args, **kwargs):\n    \"\"\"Prune backward configs based on SMEM capacity and register pressure.\n\n    The backward kernel loads X[BLOCK_M, BLOCK_K] and DY[BLOCK_M, BLOCK_N]\n    in the inner loop, plus holds A[BLOCK_R, BLOCK_K] and B[BLOCK_N, BLOCK_R]\n    for the full expert. We estimate SMEM based on the dominant terms.\n    \"\"\"\n    smem_cap = _get_smem_capacity()\n    block_r = named_args.get(\"BLOCK_R\", 64)\n\n    scored = []\n    for config in configs:\n        block_m = config.kwargs[\"BLOCK_M\"]\n        block_k = config.kwargs[\"BLOCK_K\"]\n        block_n = config.kwargs[\"BLOCK_N\"]\n        # Inner loop loads X[M,K] and DY[M,N], pipeline over M iterations\n        smem_base = _estimate_smem_usage(config.num_stages, block_m, block_n, block_k)\n        # A[BLOCK_R, BLOCK_K] and B[BLOCK_N, BLOCK_R] held for the full expert\n        smem_lora = (block_r * block_k + block_n * block_r) * 2\n        smem = smem_base + smem_lora\n\n        # Register pressure: dA_acc[R,K], dB_acc[N,R], x[M,K], dy[M,N],\n        # a[R,K], b[N,R], xa[M,R], dy_b[M,R]\n        est_regs = _estimate_register_pressure(\n            config.num_warps,\n            (block_r, block_k),  # dA_acc\n            (block_n, block_r),  # dB_acc\n            (block_m, block_k),  # x tile\n            (block_m, block_n),  # dy tile\n            (block_r, block_k),  # a tile\n            (block_n, block_r),  # b tile\n            (block_m, block_r),  # xa intermediate\n        )\n        if est_regs > _MAX_REGS_SOFT_LIMIT:\n            continue\n\n        scored.append((smem, config))\n\n    pruned = [c for s, c in scored if s <= smem_cap - _SMEM_SLACK]\n    if pruned:\n        return pruned\n    if scored:\n        # All surviving configs exceed SMEM — return the one with smallest usage\n        scored.sort(key=lambda x: x[0])\n        return [scored[0][1]]\n    # All configs pruned by register pressure — fall back to smallest tiles\n    return [\n        min(\n            configs,\n            key=lambda c: (\n                c.kwargs[\"BLOCK_M\"] * c.kwargs[\"BLOCK_K\"] * c.kwargs[\"BLOCK_N\"]\n            ),\n        )\n    ]\n\n\n@triton.autotune(\n    configs=_group_bwd_lora_configs(),\n    key=[\"M\", \"N\", \"K\"],\n    prune_configs_by={\"early_config_prune\": _prune_bwd_lora_configs},\n    reset_to_zero=[\"DLA_ptr\", \"DLB_ptr\"],\n)\n@triton.heuristics(\n    {\n        \"NO_K_MASK\": lambda args: (args[\"K\"] % args[\"BLOCK_K\"]) == 0,\n        \"NO_N_MASK\": lambda args: (args[\"N\"] % args[\"BLOCK_N\"]) == 0,\n    }\n)\n@triton.jit\ndef _group_bwd_lora(\n    # Inputs\n    DY_ptr,\n    stride_dym,\n    stride_dyn,\n    X_ptr,\n    stride_xm,\n    stride_xk,\n    # LoRA weights (needed for cross-terms)\n    LA_ptr,\n    stride_la_r,\n    stride_la_k,  # A: [r*E, K]\n    LB_ptr,\n    stride_lb_n,\n    stride_lb_r,  # B: [N, r*E]\n    # Gradient outputs\n    DLA_ptr,\n    stride_dla_r,\n    stride_dla_k,\n    DLB_ptr,\n    stride_dlb_n,\n    stride_dlb_r,\n    # Expert offsets\n    expert_offsets_ptr,\n    # Dimensions\n    M,\n    K: tl.constexpr,\n    N: tl.constexpr,\n    ACTUAL_R: tl.constexpr,  # True LoRA rank (for weight indexing)\n    BLOCK_R: tl.constexpr,  # Padded tile size >= max(ACTUAL_R, 16)\n    scaling,\n    # Block sizes\n    BLOCK_M: tl.constexpr,\n    BLOCK_K: tl.constexpr,\n    BLOCK_N: tl.constexpr,\n    ACC_TYPE: tl.constexpr,\n    allow_tf32: tl.constexpr,\n    NO_K_MASK: tl.constexpr,\n    NO_N_MASK: tl.constexpr,\n):\n    \"\"\"\n    Compute LoRA gradients for each expert on grouped data.\n\n    Grid: (E * cdiv(K, BLOCK_K), cdiv(N, BLOCK_N))\n\n    For expert e:\n      dA[e] = scaling * (dY @ B[e])^T @ X   -> [r, K], accumulate over M tokens\n      dB[e] = scaling * dY^T @ (X @ A[e]^T)  -> [N, r], accumulate over M tokens\n\n    ACTUAL_R is the true LoRA rank. BLOCK_R >= ACTUAL_R is padded for tl.dot min size.\n    \"\"\"\n    pid0 = tl.program_id(axis=0)\n    pid1 = tl.program_id(axis=1)\n\n    K_BLOCK_COUNT = tl.cdiv(K, BLOCK_K)\n    E_idx = pid0 // K_BLOCK_COUNT\n    K_block_id = pid0 % K_BLOCK_COUNT\n    N_block_id = pid1\n\n    # Get expert's token range from cumulative offsets\n    if E_idx == 0:\n        start_idx = 0\n    else:\n        start_idx = tl.load(expert_offsets_ptr + E_idx - 1).to(tl.int32)\n    end_idx = tl.load(expert_offsets_ptr + E_idx).to(tl.int32)\n    num_tokens = end_idx - start_idx\n\n    if num_tokens > 0:\n        M_block = tl.arange(0, BLOCK_M)\n        K_block = K_block_id * BLOCK_K + tl.arange(0, BLOCK_K)\n        K_mask = K_block < K\n        N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)\n        N_mask = N_block < N\n        R_block = tl.arange(0, BLOCK_R)\n        R_mask = R_block < ACTUAL_R  # Mask for padding\n\n        lora_offset = E_idx * ACTUAL_R\n\n        # Determine input element type for consistent casting.\n        INPUT_DTYPE = X_ptr.dtype.element_ty\n\n        # Load B[e]: [BLOCK_N, BLOCK_R] (masked on R and N, other=0 for padding)\n        B_blk_ptrs = (\n            LB_ptr\n            + N_block[:, None] * stride_lb_n\n            + (lora_offset + R_block)[None, :] * stride_lb_r\n        )\n        b_e = tl.load(B_blk_ptrs, mask=N_mask[:, None] & R_mask[None, :], other=0.0).to(\n            INPUT_DTYPE\n        )\n\n        # Load A[e]: [BLOCK_R, BLOCK_K] (masked on R and K, other=0 for padding)\n        A_blk_ptrs = (\n            LA_ptr\n            + (lora_offset + R_block)[:, None] * stride_la_r\n            + K_block[None, :] * stride_la_k\n        )\n        a_e = tl.load(A_blk_ptrs, mask=R_mask[:, None] & K_mask[None, :], other=0.0).to(\n            INPUT_DTYPE\n        )\n\n        # Accumulators\n        dA_acc = tl.zeros((BLOCK_R, BLOCK_K), dtype=ACC_TYPE)\n        dB_acc = tl.zeros((BLOCK_N, BLOCK_R), dtype=ACC_TYPE)\n\n        iters = tl.cdiv(num_tokens, BLOCK_M)\n        for i in range(iters):\n            M_idx = start_idx + i * BLOCK_M + M_block\n            M_mask = M_idx < end_idx\n\n            # Load X: [BLOCK_M, BLOCK_K]\n            X_blk_ptrs = (\n                X_ptr + M_idx[:, None] * stride_xm + K_block[None, :] * stride_xk\n            )\n            x = tl.load(\n                X_blk_ptrs, mask=M_mask[:, None] & K_mask[None, :], other=0.0\n            ).to(INPUT_DTYPE)\n\n            # Load dY: [BLOCK_M, BLOCK_N]\n            DY_blk_ptrs = (\n                DY_ptr + M_idx[:, None] * stride_dym + N_block[None, :] * stride_dyn\n            )\n            dy = tl.load(\n                DY_blk_ptrs, mask=M_mask[:, None] & N_mask[None, :], other=0.0\n            ).to(INPUT_DTYPE)\n\n            # X @ A[e]^T: [M, K] @ [K, R] -> [M, R]\n            xa = tl.dot(x, tl.trans(a_e), allow_tf32=allow_tf32)\n\n            # dY @ B[e]: [M, N] @ [N, R] -> [M, R]\n            dy_b = tl.dot(dy, b_e, allow_tf32=allow_tf32)\n\n            # Cast intermediates to input dtype for subsequent tl.dot calls\n            # (tl.dot requires both operands to have the same dtype)\n            dy_b_cast = dy_b.to(INPUT_DTYPE)\n            xa_cast = xa.to(INPUT_DTYPE)\n\n            # dA += (dY @ B)^T @ X: [R, M] @ [M, K] -> [R, K]\n            dA_acc += tl.dot(tl.trans(dy_b_cast), x, allow_tf32=allow_tf32)\n\n            # dB += dY^T @ (X @ A^T): [N, M] @ [M, R] -> [N, R]\n            dB_acc += tl.dot(tl.trans(dy), xa_cast, allow_tf32=allow_tf32)\n\n        # Store dA with scaling (atomic add since multiple N_blocks contribute)\n        # Only store the actual R rows, not the padded ones\n        DLA_blk_ptrs = (\n            DLA_ptr\n            + (lora_offset + R_block)[:, None] * stride_dla_r\n            + K_block[None, :] * stride_dla_k\n        )\n        tl.atomic_add(\n            DLA_blk_ptrs,\n            (dA_acc * scaling).to(DLA_ptr.dtype.element_ty),\n            mask=R_mask[:, None] & K_mask[None, :],\n        )\n\n        # Store dB with scaling (atomic add since multiple K_blocks contribute)\n        DLB_blk_ptrs = (\n            DLB_ptr\n            + N_block[:, None] * stride_dlb_n\n            + (lora_offset + R_block)[None, :] * stride_dlb_r\n        )\n        tl.atomic_add(\n            DLB_blk_ptrs,\n            (dB_acc * scaling).to(DLB_ptr.dtype.element_ty),\n            mask=N_mask[:, None] & R_mask[None, :],\n        )\n\n\ndef _group_bwd_split_configs():\n    \"\"\"Autotune configs for split dA/dB kernels.\"\"\"\n    configs = []\n    for block_m, block_dim, warps, stages in product(\n        [32, 64, 128],  # BLOCK_M (token tile)\n        [32, 64, 128, 256],  # BLOCK_DIM (K for dA, N for dB — output tile)\n        [4, 8],  # num_warps\n        [3, 4, 5],  # num_stages\n    ):\n        configs.append(\n            triton.Config(\n                {\"BLOCK_M\": block_m, \"BLOCK_DIM\": block_dim},\n                num_stages=stages,\n                num_warps=warps,\n            )\n        )\n    return configs\n\n\ndef _prune_split_configs(configs, named_args, **kwargs):\n    \"\"\"Prune split kernel configs based on SMEM capacity and register pressure.\"\"\"\n    smem_cap = _get_smem_capacity()\n    block_r = named_args.get(\"BLOCK_R\", 64)\n\n    # Fixed inner tile for reduction dimension\n    BLOCK_INNER = 64\n\n    pruned = []\n    for config in configs:\n        block_m = config.kwargs[\"BLOCK_M\"]\n        block_dim = config.kwargs[\"BLOCK_DIM\"]\n        # Inner loop loads: input[M, INNER] and other[M, INNER_or_DIM]\n        smem = config.num_stages * BLOCK_INNER * (block_m + block_dim) * 2\n        # LoRA weights held in registers: [INNER, R] or [R, DIM]\n        smem += (block_r * max(block_dim, BLOCK_INNER)) * 2\n\n        # Register pressure check\n        est_regs = _estimate_register_pressure(\n            config.num_warps,\n            (block_r, block_dim),  # acc\n            (block_m, BLOCK_INNER),  # input tile\n            (block_m, block_dim),  # other tile\n            (block_r, BLOCK_INNER),  # lora weight\n        )\n        if est_regs > _MAX_REGS_SOFT_LIMIT:\n            continue\n\n        if smem <= smem_cap - _SMEM_SLACK:\n            pruned.append(config)\n\n    if pruned:\n        return pruned\n    configs.sort(key=lambda c: c.kwargs[\"BLOCK_M\"] * c.kwargs[\"BLOCK_DIM\"])\n    return [configs[0]]\n\n\n@triton.autotune(\n    configs=_group_bwd_split_configs(),\n    key=[\"M\", \"K\", \"N\"],\n    prune_configs_by={\"early_config_prune\": _prune_split_configs},\n)\n@triton.heuristics(\n    {\n        \"NO_DIM_MASK\": lambda args: (\n            (args[\"K\"] % args[\"BLOCK_DIM\"]) == 0\n            if args[\"COMPUTE_DA\"]\n            else (args[\"N\"] % args[\"BLOCK_DIM\"]) == 0\n        ),\n    }\n)\n@triton.jit\ndef _group_bwd_lora_split(\n    # Data tensors (DY and X are always present)\n    DY_ptr,\n    stride_dym,\n    stride_dyn,\n    X_ptr,\n    stride_xm,\n    stride_xk,\n    # LoRA weight for the inner reduction (B for dA, A for dB)\n    LW_ptr,\n    stride_lw0,\n    stride_lw1,\n    # Output gradient tensor (dA or dB)\n    OUT_ptr,\n    stride_out0,\n    stride_out1,\n    # Expert offsets\n    expert_offsets_ptr,\n    # Dimensions\n    M,\n    K: tl.constexpr,\n    N: tl.constexpr,\n    ACTUAL_R: tl.constexpr,\n    BLOCK_R: tl.constexpr,\n    INNER_DIM: tl.constexpr,  # reduction dimension (N for dA, K for dB)\n    scaling,\n    # Mode flag\n    COMPUTE_DA: tl.constexpr,  # True = compute dA, False = compute dB\n    # Tile sizes\n    BLOCK_M: tl.constexpr,\n    BLOCK_DIM: tl.constexpr,\n    ACC_TYPE: tl.constexpr,\n    allow_tf32: tl.constexpr,\n    NO_DIM_MASK: tl.constexpr,\n):\n    \"\"\"\n    Unified split kernel for LoRA gradient computation.\n\n    When COMPUTE_DA=True:\n      dA[e] = scaling * (dY @ B[e])^T @ X  →  [R, K]\n      Grid: (E, cdiv(K, BLOCK_DIM))\n      - outer_ptr/stride = X (read [M, K_block])\n      - inner reduction over N using DY and B\n      - output shape [BLOCK_R, BLOCK_DIM]\n\n    When COMPUTE_DA=False:\n      dB[e] = scaling * dY^T @ (X @ A[e]^T)  →  [N, R]\n      Grid: (E, cdiv(N, BLOCK_DIM))\n      - outer_ptr/stride = DY (read [M, N_block])\n      - inner reduction over K using X and A\n      - output shape [BLOCK_DIM, BLOCK_R]\n\n    No atomic adds — each (E, dim_block) pair is written by exactly one block.\n    \"\"\"\n    E_idx = tl.program_id(0)\n    dim_block_id = tl.program_id(1)\n\n    if E_idx == 0:\n        start_idx = 0\n    else:\n        start_idx = tl.load(expert_offsets_ptr + E_idx - 1).to(tl.int32)\n    end_idx = tl.load(expert_offsets_ptr + E_idx).to(tl.int32)\n    num_tokens = end_idx - start_idx\n\n    # Output dimension tile (K for dA, N for dB)\n    if COMPUTE_DA:\n        OUT_DIM: tl.constexpr = K  # type: ignore[no-redef]\n    else:\n        OUT_DIM: tl.constexpr = N  # type: ignore[no-redef]\n    dim_block = dim_block_id * BLOCK_DIM + tl.arange(0, BLOCK_DIM)\n    dim_mask = dim_block < OUT_DIM\n    R_block = tl.arange(0, BLOCK_R)\n    R_mask = R_block < ACTUAL_R\n    lora_offset = E_idx * ACTUAL_R\n\n    # Output pointers — layout differs: dA is [R, K], dB is [N, R]\n    if COMPUTE_DA:\n        out_blk_ptrs = (\n            OUT_ptr\n            + (lora_offset + R_block)[:, None] * stride_out0\n            + dim_block[None, :] * stride_out1\n        )\n        out_mask = R_mask[:, None] & dim_mask[None, :]\n    else:\n        out_blk_ptrs = (\n            OUT_ptr\n            + dim_block[:, None] * stride_out0\n            + (lora_offset + R_block)[None, :] * stride_out1\n        )\n        out_mask = dim_mask[:, None] & R_mask[None, :]\n\n    if num_tokens > 0:\n        M_block = tl.arange(0, BLOCK_M)\n        INPUT_DTYPE = X_ptr.dtype.element_ty\n        BLOCK_INNER: tl.constexpr = 64\n        inner_iters = tl.cdiv(INNER_DIM, BLOCK_INNER)\n\n        if COMPUTE_DA:\n            acc = tl.zeros((BLOCK_R, BLOCK_DIM), dtype=ACC_TYPE)\n        else:\n            acc = tl.zeros((BLOCK_DIM, BLOCK_R), dtype=ACC_TYPE)\n\n        M_iters = tl.cdiv(num_tokens, BLOCK_M)\n        for i in range(M_iters):\n            M_idx = start_idx + i * BLOCK_M + M_block\n            M_mask = M_idx < end_idx\n\n            if COMPUTE_DA:\n                # Load X[M, K_block] (the \"outer\" tensor for dA)\n                outer = tl.load(\n                    X_ptr + M_idx[:, None] * stride_xm + dim_block[None, :] * stride_xk,\n                    mask=M_mask[:, None] & dim_mask[None, :],\n                    other=0.0,\n                ).to(INPUT_DTYPE)\n\n                # Reduce DY[M, :] @ B[e][:, R] over N → [M, R]\n                reduced = tl.zeros((BLOCK_M, BLOCK_R), dtype=ACC_TYPE)\n                inner_range = tl.arange(0, BLOCK_INNER)\n                for j in range(inner_iters):\n                    inn_off = j * BLOCK_INNER + inner_range\n                    inn_mask = inn_off < N\n\n                    dy_tile = tl.load(\n                        DY_ptr\n                        + M_idx[:, None] * stride_dym\n                        + inn_off[None, :] * stride_dyn,\n                        mask=M_mask[:, None] & inn_mask[None, :],\n                        other=0.0,\n                    ).to(INPUT_DTYPE)\n                    # B layout: [N, r*E] → stride_lw0=N stride, stride_lw1=r*E stride\n                    lw_tile = tl.load(\n                        LW_ptr\n                        + inn_off[:, None] * stride_lw0\n                        + (lora_offset + R_block)[None, :] * stride_lw1,\n                        mask=inn_mask[:, None] & R_mask[None, :],\n                        other=0.0,\n                    ).to(INPUT_DTYPE)\n                    reduced += tl.dot(dy_tile, lw_tile, allow_tf32=allow_tf32)\n\n                # dA += (DY@B)^T @ X: [R, M] @ [M, K_block] → [R, K_block]\n                acc += tl.dot(\n                    tl.trans(reduced.to(INPUT_DTYPE)), outer, allow_tf32=allow_tf32\n                )\n            else:\n                # Load DY[M, N_block] (the \"outer\" tensor for dB)\n                outer = tl.load(\n                    DY_ptr\n                    + M_idx[:, None] * stride_dym\n                    + dim_block[None, :] * stride_dyn,\n                    mask=M_mask[:, None] & dim_mask[None, :],\n                    other=0.0,\n                ).to(INPUT_DTYPE)\n\n                # Reduce X[M, :] @ A[e][:, :].T over K → [M, R]\n                reduced = tl.zeros((BLOCK_M, BLOCK_R), dtype=ACC_TYPE)\n                inner_range = tl.arange(0, BLOCK_INNER)\n                for j in range(inner_iters):\n                    inn_off = j * BLOCK_INNER + inner_range\n                    inn_mask = inn_off < K\n\n                    x_tile = tl.load(\n                        X_ptr\n                        + M_idx[:, None] * stride_xm\n                        + inn_off[None, :] * stride_xk,\n                        mask=M_mask[:, None] & inn_mask[None, :],\n                        other=0.0,\n                    ).to(INPUT_DTYPE)\n                    # A layout: [r*E, K] → stride_lw0=r*E stride, stride_lw1=K stride\n                    # We want A[e]^T: [K, R], so load as [K_inner, R]\n                    lw_tile = tl.load(\n                        LW_ptr\n                        + (lora_offset + R_block)[None, :] * stride_lw0\n                        + inn_off[:, None] * stride_lw1,\n                        mask=inn_mask[:, None] & R_mask[None, :],\n                        other=0.0,\n                    ).to(INPUT_DTYPE)\n                    reduced += tl.dot(x_tile, lw_tile, allow_tf32=allow_tf32)\n\n                # dB += DY^T @ (X@A^T): [N_block, M] @ [M, R] → [N_block, R]\n                acc += tl.dot(\n                    tl.trans(outer), reduced.to(INPUT_DTYPE), allow_tf32=allow_tf32\n                )\n\n        tl.store(\n            out_blk_ptrs, (acc * scaling).to(OUT_ptr.dtype.element_ty), mask=out_mask\n        )\n    else:\n        # Zero out this expert's slice — needed because output uses empty_like\n        if COMPUTE_DA:\n            tl.store(\n                out_blk_ptrs,\n                tl.zeros((BLOCK_R, BLOCK_DIM), dtype=OUT_ptr.dtype.element_ty),\n                mask=out_mask,\n            )\n        else:\n            tl.store(\n                out_blk_ptrs,\n                tl.zeros((BLOCK_DIM, BLOCK_R), dtype=OUT_ptr.dtype.element_ty),\n                mask=out_mask,\n            )\n\n\ndef group_bwd_lora(\n    DY: torch.Tensor,\n    X: torch.Tensor,\n    lora_A: torch.Tensor,\n    lora_B: torch.Tensor,\n    expert_offsets: torch.Tensor,\n    E: int,\n    scaling: float,\n    sorted_scattered_idxs: Optional[torch.Tensor] = None,\n    k: int = 1,\n) -> tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"\n    Compute LoRA gradients for A and B on expert-grouped data.\n\n    Uses split dA/dB kernels that eliminate atomic adds by giving each\n    (expert, output_block) pair its own thread block.\n\n    Args:\n        DY: Gradient w.r.t. output [M_total, N] (grouped by expert)\n        X: Input [M_total, K] (grouped by expert)\n        lora_A: LoRA A weights [r*E, K]\n        lora_B: LoRA B weights [N, r*E]\n        expert_offsets: Cumulative token counts per expert [E]\n        E: Number of experts\n        scaling: LoRA scaling factor\n\n    Returns:\n        dA: Gradient for A [r*E, K]\n        dB: Gradient for B [N, r*E]\n    \"\"\"\n    R = lora_A.size(0) // E\n    K = X.size(1)\n    N = DY.size(1)\n\n    # No zero-init needed: the split kernels write zeros for experts with\n    # zero routed tokens directly in the kernel (else branch).\n    dA = torch.empty_like(lora_A)\n    dB = torch.empty_like(lora_B)\n\n    BLOCK_R = _block_r_for_rank(R)\n\n    def grid_dA(META):\n        return (E, triton.cdiv(K, META[\"BLOCK_DIM\"]))\n\n    _group_bwd_lora_split[grid_dA](\n        DY,\n        DY.stride(0),\n        DY.stride(1),\n        X,\n        X.stride(0),\n        X.stride(1),\n        lora_B,\n        lora_B.stride(0),\n        lora_B.stride(1),\n        dA,\n        dA.stride(0),\n        dA.stride(1),\n        expert_offsets,\n        M=DY.size(0),\n        K=K,\n        N=N,\n        ACTUAL_R=R,\n        BLOCK_R=BLOCK_R,\n        INNER_DIM=N,\n        scaling=scaling,\n        COMPUTE_DA=True,\n        ACC_TYPE=tl.float32,\n        allow_tf32=ALLOW_TF32,\n    )\n\n    def grid_dB(META):\n        return (E, triton.cdiv(N, META[\"BLOCK_DIM\"]))\n\n    _group_bwd_lora_split[grid_dB](\n        DY,\n        DY.stride(0),\n        DY.stride(1),\n        X,\n        X.stride(0),\n        X.stride(1),\n        lora_A,\n        lora_A.stride(0),\n        lora_A.stride(1),\n        dB,\n        dB.stride(0),\n        dB.stride(1),\n        expert_offsets,\n        M=DY.size(0),\n        K=K,\n        N=N,\n        ACTUAL_R=R,\n        BLOCK_R=BLOCK_R,\n        INNER_DIM=K,\n        scaling=scaling,\n        COMPUTE_DA=False,\n        ACC_TYPE=tl.float32,\n        allow_tf32=ALLOW_TF32,\n    )\n\n    return dA, dB\n\n\n# =============================================================================\n# Backward Kernel: Fused gather + LoRA gradient (dA, dB) — eliminates group()\n# =============================================================================\n\n\n@triton.autotune(\n    configs=_group_bwd_lora_configs(),\n    key=[\"M\", \"N\", \"K\"],\n    prune_configs_by={\"early_config_prune\": _prune_bwd_lora_configs},\n    reset_to_zero=[\"DLA_ptr\", \"DLB_ptr\"],\n)\n@triton.heuristics(\n    {\n        \"NO_K_MASK\": lambda args: (args[\"K\"] % args[\"BLOCK_K\"]) == 0,\n        \"NO_N_MASK\": lambda args: (args[\"N\"] % args[\"BLOCK_N\"]) == 0,\n    }\n)\n@triton.jit\ndef _group_bwd_lora_fused(\n    # Inputs (ungrouped or grouped)\n    DY_ptr,\n    stride_dym,\n    stride_dyn,\n    X_ptr,\n    stride_xm,\n    stride_xk,\n    # Scatter indices for gather-on-load\n    sorted_scattered_idxs_ptr,\n    FAN_OUT: tl.constexpr,\n    # LoRA weights (needed for cross-terms)\n    LA_ptr,\n    stride_la_r,\n    stride_la_k,  # A: [r*E, K]\n    LB_ptr,\n    stride_lb_n,\n    stride_lb_r,  # B: [N, r*E]\n    # Gradient outputs\n    DLA_ptr,\n    stride_dla_r,\n    stride_dla_k,\n    DLB_ptr,\n    stride_dlb_n,\n    stride_dlb_r,\n    # Expert offsets\n    expert_offsets_ptr,\n    # Real expert offsets (for M_mask when using token rounding, else same as expert_offsets_ptr)\n    real_expert_offsets_ptr,\n    # Dimensions\n    M,\n    K: tl.constexpr,\n    N: tl.constexpr,\n    ACTUAL_R: tl.constexpr,\n    BLOCK_R: tl.constexpr,\n    scaling,\n    # Block sizes\n    BLOCK_M: tl.constexpr,\n    BLOCK_K: tl.constexpr,\n    BLOCK_N: tl.constexpr,\n    ACC_TYPE: tl.constexpr,\n    allow_tf32: tl.constexpr,\n    NO_K_MASK: tl.constexpr,\n    NO_N_MASK: tl.constexpr,\n    # Whether DY is already in grouped (expert-sorted) order\n    dy_grouped: tl.constexpr = False,\n):\n    \"\"\"\n    Fused gather + LoRA gradient computation. Same as _group_bwd_lora but\n    reads X from ungrouped buffers using sorted_scattered_idxs for indirect\n    indexing, eliminating the need for a separate group(X) call.\n\n    When dy_grouped=False (default): both X and DY are read via indirect\n    indexing through sorted_scattered_idxs.  This eliminates both group()\n    calls entirely.\n\n    When dy_grouped=True: DY is already in grouped order (e.g. gate_up_proj\n    backward where grouped_out=True) and is read directly.  Only X uses\n    indirect indexing.  This avoids the group(X) allocation while\n    still supporting the grouped DY case.\n\n    Grid: (E * cdiv(K, BLOCK_K), cdiv(N, BLOCK_N))\n\n    For expert e:\n      dA[e] = scaling * (dY @ B[e])^T @ X   -> [r, K]\n      dB[e] = scaling * dY^T @ (X @ A[e]^T)  -> [N, r]\n\n    Supports token rounding: expert_offsets_ptr gives the iteration range\n    (padded to BLOCK_M multiples), real_expert_offsets_ptr gives the real\n    token count for M_mask (to exclude padding tokens).\n    \"\"\"\n    pid0 = tl.program_id(axis=0)\n    pid1 = tl.program_id(axis=1)\n\n    K_BLOCK_COUNT = tl.cdiv(K, BLOCK_K)\n    E_idx = pid0 // K_BLOCK_COUNT\n    K_block_id = pid0 % K_BLOCK_COUNT\n    N_block_id = pid1\n\n    # Get expert's token range from cumulative offsets\n    # start_idx/end_idx from expert_offsets_ptr: iteration range (possibly padded)\n    # real_end_idx from real_expert_offsets_ptr: for M_mask (real token count)\n    if E_idx == 0:\n        start_idx = 0\n        real_start_idx = 0\n    else:\n        start_idx = tl.load(expert_offsets_ptr + E_idx - 1).to(tl.int32)\n        real_start_idx = tl.load(real_expert_offsets_ptr + E_idx - 1).to(tl.int32)\n    end_idx = tl.load(expert_offsets_ptr + E_idx).to(tl.int32)\n    real_end_idx = tl.load(real_expert_offsets_ptr + E_idx).to(tl.int32)\n    num_tokens = end_idx - start_idx\n\n    if num_tokens > 0:\n        M_block = tl.arange(0, BLOCK_M)\n        K_block = K_block_id * BLOCK_K + tl.arange(0, BLOCK_K)\n        K_mask = K_block < K\n        N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)\n        N_mask = N_block < N\n        R_block = tl.arange(0, BLOCK_R)\n        R_mask = R_block < ACTUAL_R\n\n        lora_offset = E_idx * ACTUAL_R\n\n        # Determine input element type for consistent casting.\n        INPUT_DTYPE = X_ptr.dtype.element_ty\n\n        # Load B[e] and A[e] — same as non-fused kernel\n        B_blk_ptrs = (\n            LB_ptr\n            + N_block[:, None] * stride_lb_n\n            + (lora_offset + R_block)[None, :] * stride_lb_r\n        )\n        b_e = tl.load(B_blk_ptrs, mask=N_mask[:, None] & R_mask[None, :], other=0.0).to(\n            INPUT_DTYPE\n        )\n\n        A_blk_ptrs = (\n            LA_ptr\n            + (lora_offset + R_block)[:, None] * stride_la_r\n            + K_block[None, :] * stride_la_k\n        )\n        a_e = tl.load(A_blk_ptrs, mask=R_mask[:, None] & K_mask[None, :], other=0.0).to(\n            INPUT_DTYPE\n        )\n\n        # Accumulators\n        dA_acc = tl.zeros((BLOCK_R, BLOCK_K), dtype=ACC_TYPE)\n        dB_acc = tl.zeros((BLOCK_N, BLOCK_R), dtype=ACC_TYPE)\n\n        real_num_tokens = real_end_idx - real_start_idx\n        iters = tl.cdiv(num_tokens, BLOCK_M)\n        for i in range(iters):\n            M_idx = start_idx + i * BLOCK_M + M_block\n            # Use real token count for masking (excludes padding tokens)\n            M_local = i * BLOCK_M + M_block\n            M_mask = M_local < real_num_tokens\n\n            # Fused gather: load scatter indices for indirect X access\n            scatter_idx = tl.load(\n                sorted_scattered_idxs_ptr + M_idx, mask=M_mask, other=0\n            ).to(tl.int32)\n            X_token_idx = scatter_idx // FAN_OUT  # X is [M, K], not expanded by k\n\n            # Load X via indirect index: [BLOCK_M, BLOCK_K]\n            X_blk_ptrs = (\n                X_ptr + X_token_idx[:, None] * stride_xm + K_block[None, :] * stride_xk\n            )\n            x = tl.load(\n                X_blk_ptrs, mask=M_mask[:, None] & K_mask[None, :], other=0.0\n            ).to(INPUT_DTYPE)\n\n            # Load DY: indirect via scatter_idx when ungrouped, direct via M_idx when grouped\n            if dy_grouped:\n                DY_blk_ptrs = (\n                    DY_ptr + M_idx[:, None] * stride_dym + N_block[None, :] * stride_dyn\n                )\n            else:\n                DY_blk_ptrs = (\n                    DY_ptr\n                    + scatter_idx[:, None] * stride_dym\n                    + N_block[None, :] * stride_dyn\n                )\n            dy = tl.load(\n                DY_blk_ptrs, mask=M_mask[:, None] & N_mask[None, :], other=0.0\n            ).to(INPUT_DTYPE)\n\n            # X @ A[e]^T: [M, K] @ [K, R] -> [M, R]\n            xa = tl.dot(x, tl.trans(a_e), allow_tf32=allow_tf32)\n\n            # dY @ B[e]: [M, N] @ [N, R] -> [M, R]\n            dy_b = tl.dot(dy, b_e, allow_tf32=allow_tf32)\n\n            dy_b_cast = dy_b.to(INPUT_DTYPE)\n            xa_cast = xa.to(INPUT_DTYPE)\n\n            # dA += (dY @ B)^T @ X: [R, M] @ [M, K] -> [R, K]\n            dA_acc += tl.dot(tl.trans(dy_b_cast), x, allow_tf32=allow_tf32)\n\n            # dB += dY^T @ (X @ A^T): [N, M] @ [M, R] -> [N, R]\n            dB_acc += tl.dot(tl.trans(dy), xa_cast, allow_tf32=allow_tf32)\n\n        # Store dA with scaling (atomic add since multiple N_blocks contribute)\n        DLA_blk_ptrs = (\n            DLA_ptr\n            + (lora_offset + R_block)[:, None] * stride_dla_r\n            + K_block[None, :] * stride_dla_k\n        )\n        tl.atomic_add(\n            DLA_blk_ptrs,\n            (dA_acc * scaling).to(DLA_ptr.dtype.element_ty),\n            mask=R_mask[:, None] & K_mask[None, :],\n        )\n\n        # Store dB with scaling (atomic add since multiple K_blocks contribute)\n        DLB_blk_ptrs = (\n            DLB_ptr\n            + N_block[:, None] * stride_dlb_n\n            + (lora_offset + R_block)[None, :] * stride_dlb_r\n        )\n        tl.atomic_add(\n            DLB_blk_ptrs,\n            (dB_acc * scaling).to(DLB_ptr.dtype.element_ty),\n            mask=N_mask[:, None] & R_mask[None, :],\n        )\n\n\ndef group_bwd_lora_fused(\n    DY: torch.Tensor,\n    X: torch.Tensor,\n    lora_A: torch.Tensor,\n    lora_B: torch.Tensor,\n    expert_offsets: torch.Tensor,\n    sorted_scattered_idxs: torch.Tensor,\n    E: int,\n    k: int,\n    scaling: float,\n    real_expert_offsets: Optional[torch.Tensor] = None,\n    dy_grouped: bool = False,\n) -> tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"\n    Fused gather + LoRA gradient computation. Same result as\n    group(X) + group(DY) + group_bwd_lora(DY, X, ...) but without\n    the intermediate grouped buffers.\n\n    Args:\n        DY: Gradient w.r.t. output [M*k, N].\n            If dy_grouped=False: ungrouped (original token order), read via\n            indirect indexing through sorted_scattered_idxs.\n            If dy_grouped=True: already in grouped (expert-sorted) order,\n            read directly.\n        X: Input [M, K] (ungrouped, original token order).  Always read via\n            indirect indexing through sorted_scattered_idxs.\n        lora_A: LoRA A weights [r*E, K]\n        lora_B: LoRA B weights [N, r*E]\n        expert_offsets: Cumulative token counts per expert [E]\n            (or padded offsets if using token rounding)\n        sorted_scattered_idxs: Maps grouped position -> original position [M*k]\n            (or padded version if using token rounding)\n        E: Number of experts\n        k: Fan-out (top-k)\n        scaling: LoRA scaling factor\n        real_expert_offsets: Original cumulative counts for M_mask when using\n            token rounding. If None, expert_offsets is used for both.\n        dy_grouped: Whether DY is already in grouped order (default False).\n            When True, avoids indirect indexing for DY, used for gate_up_proj\n            backward where grouped_out=True.\n\n    Returns:\n        dA: Gradient for A [r*E, K]\n        dB: Gradient for B [N, r*E]\n    \"\"\"\n    R = lora_A.size(0) // E\n    K = X.size(1)\n    N = DY.size(1)\n\n    # Zero-init for atomic accumulation\n    dA = torch.zeros_like(lora_A)\n    dB = torch.zeros_like(lora_B)\n\n    BLOCK_R = _block_r_for_rank(R)\n\n    if real_expert_offsets is None:\n        real_expert_offsets = expert_offsets\n\n    def grid(META):\n        return (\n            E * triton.cdiv(K, META[\"BLOCK_K\"]),\n            triton.cdiv(N, META[\"BLOCK_N\"]),\n        )\n\n    _group_bwd_lora_fused[grid](\n        DY,\n        DY.stride(0),\n        DY.stride(1),\n        X,\n        X.stride(0),\n        X.stride(1),\n        sorted_scattered_idxs,\n        FAN_OUT=k,\n        LA_ptr=lora_A,\n        stride_la_r=lora_A.stride(0),\n        stride_la_k=lora_A.stride(1),\n        LB_ptr=lora_B,\n        stride_lb_n=lora_B.stride(0),\n        stride_lb_r=lora_B.stride(1),\n        DLA_ptr=dA,\n        stride_dla_r=dA.stride(0),\n        stride_dla_k=dA.stride(1),\n        DLB_ptr=dB,\n        stride_dlb_n=dB.stride(0),\n        stride_dlb_r=dB.stride(1),\n        expert_offsets_ptr=expert_offsets,\n        real_expert_offsets_ptr=real_expert_offsets,\n        M=sorted_scattered_idxs.size(0),\n        K=K,\n        N=N,\n        ACTUAL_R=R,\n        BLOCK_R=BLOCK_R,\n        scaling=scaling,\n        ACC_TYPE=tl.float32,\n        allow_tf32=ALLOW_TF32,\n        dy_grouped=dy_grouped,\n    )\n\n    return dA, dB\n"
  },
  {
    "path": "src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/ops.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n# Adapted from https://github.com/shawntan/scattermoe\n# Copyright (c) Shawn Tan and ScatterMoE Contributors\n# Licensed under the Apache License, Version 2.0\n# See https://github.com/shawntan/scattermoe/blob/main/LICENSE\n\nfrom typing import Optional\n\nimport torch\nimport triton\nimport triton.language as tl\n\nBLOCK_M = 128\nALLOW_TF32 = True\n\n\n@triton.jit\ndef _compute_expert_block(\n    E_idx,\n    E_mask,\n    M_in_idx,\n    N_block,\n    N_mask,\n    X_ptr,\n    stride_xm,\n    stride_xk,\n    W_ptr,\n    stride_we,\n    stride_wk,\n    stride_wn,\n    K,\n    acc,\n    no_k_mask,\n    BLOCK_K,\n    allow_tf32=True,\n):\n    K_block = tl.arange(0, BLOCK_K)\n    X_blk_ptrs = X_ptr + M_in_idx[:, None] * stride_xm + K_block[None, :] * stride_xk\n    W_blk_ptrs = (\n        W_ptr\n        + K_block[:, None] * stride_wk\n        + N_block[None, :] * stride_wn\n        + E_idx * stride_we\n    )\n    iters = tl.cdiv(K, BLOCK_K)\n\n    for K_block_id in range(iters):\n        if no_k_mask:\n            x = tl.load(X_blk_ptrs, mask=E_mask[:, None])\n            w = tl.load(W_blk_ptrs, mask=N_mask[None, :])\n        else:\n            K_mask = (K_block_id * BLOCK_K + K_block) < K\n            x = tl.load(X_blk_ptrs, mask=E_mask[:, None] & K_mask[None, :])\n            w = tl.load(W_blk_ptrs, mask=K_mask[:, None] & N_mask[None, :])\n\n        X_blk_ptrs += BLOCK_K * stride_xk\n        W_blk_ptrs += BLOCK_K * stride_wk\n        acc = tl.dot(x, w, acc, allow_tf32=allow_tf32)\n    return acc\n\n\ndef _scatter2scatter_configs():\n    return [\n        triton.Config({\"BLOCK_N\": 128, \"BLOCK_K\": 32}, num_stages=4, num_warps=4),\n    ]\n\n\n@triton.autotune(\n    configs=_scatter2scatter_configs(),\n    key=[\"M\", \"N\", \"K\"],\n)\n@triton.heuristics(\n    {\n        \"NO_K_MASK\": lambda args: (args[\"K\"] % args[\"BLOCK_K\"]) == 0,\n        \"NO_N_MASK\": lambda args: (args[\"N\"] % args[\"BLOCK_N\"]) == 0,\n    }\n)\n@triton.jit\ndef _scatter2scatter(\n    X_ptr,\n    stride_xm: tl.constexpr,\n    stride_xk: tl.constexpr,\n    W_ptr,\n    stride_we,\n    stride_wk: tl.constexpr,\n    stride_wn: tl.constexpr,\n    Y_ptr,\n    stride_ym: tl.constexpr,\n    stride_yn: tl.constexpr,\n    B_ptr,\n    stride_be: tl.constexpr,\n    stride_bn: tl.constexpr,\n    grouped_idx_ptr,\n    expert_idxs_ptr,\n    # block_start_idx_ptr,\n    FAN_OUT: tl.constexpr,\n    M,\n    K: tl.constexpr,\n    N: tl.constexpr,\n    E: tl.constexpr,\n    BLOCK_M: tl.constexpr,\n    BLOCK_N: tl.constexpr,\n    BLOCK_K: tl.constexpr,\n    ACC_TYPE: tl.constexpr,\n    # OUT_M,\n    allow_tf32: tl.constexpr,\n    x_grouped: tl.constexpr,\n    y_grouped: tl.constexpr,\n    NO_K_MASK: tl.constexpr,\n    NO_N_MASK: tl.constexpr,\n):\n    pid = tl.program_id(axis=0)\n\n    N_BLOCK_COUNT = tl.cdiv(N, BLOCK_N)\n    M_block_id = pid // N_BLOCK_COUNT\n    N_block_id = pid % N_BLOCK_COUNT\n\n    M_block = M_block_id * BLOCK_M + tl.arange(0, BLOCK_M)\n    N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)\n    N_mask = N_block < N\n    M_boundary_mask = M_block < (FAN_OUT * M)\n    E_idxs = tl.load(expert_idxs_ptr + M_block, mask=M_boundary_mask, other=E)\n\n    no_k_mask = K % BLOCK_K == 0\n\n    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)\n    E_first_idx = tl.min(E_idxs)\n    E_last_idx = tl.minimum(tl.max(E_idxs), E - 1)\n    M_idx = tl.load(grouped_idx_ptr + M_block, mask=M_boundary_mask).to(tl.int32)\n    for E_idx in range(E_first_idx, E_last_idx + 1):\n        E_mask = E_idxs == E_idx\n        E_M_idx = M_idx\n        if x_grouped:\n            M_in_idx = M_block\n        else:\n            M_in_idx = E_M_idx // FAN_OUT\n        acc = _compute_expert_block(\n            E_idx,\n            E_mask,\n            M_in_idx,\n            N_block,\n            N_mask,\n            X_ptr,\n            stride_xm,\n            stride_xk,\n            W_ptr,\n            stride_we,\n            stride_wk,\n            stride_wn,\n            K,\n            acc,\n            no_k_mask,\n            BLOCK_K,\n            allow_tf32=allow_tf32,\n        )\n\n    if B_ptr is not None:\n        B_blk_ptrs = B_ptr + E_idxs[:, None] * stride_be + N_block[None, :] * stride_bn\n        acc += tl.load(B_blk_ptrs, mask=M_boundary_mask[:, None] & N_mask[None, :])\n\n    if y_grouped:\n        M_out_idx = M_block\n    else:\n        M_out_idx = M_idx\n    Y_blk_ptrs = Y_ptr + (M_out_idx[:, None] * stride_ym + N_block[None, :] * stride_yn)\n    tl.store(Y_blk_ptrs, acc, mask=M_boundary_mask[:, None] & N_mask[None, :])\n\n\ndef scatter2scatter(\n    X,\n    W,\n    sorted_expert_idxs,\n    sorted_scattered_idxs,\n    k,\n    b=None,\n    x_grouped=False,\n    y_grouped=False,\n    out=None,\n):\n    assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0)\n    assert sorted_scattered_idxs.size(0) == X.size(0) * k\n    # Pre-kernel setup\n    y_dim = W.size(-1)\n    L_scattered = sorted_expert_idxs.size(0)\n    if out is None:\n        output = torch.empty((L_scattered, y_dim), device=X.device, dtype=X.dtype)\n    else:\n        assert out.size(0) == L_scattered and out.size(1) == y_dim\n        output = out\n\n    scatter2scatter_compileable(\n        output,\n        W,\n        X,\n        k,\n        sorted_expert_idxs,\n        sorted_scattered_idxs,\n        b,\n        x_grouped,\n        y_grouped,\n    )\n    return output\n\n\n@torch.library.custom_op(\"scattermoe::scatter2scatter\", mutates_args={\"output\"})\ndef scatter2scatter_compileable(\n    output: torch.Tensor,\n    W: torch.Tensor,\n    X: torch.Tensor,\n    k: int,\n    sorted_expert_idxs: torch.Tensor,\n    sorted_scattered_idxs: torch.Tensor,\n    b: Optional[torch.Tensor],\n    x_grouped: bool,\n    y_grouped: bool,\n) -> None:\n    def grid(META):\n        grid_num = (\n            triton.cdiv(sorted_expert_idxs.size(0), META[\"BLOCK_M\"])\n            * triton.cdiv(META[\"N\"], META[\"BLOCK_N\"]),\n        )\n        return grid_num\n\n    if b is None:\n        b = None\n        stride_be = stride_bn = 0\n    else:\n        stride_be, stride_bn = b.stride()\n\n    _scatter2scatter[grid](\n        # X_ptr, stride_xm, stride_xk,\n        X,\n        X.stride(0),\n        X.stride(1),\n        # W_ptr, stride_we, stride_wk, stride_wn,\n        W,\n        W.stride(0),\n        W.stride(1),\n        W.stride(2),\n        # Y_ptr, stride_ym, stride_yn,\n        output,\n        output.stride(0),\n        output.stride(1),\n        # B_ptr, stride_be, stride_bn\n        b,\n        stride_be,\n        stride_bn,\n        grouped_idx_ptr=sorted_scattered_idxs,\n        expert_idxs_ptr=sorted_expert_idxs,\n        # block_start_idx_ptr=padded_block_idxs,\n        FAN_OUT=k,\n        M=X.size(0),\n        K=X.size(1),\n        N=output.size(1),\n        E=W.size(0),\n        BLOCK_M=BLOCK_M,\n        ACC_TYPE=tl.float32,\n        allow_tf32=ALLOW_TF32,\n        x_grouped=x_grouped,\n        y_grouped=y_grouped,\n    )\n\n\ndef _config_XtY():\n    return [\n        triton.Config(\n            {\"BLOCK_N\": 128, \"BLOCK_K\": 128, \"BLOCK_M\": 32}, num_stages=4, num_warps=4\n        ),\n    ]\n\n\ndef group_bwd_W(DY, X, expert_offsets, E, has_bias=False):\n    DWt = torch.zeros((E, DY.size(-1), X.size(-1)), device=DY.device, dtype=DY.dtype)\n    DW = DWt.permute(0, 2, 1)\n    if has_bias:\n        Db = torch.zeros((E, DY.size(-1)), device=DY.device, dtype=DY.dtype)\n    else:\n        Db = None\n    groupXtY_compileable(E, DW, Db, DY, X, expert_offsets)\n    return DW, Db\n\n\n@torch.library.custom_op(\"scattermoe::groupXtY\", mutates_args={\"DW\", \"Db\"})\ndef groupXtY_compileable(\n    E: int,\n    DW: torch.Tensor,\n    Db: Optional[torch.Tensor],\n    DY: torch.Tensor,\n    X: torch.Tensor,\n    expert_offsets: torch.Tensor,\n) -> None:\n    def grid(META):\n        grid = (\n            E * triton.cdiv(META[\"K\"], META[\"BLOCK_K\"]),\n            triton.cdiv(META[\"N\"], META[\"BLOCK_N\"]),\n        )\n        return grid\n\n    if Db is None:\n        stride_dbe = 0\n        stride_dbn = 0\n    else:\n        stride_dbe, stride_dbn = Db.stride()\n\n    _groupXtY[grid](\n        # DY_ptr, stride_dym, stride_dyk,\n        DY,\n        DY.stride(0),\n        DY.stride(1),\n        # X_ptr, stride_xm, stride_xn,\n        X,\n        X.stride(0),\n        X.stride(1),\n        # DW_ptr, stride_dwe, stride_dwk, stride_dwn,\n        DW,\n        DW.stride(0),\n        DW.stride(1),\n        DW.stride(2),\n        # Db_ptr, stride_dwe, stride_dbn,\n        Db,\n        stride_dbe,\n        stride_dbn,\n        # expert_offsets_ptr,\n        expert_offsets,\n        # K: tl.constexpr, N: tl.constexpr,\n        M=DY.size(0),\n        N=DY.size(-1),\n        K=X.size(-1),\n        # ACC_TYPE: tl.constexpr,\n        ACC_TYPE=tl.float32,\n        allow_tf32=ALLOW_TF32,\n    )\n\n\n@triton.autotune(\n    configs=_config_XtY(),\n    key=[\"M\", \"N\", \"K\"],\n)\n@triton.heuristics(\n    {\n        \"NO_K_MASK\": lambda args: (args[\"K\"] % args[\"BLOCK_K\"]) == 0,\n        \"NO_N_MASK\": lambda args: (args[\"N\"] % args[\"BLOCK_N\"]) == 0,\n    }\n)\n@triton.jit\ndef _groupXtY(\n    DY_ptr,\n    stride_dym,\n    stride_dyk,\n    X_ptr,\n    stride_xm,\n    stride_xn,\n    DW_ptr,\n    stride_dwe,\n    stride_dwk,\n    stride_dwn,\n    Db_ptr,\n    stride_dbe,\n    stride_dbn,\n    expert_offsets_ptr,\n    M,\n    K: tl.constexpr,\n    N: tl.constexpr,\n    BLOCK_M: tl.constexpr,\n    BLOCK_N: tl.constexpr,\n    BLOCK_K: tl.constexpr,\n    ACC_TYPE: tl.constexpr,\n    allow_tf32: tl.constexpr,\n    NO_K_MASK: tl.constexpr,\n    NO_N_MASK: tl.constexpr,\n):\n    pid0 = tl.program_id(axis=0)\n    pid1 = tl.program_id(axis=1)\n    num0 = tl.num_programs(0)\n    num1 = tl.num_programs(1)\n    # pid1, pid0 = tl.swizzle2d(pid1, pid0, num1, num0, 128)\n    pid0, pid1 = tl.swizzle2d(pid0, pid1, num0, num1, 4)\n\n    K_BLOCK_COUNT = tl.cdiv(K, BLOCK_K)\n    E_idx = pid0 // K_BLOCK_COUNT\n    K_block_id = pid0 % K_BLOCK_COUNT\n    N_block_id = pid1\n\n    if E_idx == 0:\n        start_idx = 0\n    else:\n        start_idx = tl.load(expert_offsets_ptr + E_idx - 1).to(tl.int32)\n    end_idx = tl.load(expert_offsets_ptr + E_idx).to(tl.int32)\n\n    if end_idx > start_idx:\n        M_block = tl.max_contiguous(start_idx + tl.arange(0, BLOCK_M), BLOCK_M)\n\n        K_block = K_block_id * BLOCK_K + tl.arange(0, BLOCK_K)\n        K_mask = K_block < K\n        K_block = tl.max_contiguous(tl.multiple_of(K_block % K, BLOCK_K), BLOCK_K)\n\n        N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)\n        N_mask = N_block < N\n        N_block = tl.max_contiguous(tl.multiple_of(N_block % N, BLOCK_N), BLOCK_N)\n\n        M_idxs = M_block\n        xt_blk_ptrs = X_ptr + K_block[:, None] * stride_xn + M_idxs[None, :] * stride_xm\n        dy_blk_ptrs = (\n            DY_ptr + M_idxs[:, None] * stride_dym + N_block[None, :] * stride_dyk\n        )\n        if (Db_ptr is not None) and (K_block_id == 0):\n            _xty_and_bias(\n                E_idx,\n                start_idx,\n                end_idx,\n                M_block,\n                K_block,\n                K_mask,\n                N_block,\n                N_mask,\n                dy_blk_ptrs,\n                stride_dym,\n                xt_blk_ptrs,\n                stride_xm,\n                DW_ptr,\n                stride_dwe,\n                stride_dwk,\n                stride_dwn,\n                Db_ptr,\n                stride_dbe,\n                stride_dbn,\n                BLOCK_M,\n                BLOCK_N,\n                BLOCK_K,\n                ACC_TYPE,\n                allow_tf32,\n                NO_K_MASK,\n                NO_N_MASK,\n                compute_bias=True,\n            )\n        else:\n            _xty_and_bias(\n                E_idx,\n                start_idx,\n                end_idx,\n                M_block,\n                K_block,\n                K_mask,\n                N_block,\n                N_mask,\n                dy_blk_ptrs,\n                stride_dym,\n                xt_blk_ptrs,\n                stride_xm,\n                DW_ptr,\n                stride_dwe,\n                stride_dwk,\n                stride_dwn,\n                Db_ptr,\n                stride_dbe,\n                stride_dbn,\n                BLOCK_M,\n                BLOCK_N,\n                BLOCK_K,\n                ACC_TYPE,\n                allow_tf32,\n                NO_K_MASK,\n                NO_N_MASK,\n                compute_bias=False,\n            )\n\n\n@triton.jit\ndef _xty_and_bias(\n    E_idx,\n    start_idx,\n    end_idx,\n    M_block,\n    K_block,\n    K_mask,\n    N_block,\n    N_mask,\n    dy_blk_ptrs,\n    stride_dym,\n    xt_blk_ptrs,\n    stride_xm,\n    DW_ptr,\n    stride_dwe,\n    stride_dwk,\n    stride_dwn,\n    Db_ptr,\n    stride_dbe,\n    stride_dbn,\n    BLOCK_M,\n    BLOCK_N,\n    BLOCK_K,\n    ACC_TYPE,\n    allow_tf32,\n    NO_K_MASK,\n    NO_N_MASK,\n    compute_bias: tl.constexpr,\n):\n    if compute_bias:\n        db_acc = tl.zeros((BLOCK_N,), dtype=ACC_TYPE)\n    else:\n        db_acc = None\n\n    acc = tl.zeros((BLOCK_K, BLOCK_N), dtype=ACC_TYPE)\n    iters = tl.cdiv(end_idx - start_idx, BLOCK_M)\n    for i in range(0, iters):\n        M_mask = (i * BLOCK_M + M_block) < end_idx\n        if NO_K_MASK:\n            xt = tl.load(xt_blk_ptrs, mask=M_mask[None, :])\n        else:\n            xt = tl.load(xt_blk_ptrs, mask=K_mask[:, None] & M_mask[None, :])\n        if NO_N_MASK:\n            dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None])\n        else:\n            dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None] & N_mask[None, :])\n\n        acc += tl.dot(xt, dy, out_dtype=ACC_TYPE, allow_tf32=allow_tf32)\n\n        xt_blk_ptrs += BLOCK_M * stride_xm\n        dy_blk_ptrs += BLOCK_M * stride_dym\n\n        if compute_bias:\n            db_acc += tl.sum(dy, axis=0)\n\n    DW_blk_ptrs = (\n        DW_ptr\n        + E_idx * stride_dwe\n        + K_block[:, None] * stride_dwk\n        + N_block[None, :] * stride_dwn\n    )\n    acc = acc.to(DW_blk_ptrs.dtype.element_ty)\n    tl.store(DW_blk_ptrs, acc, mask=K_mask[:, None] & N_mask[None, :])\n    if compute_bias:\n        Db_blk_ptrs = Db_ptr + E_idx * stride_dbe + N_block * stride_dbn\n        tl.store(Db_blk_ptrs, db_acc, mask=N_mask)\n\n\ndef _config_grouping():\n    return [\n        triton.Config({\"BLOCK_N\": 256, \"BLOCK_K\": 128}, num_stages=4, num_warps=4),\n        # triton.Config({'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=4, num_warps=4),\n        # triton.Config({'BLOCK_N': 64, 'BLOCK_K': 32}, num_stages=4, num_warps=4),\n    ]\n\n\ndef group(A, sorted_expert_idxs, coeff=None, fan_out=1, out=None):\n    N = sorted_expert_idxs.size(0)\n    K = A.size(1)\n    assert A.size(0) * fan_out == N\n    if out is not None:\n        Y = out\n    else:\n        Y = torch.empty((N, K), dtype=A.dtype, device=A.device)\n    group_compileable(A, K, N, Y, coeff, coeff is not None, fan_out, sorted_expert_idxs)\n    return Y\n\n\n@torch.library.custom_op(\"scattermoe::group\", mutates_args={\"Y\"})\ndef group_compileable(\n    A: torch.Tensor,\n    K: int,\n    N: int,\n    Y: torch.Tensor,\n    coeff: Optional[torch.Tensor],\n    has_coeff: bool,\n    fan_out: int,\n    sorted_expert_idxs: torch.Tensor,\n) -> None:\n    def grid(META):\n        grid_num = (triton.cdiv(META[\"N\"], META[\"BLOCK_N\"]),)\n        return grid_num\n\n    _group[grid](\n        # A_ptr, stride_an, stride_ai,\n        A,\n        A.stride(0),\n        A.stride(1),\n        has_coeff,\n        coeff,\n        fan_out,\n        # Y_ptr, stride_yn, stride_yk,\n        Y,\n        Y.stride(0),\n        Y.stride(1),\n        # grouped_idx_ptr,\n        sorted_expert_idxs,\n        # N: tl.constexpr, K: tl.constexpr,\n        N,\n        K,\n    )\n\n\n@triton.autotune(configs=_config_grouping(), key=[\"K\"])\n@triton.heuristics({\"NO_K_MASK\": lambda args: (args[\"K\"] % args[\"BLOCK_K\"]) == 0})\n@triton.jit\ndef _group(\n    src_ptr,\n    stride_sn,\n    stride_sk,\n    has_coeff: tl.constexpr,\n    coeff_ptr,\n    FAN_OUT: tl.constexpr,\n    tgt_ptr,\n    stride_tn,\n    stride_ti,\n    grouped_idx_ptr,\n    N,\n    K: tl.constexpr,\n    BLOCK_N: tl.constexpr,\n    BLOCK_K: tl.constexpr,\n    NO_K_MASK: tl.constexpr,\n):\n    pid = tl.program_id(axis=0)\n\n    N_block_id = pid\n    N_blk = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)\n    N_mask = N_blk < N\n    N_blk = tl.max_contiguous(tl.multiple_of(N_blk % N, BLOCK_N), BLOCK_N)\n    N_idx = tl.load(grouped_idx_ptr + N_blk, mask=N_mask, other=0)\n\n    K_blk = tl.arange(0, BLOCK_K)\n    src_blk_ptrs = (\n        src_ptr + (N_idx // FAN_OUT)[:, None] * stride_sn + K_blk[None, :] * stride_sk\n    )\n    tgt_blk_ptrs = tgt_ptr + N_blk[:, None] * stride_tn + K_blk[None, :] * stride_ti\n\n    if has_coeff:\n        c = tl.load(coeff_ptr + N_idx, mask=N_mask)[:, None]\n\n    iters = tl.cdiv(K, BLOCK_K)\n    for i in range(0, iters):\n        if NO_K_MASK or i < iters - 1:\n            block = tl.load(src_blk_ptrs, mask=N_mask[:, None])\n            if has_coeff:\n                block *= c\n            tl.store(tgt_blk_ptrs, block, mask=N_mask[:, None])\n\n        else:\n            K_mask = (i * BLOCK_K + K_blk) < K\n            mask = N_mask[:, None] & K_mask[None, :]\n            block = tl.load(src_blk_ptrs, mask=mask)\n            if has_coeff:\n                block *= c\n            tl.store(tgt_blk_ptrs, block, mask=mask)\n        src_blk_ptrs += BLOCK_K * stride_sk\n        tgt_blk_ptrs += BLOCK_K * stride_ti\n"
  },
  {
    "path": "src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/single.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n# Adapted from https://github.com/shawntan/scattermoe\n# Copyright (c) Shawn Tan and ScatterMoE Contributors\n# Licensed under the Apache License, Version 2.0\n# See https://github.com/shawntan/scattermoe/blob/main/LICENSE\n\nimport torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef _single2scatter(\n    X_ptr,\n    stride_xm,\n    stride_xk,\n    W_ptr,\n    stride_we,\n    stride_wk,\n    stride_wn,\n    Y_ptr,\n    stride_ym,\n    stride_yn,\n    expert_idxs_ptr,\n    FAN_OUT: tl.constexpr,\n    K: tl.constexpr,\n    N: tl.constexpr,\n    E: tl.constexpr,\n    BLOCK_N: tl.constexpr,\n    BLOCK_K: tl.constexpr,\n    ACC_TYPE: tl.constexpr,\n):\n    pid0 = tl.program_id(axis=0)\n    pid1 = tl.program_id(axis=1)\n\n    N_block_id = pid0\n    if FAN_OUT == 1:\n        in_idx = pid1\n    else:\n        in_idx = 0\n    out_idx = pid1\n\n    K_block = tl.arange(0, BLOCK_K)\n    N_block = tl.max_contiguous(\n        tl.multiple_of((N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)) % N, BLOCK_N),\n        BLOCK_N,\n    )\n    E_idx = tl.load(expert_idxs_ptr + pid1)\n    X_blk_ptrs = X_ptr + in_idx * stride_xm + K_block[:, None] * stride_xk\n    W_blk_ptrs = (\n        W_ptr\n        + E_idx * stride_we\n        + K_block[:, None] * stride_wk\n        + N_block[None, :] * stride_wn\n    )\n    N_mask = N_block < N\n    acc = tl.zeros((1, BLOCK_N), dtype=ACC_TYPE)\n    for _K_block_id in range(0, tl.cdiv(K, BLOCK_K)):\n        K_mask = K_block < K\n        x = tl.load(X_blk_ptrs, mask=K_mask[:, None], other=0.0)\n        w = tl.load(W_blk_ptrs, mask=K_mask[:, None] & N_mask[None, :], other=0.0)\n        acc += tl.sum(x * w, axis=0)[None, :]\n        X_blk_ptrs += BLOCK_K * stride_xk\n        W_blk_ptrs += BLOCK_K * stride_wk\n        K_block += BLOCK_K\n    Y_blk_ptrs = Y_ptr + out_idx * stride_ym + N_block[None, :] * stride_yn\n    tl.store(Y_blk_ptrs, acc, mask=N_mask[None, :])\n\n\ndef single2scatter(X, W, expert_idxs):\n    E, xdim, ydim = W.size()\n    k = expert_idxs.size(1)\n    assert X.size(0) == k or X.size(0) == 1\n    Y = torch.empty((k, ydim), device=X.device, dtype=X.dtype)\n    BLOCK_N = 128\n    BLOCK_K = 128\n    grid = triton.cdiv(ydim, BLOCK_N), k\n    _single2scatter[grid](\n        X,\n        X.stride(0),\n        X.stride(1),\n        W,\n        W.stride(0),\n        W.stride(1),\n        W.stride(2),\n        Y,\n        Y.stride(0),\n        Y.stride(1),\n        expert_idxs,\n        FAN_OUT=Y.size(0) // X.size(0),\n        K=xdim,\n        N=ydim,\n        E=E,\n        BLOCK_N=BLOCK_N,\n        BLOCK_K=BLOCK_K,\n        ACC_TYPE=tl.float32,\n    )\n    return Y\n"
  },
  {
    "path": "src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n#\n# Original work Copyright (c) Shawn Tan and ScatterMoE Contributors\n# Adapted from https://github.com/shawntan/scattermoe\n# See https://github.com/shawntan/scattermoe/blob/main/LICENSE\n#\n# Modifications and LoRA adaptation Copyright (c) Axolotl AI\n# Licensed under the Apache License, Version 2.0\n\n\"\"\"\nScatterMoE layer replacements for HuggingFace MoE architectures.\n\nProvides drop-in forward replacements that use ScatterMoE kernels for\nacceleration. When used via the HF ``kernels`` library\n(``replace_kernel_forward_from_hub``), these classes replace the forward\nmethod of the original MoE block.\n\nLoRA support\n------------\nWhen peft wraps parameters via ``target_parameters``, the ``self.experts``\nsubmodule becomes a chain of ``ParamWrapper`` objects and the ``self.gate``\nrouter may also become a ``ParamWrapper``.  The ``HFScatterMoEGatedMLP``\nforward detects this and automatically:\n\n1. Unwraps ``self.gate`` to the base router, applying gate LoRA delta\n2. Unwraps ``self.experts`` to the base ``OlmoeExperts`` module\n3. Extracts LoRA A/B weights and scaling from each wrapper\n4. Converts B layout from peft rank-major to scattermoe expert-major\n5. Routes to ``parallel_linear_lora`` for fused LoRA computation\n6. Passes through ``self.shared_expert`` / ``self.shared_expert_gate``\n   (peft wraps their linear layers with standard LoRA, no special handling)\n\"\"\"\n\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nfrom .parallel_experts import flatten_sort_count, parallel_linear\nfrom .parallel_linear_lora import get_lora_params_from_wrapper, parallel_linear_lora\n\n# =============================================================================\n# LoRA layout conversion utilities (peft <-> scattermoe)\n# =============================================================================\n\n\ndef peft_lora_B_to_scattermoe(peft_B, num_experts, rank):\n    \"\"\"Convert peft rank-major lora_B ``[out, E*r]`` to scattermoe\n    expert-major ``[N, r*E]``.\n\n    peft reshapes B to ``[out, r, E]`` (rank-major).\n    scattermoe slices B as ``[:, e*r:(e+1)*r]`` (expert-major).\n    \"\"\"\n    N = peft_B.shape[0]\n    return (\n        peft_B.reshape(N, rank, num_experts)\n        .permute(0, 2, 1)\n        .contiguous()\n        .reshape(N, num_experts * rank)\n    )\n\n\ndef peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank):\n    \"\"\"Convert peft LoRA weights to scattermoe layout (with A<->B swap).\n\n    peft operates on the parameter in its native storage layout ``[E, dim1, dim2]``\n    where ``in_features=dim1, out_features=dim2``.  ScatterMoE transposes the\n    parameter (``W = param.transpose(2, 1)``) giving ``[E, dim2, dim1]`` with\n    ``K=dim2, N=dim1``.  Because of this transposition, peft's A and B roles\n    are swapped relative to scattermoe's convention.\n\n    peft gives:\n        lora_A ``[r*E, dim1]``, lora_B ``[dim2, r*E]``\n\n    scattermoe needs:\n        lora_A ``[r*E, K=dim2]``, lora_B ``[N=dim1, r*E]``\n\n    This function swaps A<->B and converts B from rank-major to expert-major.\n    Uses vectorized tensor operations (no Python loop over experts).\n\n    Works for **both** gate_up_proj and down_proj since the transposition\n    issue is the same for any parameter.\n    \"\"\"\n    peft_B_em = peft_lora_B_to_scattermoe(peft_B, num_experts, rank)\n\n    dim1 = peft_A.shape[1]  # peft in_features -> scattermoe N\n    dim2 = peft_B_em.shape[0]  # peft out_features -> scattermoe K\n\n    # smoe_A: per expert, transpose B_e [dim2, r] -> [r, dim2]\n    # [dim2, E*r] -> [dim2, E, r] -> [E, r, dim2] -> [E*r, dim2]\n    smoe_A = (\n        peft_B_em.reshape(dim2, num_experts, rank)\n        .permute(1, 2, 0)\n        .contiguous()\n        .reshape(rank * num_experts, dim2)\n    )\n\n    # smoe_B: per expert, transpose A_e [r, dim1] -> [dim1, r]\n    # [E*r, dim1] -> [E, r, dim1] -> [dim1, E, r] -> [dim1, E*r]\n    smoe_B = (\n        peft_A.reshape(num_experts, rank, dim1)\n        .permute(2, 0, 1)\n        .contiguous()\n        .reshape(dim1, num_experts * rank)\n    )\n\n    return smoe_A, smoe_B\n\n\ndef peft_down_proj_lora_to_scattermoe(peft_A, peft_B, num_experts, rank):\n    \"\"\"Deprecated alias for :func:`peft_lora_to_scattermoe`.\"\"\"\n    return peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank)\n\n\n# =============================================================================\n# ParamWrapper unwrapping\n# =============================================================================\n\n\ndef _unwrap_gate_lora(gate_module):\n    \"\"\"Unwrap peft ``ParamWrapper`` on the router gate.\n\n    When peft targets ``gate.weight``, ``self.gate`` becomes::\n\n        ParamWrapper(weight)\n          -> base_layer: OlmoeTopKRouter (the real module)\n\n    This function detects the wrapping and returns the base router, its\n    weight tensor, and an optional LoRA delta tensor.\n\n    Returns:\n        (base_gate, gate_weight, gate_lora_delta_or_None)\n\n        ``base_gate`` is the original router module (with ``.top_k``,\n        ``.num_experts``, ``.norm_topk_prob``).\n        ``gate_weight`` is the base router weight (may be a DTensor under FSDP).\n        ``gate_lora_delta_or_None`` is the LoRA delta tensor if LoRA is active,\n        else ``None``.  Kept separate to avoid mixing DTensor + Tensor in an add.\n    \"\"\"\n    if hasattr(gate_module, \"base_layer\") and hasattr(gate_module, \"lora_A\"):\n        base_gate = gate_module.base_layer\n        lora_A, lora_B, scaling = get_lora_params_from_wrapper(gate_module)\n        if lora_A is not None:\n            # gate weight: [num_experts, hidden_size]\n            # lora_A: [r, hidden_size], lora_B: [num_experts, r]\n            # delta = scaling * B @ A = [num_experts, hidden_size]\n            delta = scaling * (lora_B @ lora_A)\n            return base_gate, base_gate.weight, delta\n        else:\n            return base_gate, base_gate.weight, None\n    else:\n        # No wrapping — gate is the original module\n        return gate_module, gate_module.weight, None\n\n\ndef _convert_smoe_lora(lora_A, lora_B, num_experts, rank, scaling):\n    \"\"\"Convert peft LoRA weights to scattermoe layout.\"\"\"\n    smoe_A, smoe_B = peft_lora_to_scattermoe(lora_A, lora_B, num_experts, rank)\n    return (smoe_A, smoe_B, scaling)\n\n\ndef _unwrap_experts_lora(experts_module):\n    \"\"\"Walk a peft ``ParamWrapper`` chain on ``self.experts``.\n\n    When peft targets ``experts.gate_up_proj`` and ``experts.down_proj`` via\n    ``target_parameters``, ``self.experts`` becomes a nested chain::\n\n        ParamWrapper(down_proj)\n          -> base_layer: ParamWrapper(gate_up_proj)\n              -> base_layer: OlmoeExperts (the real module)\n\n    This function walks the chain, collects LoRA params keyed by\n    ``parameter_name``, and returns the base experts module.\n\n    Returns:\n        (base_experts, gup_lora, down_lora)\n\n        Each ``*_lora`` is either ``(smoe_A, smoe_B, scaling)`` or ``None``.\n        A/B are already in scattermoe layout.\n    \"\"\"\n    # Collect ParamWrapper layers by their parameter_name\n    wrappers = {}\n    module = experts_module\n    while hasattr(module, \"base_layer\") and hasattr(module, \"lora_A\"):\n        param_name = getattr(module, \"parameter_name\", None)\n        if param_name is not None:\n            wrappers[param_name] = module\n        module = module.base_layer\n\n    base_experts = module\n\n    if not wrappers:\n        return base_experts, None, None\n\n    # Determine num_experts from base module\n    num_experts = getattr(base_experts, \"num_experts\", None)\n    if num_experts is None:\n        # Fallback: infer from parameter shape\n        gup = getattr(base_experts, \"gate_up_proj\", None)\n        if gup is not None:\n            num_experts = gup.shape[0]\n\n    # Extract gate_up_proj LoRA (needs A<->B swap due to transposition)\n    gup_lora = None\n    gup_wrapper = wrappers.get(\"gate_up_proj\")\n    if gup_wrapper is not None:\n        lora_A, lora_B, scaling = get_lora_params_from_wrapper(gup_wrapper)\n        if lora_A is not None:\n            rank = lora_A.shape[0] // num_experts\n            gup_lora = _convert_smoe_lora(lora_A, lora_B, num_experts, rank, scaling)\n\n    # Extract down_proj LoRA (needs A<->B swap due to transposition)\n    down_lora = None\n    down_wrapper = wrappers.get(\"down_proj\")\n    if down_wrapper is not None:\n        lora_A, lora_B, scaling = get_lora_params_from_wrapper(down_wrapper)\n        if lora_A is not None:\n            rank = lora_A.shape[0] // num_experts\n            down_lora = _convert_smoe_lora(lora_A, lora_B, num_experts, rank, scaling)\n\n    return base_experts, gup_lora, down_lora\n\n\n# =============================================================================\n# Routing helpers\n# =============================================================================\n\n\ndef _softmax_topk_route(\n    moe_block, base_gate, hidden_states, gate_weight, gate_lora_delta\n):\n    \"\"\"Softmax→topk routing (Qwen, OLMoE, Mixtral, MiniMax).\n\n    Returns:\n        (routing_weights [T, K], selected_experts [T, K], top_k, num_experts)\n    \"\"\"\n    router_logits = F.linear(hidden_states, gate_weight)\n    if gate_lora_delta is not None:\n        router_logits = router_logits + F.linear(hidden_states, gate_lora_delta)\n    routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float32)\n\n    top_k = base_gate.top_k\n    num_experts = base_gate.num_experts\n    routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1)\n\n    if getattr(base_gate, \"norm_topk_prob\", True):\n        routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)\n\n    return routing_weights, selected_experts, top_k, num_experts\n\n\ndef _sigmoid_topk_route(\n    moe_block, base_gate, hidden_states, gate_weight, gate_lora_delta\n):\n    \"\"\"Sigmoid→topk routing (GLM, DeepSeek V3, MiniMax M2).\n\n    Supports:\n    - ``e_score_correction_bias`` on gate or moe_block\n    - Group-based expert selection when ``n_group > 1``\n    - ``routed_scaling_factor`` applied to final weights\n    - Final weights gathered from original sigmoid probs (not bias-corrected)\n\n    Returns:\n        (routing_weights [T, K], selected_experts [T, K], top_k, num_experts)\n    \"\"\"\n    router_logits = F.linear(hidden_states.float(), gate_weight.float())\n    if gate_lora_delta is not None:\n        router_logits = router_logits + F.linear(\n            hidden_states.float(), gate_lora_delta.float()\n        )\n    router_probs = router_logits.sigmoid()  # [T, E]\n\n    top_k = getattr(moe_block, \"top_k\", getattr(base_gate, \"top_k\", None))\n    num_experts = getattr(moe_block, \"n_routed_experts\", gate_weight.shape[0])\n\n    # Bias-corrected scores for expert selection (not used for final weights).\n    # glm_moe_dsa/deepseek_v3 store the bias on gate; minimax_m2 on the block.\n    e_score_correction_bias = getattr(base_gate, \"e_score_correction_bias\", None)\n    if e_score_correction_bias is None:\n        e_score_correction_bias = getattr(moe_block, \"e_score_correction_bias\", None)\n    if e_score_correction_bias is not None:\n        scores_for_choice = router_probs + e_score_correction_bias\n    else:\n        scores_for_choice = router_probs\n\n    # Group-based selection: pick top groups, mask the rest\n    n_group = getattr(moe_block, \"n_group\", 1)\n    if n_group > 1:\n        group_scores = (\n            scores_for_choice.view(-1, n_group, num_experts // n_group)\n            .topk(2, dim=-1)[0]\n            .sum(dim=-1)\n        )  # [T, n_group]\n        topk_group = getattr(moe_block, \"topk_group\", n_group)\n        group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[1]\n        group_mask = torch.zeros_like(group_scores)\n        group_mask.scatter_(1, group_idx, 1)\n        score_mask = (\n            group_mask.unsqueeze(-1)\n            .expand(-1, n_group, num_experts // n_group)\n            .reshape(-1, num_experts)\n        )\n        scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)\n\n    # Final topk from (possibly masked) scores\n    topk_indices = torch.topk(scores_for_choice, k=top_k, dim=-1, sorted=False)[1]\n\n    # Gather weights from original sigmoid scores (not bias-corrected)\n    topk_weights = router_probs.gather(1, topk_indices)\n\n    # Optional renormalization + scaling\n    if getattr(moe_block, \"norm_topk_prob\", True):\n        topk_weights = topk_weights / (topk_weights.sum(dim=-1, keepdim=True) + 1e-20)\n    routed_scaling_factor = getattr(moe_block, \"routed_scaling_factor\", 1.0)\n    topk_weights = topk_weights * routed_scaling_factor\n\n    return topk_weights, topk_indices, top_k, num_experts\n\n\ndef _route(moe_block, base_gate, hidden_states, gate_weight, gate_lora_delta):\n    \"\"\"Dispatch to the correct routing strategy based on block attributes.\n\n    Detects sigmoid routing by the presence of ``e_score_correction_bias``\n    on either the gate or the moe_block.\n    \"\"\"\n    has_sigmoid = (\n        getattr(base_gate, \"e_score_correction_bias\", None) is not None\n        or getattr(moe_block, \"e_score_correction_bias\", None) is not None\n    )\n    if has_sigmoid:\n        return _sigmoid_topk_route(\n            moe_block, base_gate, hidden_states, gate_weight, gate_lora_delta\n        )\n    return _softmax_topk_route(\n        moe_block, base_gate, hidden_states, gate_weight, gate_lora_delta\n    )\n\n\n# =============================================================================\n# Shared expert helpers\n# =============================================================================\n\n\ndef _compute_shared_expert(moe_block, hidden_states_flat):\n    \"\"\"Compute shared expert output if the block has one.\n\n    Handles singular (qwen2_moe: ``shared_expert``), plural\n    (glm_moe_dsa/deepseek_v3: ``shared_experts``), and MLP\n    (hunyuan_v1_moe: ``shared_mlp``) attribute names.\n\n    peft wraps individual linear layers inside the shared expert with\n    standard LoRA — calling forward() handles this transparently.\n    \"\"\"\n    shared_expert = (\n        getattr(moe_block, \"shared_expert\", None)\n        or getattr(moe_block, \"shared_experts\", None)\n        or getattr(moe_block, \"shared_mlp\", None)\n    )\n    if shared_expert is None:\n        return None\n\n    shared_expert_output = shared_expert(hidden_states_flat)\n\n    # Optional sigmoid gate (Qwen2MoE pattern).\n    # shared_expert_gate may also be peft-wrapped (standard LoRA\n    # on nn.Linear), its forward() applies LoRA automatically.\n    shared_expert_gate = getattr(moe_block, \"shared_expert_gate\", None)\n    if shared_expert_gate is not None:\n        shared_expert_output = (\n            F.sigmoid(shared_expert_gate(hidden_states_flat)) * shared_expert_output\n        )\n\n    return shared_expert_output\n\n\n# =============================================================================\n# Layer classes\n# =============================================================================\n\n\nclass ScatterMoEGatedMLP(nn.Module):\n    def forward(self, layer_input):\n        \"\"\"\n        Forward pass of the mixture of experts layer.\n\n        Args:\n            layer_input (Tensor):\n                Input tensor.\n\n        Returns:\n            Tensor:\n                Output tensor.\n        \"\"\"\n        bsz, length, emb_size = layer_input.size()\n        layer_input = layer_input.reshape(-1, emb_size)\n        # compute the top_k routing decision\n        router_logits = self.router.layer(layer_input)\n        routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)\n        routing_weights, selected_experts = torch.topk(\n            routing_weights, self.router.top_k, dim=-1\n        )\n        routing_weights /= routing_weights.sum(dim=-1, keepdim=True)\n        routing_weights = routing_weights.to(layer_input.dtype)\n        sorted_expert_idxs, sorted_scattered_idxs, expert_offsets = flatten_sort_count(\n            selected_experts, num_experts=self.router.num_experts\n        )\n\n        # compute experts\n        gates, h = parallel_linear(\n            layer_input,\n            self.input_linear.weight.transpose(2, 1),\n            self.router.top_k,\n            sorted_expert_idxs,\n            sorted_scattered_idxs,\n            expert_offsets,\n            grouped_in=False,\n            grouped_out=True,\n        ).chunk(2, dim=-1)\n        h = self.activation(gates) * h\n        layer_output = parallel_linear(\n            h,\n            self.output_linear.weight.transpose(2, 1),\n            1,\n            sorted_expert_idxs,\n            sorted_scattered_idxs,\n            expert_offsets,\n            grouped_in=True,\n            grouped_out=False,\n            gates=routing_weights,\n        )\n        layer_output = layer_output.view(bsz, length, emb_size)\n        return layer_output\n\n\nclass HFScatterMoEGatedMLP(nn.Module):\n    \"\"\"\n    ScatterMoE-accelerated forward pass for HF MoEs.\n\n    Used as a kernel layer via the HF ``kernels`` library.  The ``forward``\n    method replaces the original SparseMoeBlock.forward.\n\n    Supports:\n\n    * **Softmax→topk routing**: OLMoE, Qwen2/3MoE, Mixtral, MiniMax\n    * **Sigmoid→topk routing**: GLM, DeepSeek V3, MiniMax M2\n    * **Full-parameter training**: uses ``parallel_linear`` (base ScatterMoE)\n    * **LoRA fine-tuning**: detects peft ``ParamWrapper`` on ``self.experts``,\n      extracts adapter weights, and uses ``parallel_linear_lora`` (fused kernel)\n    \"\"\"\n\n    @staticmethod\n    def forward(self: nn.Module, layer_input: torch.Tensor):\n        \"\"\"\n        Forward pass using ScatterMoE kernels.\n\n        Args:\n            self: The MoeSparseMoeBlock module containing:\n                - self.gate: Router (or peft ParamWrapper wrapping it)\n                - self.experts: Experts module (or peft ParamWrapper chain)\n                - self.shared_expert(s): Optional shared expert\n                - self.shared_expert_gate: Optional shared expert gate\n            layer_input: Input tensor [batch_size, seq_len, hidden_size]\n\n        Returns:\n            Tensor: [batch_size, seq_len, hidden_size]\n        \"\"\"\n        batch_size, sequence_length, hidden_dim = layer_input.shape\n        hidden_states_flat = layer_input.view(-1, hidden_dim)\n\n        # ====================================================================\n        # Shared Expert (if present, e.g. Qwen2MoE, DeepSeek V3)\n        # ====================================================================\n        shared_expert_output = _compute_shared_expert(self, hidden_states_flat)\n\n        # ====================================================================\n        # Router Computation (with optional gate LoRA)\n        # ====================================================================\n        base_gate, gate_weight, gate_lora_delta = _unwrap_gate_lora(self.gate)\n        routing_weights, selected_experts, top_k, num_experts = _route(\n            self, base_gate, hidden_states_flat, gate_weight, gate_lora_delta\n        )\n        routing_weights = routing_weights.to(hidden_states_flat.dtype)\n\n        sorted_expert_idxs, sorted_scattered_idxs, expert_offsets = flatten_sort_count(\n            selected_experts, num_experts=num_experts\n        )\n\n        # ====================================================================\n        # Detect LoRA (peft ParamWrapper) and extract adapter weights\n        # ====================================================================\n        experts, gup_lora, down_lora = _unwrap_experts_lora(self.experts)\n\n        # ====================================================================\n        # Selective expert weight dequantization\n        # ====================================================================\n        # When experts are BnB-quantized (quantize_moe_experts), dequantize\n        # only the active experts instead of all E. This saves ~97% memory\n        # for the transient dequant buffer when few experts are active.\n        use_selective = (\n            getattr(self, \"_use_selective_dequant\", False)\n            and hasattr(experts, \"parametrizations\")\n            and \"gate_up_proj\" in experts.parametrizations\n        )\n\n        if use_selective:\n            from axolotl.integrations.kernels.libs.scattermoe_lora.selective_dequant import (\n                get_active_experts,\n                remap_expert_indices,\n                selective_expert_weights,\n                selective_lora_weights,\n            )\n\n            active_experts = get_active_experts(sorted_expert_idxs, num_experts)\n            remapped_expert_idxs, compact_offsets = remap_expert_indices(\n                sorted_expert_idxs,\n                expert_offsets,\n                active_experts,\n                num_experts,\n            )\n            # Dequantize only active experts' weights\n            gate_up_W = selective_expert_weights(\n                experts,\n                \"gate_up_proj\",\n                active_experts,\n            ).transpose(2, 1)  # [num_active, hidden, 2*inter]\n\n            # Remap LoRA weights to match compact expert indices\n            if gup_lora is not None:\n                gup_A, gup_B, gup_scaling = gup_lora\n                gup_A, gup_B = selective_lora_weights(\n                    gup_A,\n                    gup_B,\n                    active_experts,\n                    num_experts,\n                )\n                gup_lora = (gup_A, gup_B, gup_scaling)\n\n            # Use remapped indices for ScatterMoE kernels\n            sei_gup = remapped_expert_idxs\n            eo_gup = compact_offsets\n        else:\n            gate_up_W = experts.gate_up_proj.transpose(2, 1)  # [E, hidden, 2*inter]\n            sei_gup = sorted_expert_idxs\n            eo_gup = expert_offsets\n\n        # ====================================================================\n        # Gate + Up projection\n        # ====================================================================\n        if gup_lora is not None:\n            gup_A, gup_B, gup_scaling = gup_lora\n            gup = parallel_linear_lora(\n                hidden_states_flat,\n                gate_up_W,\n                top_k,\n                sei_gup,\n                sorted_scattered_idxs,\n                eo_gup,\n                lora_A=gup_A,\n                lora_B=gup_B,\n                scaling=gup_scaling,\n                grouped_in=False,\n                grouped_out=True,\n                use_fused_dX=True,\n                use_fused_gather=True,\n            )\n        else:\n            gup = parallel_linear(\n                hidden_states_flat,\n                gate_up_W,\n                top_k,\n                sei_gup,\n                sorted_scattered_idxs,\n                eo_gup,\n                grouped_in=False,\n                grouped_out=True,\n            )\n\n        gates, h = gup.chunk(2, dim=-1)\n        h = experts.act_fn(gates) * h\n\n        # ====================================================================\n        # Down projection\n        # ====================================================================\n        if use_selective:\n            down_W = selective_expert_weights(\n                experts,\n                \"down_proj\",\n                active_experts,\n            ).transpose(2, 1)  # [num_active, inter, hidden]\n\n            if down_lora is not None:\n                down_A, down_B, down_scaling = down_lora\n                down_A, down_B = selective_lora_weights(\n                    down_A,\n                    down_B,\n                    active_experts,\n                    num_experts,\n                )\n                down_lora = (down_A, down_B, down_scaling)\n\n            sei_down = remapped_expert_idxs\n            eo_down = compact_offsets\n        else:\n            down_W = experts.down_proj.transpose(2, 1)  # [E, inter, hidden]\n            sei_down = sorted_expert_idxs\n            eo_down = expert_offsets\n\n        if down_lora is not None:\n            down_A, down_B, down_scaling = down_lora\n            expert_output = parallel_linear_lora(\n                h,\n                down_W,\n                1,\n                sei_down,\n                sorted_scattered_idxs,\n                eo_down,\n                lora_A=down_A,\n                lora_B=down_B,\n                scaling=down_scaling,\n                gates=routing_weights,\n                grouped_in=True,\n                grouped_out=False,\n                use_fused_dX=True,\n                use_fused_gather=True,\n            )\n        else:\n            expert_output = parallel_linear(\n                h,\n                down_W,\n                1,\n                sei_down,\n                sorted_scattered_idxs,\n                eo_down,\n                grouped_in=True,\n                grouped_out=False,\n                gates=routing_weights,\n            )\n\n        # ====================================================================\n        # Combine with shared expert and reshape\n        # ====================================================================\n        if shared_expert_output is not None:\n            expert_output = expert_output + shared_expert_output\n\n        expert_output = expert_output.view(batch_size, sequence_length, hidden_dim)\n        return expert_output\n"
  },
  {
    "path": "src/axolotl/integrations/kernels/libs/scattermoe_lora/lora_ops.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n# Copyright (c) Axolotl AI\n# Licensed under the Apache License, Version 2.0\n\n\"\"\"\nParallelExperts module with LoRA support.\n\nProvides a drop-in replacement for ScatterMoE's ParallelExperts that\nuses the fused LoRA kernel when adapter weights are attached.\n\"\"\"\n\nfrom typing import Optional\n\nimport torch\nimport torch.nn as nn\n\nfrom .parallel_linear_lora import parallel_linear_lora\n\n\nclass ParallelExperts(nn.Module):\n    \"\"\"\n    Parallel Experts with fused LoRA support.\n\n    Drop-in replacement for the original ParallelExperts. When LoRA parameters\n    are attached via set_lora(), the forward pass uses a fused kernel:\n        Y = X @ W + scaling * (X @ A^T) @ B^T\n    \"\"\"\n\n    def __init__(\n        self,\n        num_experts: int,\n        input_size: int,\n        output_size: int,\n        bias: bool = False,\n    ) -> None:\n        super().__init__()\n        self.weight = nn.Parameter(torch.empty(num_experts, output_size, input_size))\n        if bias:\n            self.bias = nn.Parameter(torch.empty(num_experts, output_size))\n        else:\n            self.bias = None\n        self.num_experts = num_experts\n        self.input_size = input_size\n        self.output_size = output_size\n        self._lora_A: torch.Tensor | None = None\n        self._lora_B: torch.Tensor | None = None\n        self._lora_scaling: float | None = None\n        self.reset_parameters()\n\n    def reset_parameters(self) -> None:\n        nn.init.normal_(self.weight, std=0.02)\n        if self.bias is not None:\n            nn.init.zeros_(self.bias)\n\n    def extra_repr(self) -> str:\n        return (\n            f\"num_experts={self.num_experts}, \"\n            f\"input_size={self.input_size}, \"\n            f\"output_size={self.output_size}\"\n        )\n\n    def set_lora(self, lora_A: torch.Tensor, lora_B: torch.Tensor, scaling: float):\n        \"\"\"Attach LoRA parameters for fused computation.\"\"\"\n        self._lora_A = lora_A\n        self._lora_B = lora_B\n        self._lora_scaling = scaling\n\n    def clear_lora(self):\n        \"\"\"Remove LoRA parameters.\"\"\"\n        self._lora_A = None\n        self._lora_B = None\n        self._lora_scaling = None\n\n    def forward(\n        self,\n        inputs: torch.Tensor,\n        k: int,\n        sorted_expert_idxs: torch.Tensor,\n        sorted_scattered_idxs: torch.Tensor,\n        expert_offsets: torch.Tensor,\n        gates: Optional[torch.Tensor] = None,\n        grouped_in: bool = False,\n        grouped_out: bool = False,\n    ) -> torch.Tensor:\n        return parallel_linear_lora(\n            inputs,\n            self.weight.permute(0, 2, 1),  # [E, input, output]\n            k,\n            sorted_expert_idxs,\n            sorted_scattered_idxs,\n            expert_offsets,\n            lora_A=self._lora_A,\n            lora_B=self._lora_B,\n            scaling=self._lora_scaling if self._lora_scaling is not None else 1.0,\n            expert_biases=self.bias,\n            gates=gates,\n            grouped_in=grouped_in,\n            grouped_out=grouped_out,\n        )\n"
  },
  {
    "path": "src/axolotl/integrations/kernels/libs/scattermoe_lora/parallel_experts.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n# Adapted from https://github.com/shawntan/scattermoe\n# Copyright (c) Shawn Tan and ScatterMoE Contributors\n# Licensed under the Apache License, Version 2.0\n# See https://github.com/shawntan/scattermoe/blob/main/LICENSE\n\nfrom typing import Optional\n\nimport torch\nimport torch.nn as nn\n\nfrom . import kernels\n\n\n@torch.library.custom_op(\"scattermoe::bincount\", mutates_args={})\ndef compileable_bincount(x: torch.Tensor, minlength: int) -> torch.Tensor:\n    return x.bincount(minlength=minlength)\n\n\n@compileable_bincount.register_fake\ndef _(x: torch.Tensor, minlength: int) -> torch.Tensor:\n    return torch.empty(minlength, dtype=torch.long, device=x.device)\n\n\n@torch.compile\ndef flatten_sort_count(expert_idxs: torch.Tensor, num_experts: int):\n    with torch.no_grad():\n        flattened_expert_idxs = expert_idxs.flatten()\n        sorted_expert_idxs, sorted_scattered_idxs = torch.sort(flattened_expert_idxs)\n        expert_counts = compileable_bincount(\n            flattened_expert_idxs, minlength=num_experts\n        )\n        expert_offsets = expert_counts.cumsum(-1)\n        return sorted_expert_idxs, sorted_scattered_idxs, expert_offsets\n\n\nclass ParallelLinear(torch.autograd.Function):\n    @staticmethod\n    def forward(\n        ctx,\n        x: torch.Tensor,\n        expert_weights: torch.Tensor,\n        k: int,\n        sorted_expert_idxs: torch.Tensor,\n        sorted_scattered_idxs: torch.Tensor,\n        expert_offsets: torch.Tensor,\n        expert_biases: Optional[torch.Tensor] = None,\n        gates: Optional[torch.Tensor] = None,\n        grouped_in: bool = False,\n        grouped_out: bool = False,\n    ):\n        with torch.device(x.device):\n            output = kernels.ops.scatter2scatter(\n                X=x,\n                W=expert_weights,\n                b=expert_biases,\n                k=k,\n                sorted_expert_idxs=sorted_expert_idxs,\n                sorted_scattered_idxs=sorted_scattered_idxs,\n                x_grouped=grouped_in,\n                y_grouped=grouped_out,\n            )\n            if gates is not None:\n                output_expanded = output.view(\n                    gates.size(0), gates.size(1), output.size(-1)\n                )\n                output = (gates.unsqueeze(1) @ output_expanded).squeeze(1)\n            else:\n                output_expanded = None\n\n            ctx.save_for_backward(\n                x,\n                expert_weights,\n                expert_biases,\n                sorted_expert_idxs,\n                sorted_scattered_idxs,\n                expert_offsets,\n                gates,\n                output_expanded,\n            )\n            ctx.grouped_in = grouped_in\n            ctx.grouped_out = grouped_out\n            ctx.k = k\n        return output\n\n    @staticmethod\n    def backward(ctx, grad_out: torch.Tensor):\n        with torch.device(grad_out.device):\n            (\n                x,\n                expert_weights,\n                expert_biases,\n                sorted_expert_idxs,\n                sorted_scattered_idxs,\n                expert_offsets,\n                gates,\n                output_expanded,\n            ) = ctx.saved_tensors\n            k = ctx.k\n            grouped_in = ctx.grouped_in\n            grouped_out = ctx.grouped_out\n\n            if gates is not None:\n                # calculate gates gradient\n                # d_gates = torch.bmm(output_expanded, grad_out[:, :, None]).squeeze(-1)\n                d_gates = (output_expanded @ grad_out.unsqueeze(-1)).squeeze(-1)\n                gates_flat = gates.flatten()\n                gate_fan = gates.size(1)\n                grouped_grad_out = output_expanded.flatten(\n                    0, 1\n                )  # reuse expanded buffer later\n            else:\n                d_gates = None\n                gates_flat = None\n                gate_fan = 1\n                grouped_grad_out = None\n\n            if grouped_out:\n                grouped_grad_out = grad_out\n            else:\n                grouped_grad_out = kernels.ops.group(\n                    grad_out,\n                    sorted_scattered_idxs,\n                    fan_out=gate_fan,\n                    coeff=gates_flat,\n                    out=grouped_grad_out,\n                )\n            if grouped_in:\n                grouped_x = x\n                d_expanded_input = None\n            else:\n                grouped_x = kernels.ops.group(x, sorted_scattered_idxs, fan_out=k)\n                d_expanded_input = grouped_x\n\n            d_weights, d_biases = kernels.ops.group_bwd_W(\n                DY=grouped_grad_out,\n                X=grouped_x,\n                expert_offsets=expert_offsets,\n                E=expert_weights.size(0),\n                has_bias=expert_biases is not None,\n            )\n\n            d_expanded_input = kernels.ops.scatter2scatter(\n                X=grouped_grad_out,\n                x_grouped=True,\n                W=expert_weights.permute(0, 2, 1),\n                sorted_expert_idxs=sorted_expert_idxs,\n                sorted_scattered_idxs=sorted_scattered_idxs,\n                k=1,\n                y_grouped=grouped_in,\n                out=d_expanded_input,  # Reuse grouped_x buffer\n            )\n\n            if k == 1:\n                d_input = d_expanded_input\n            else:\n                d_input = d_expanded_input.view(\n                    x.size(0), k, d_expanded_input.size(-1)\n                ).sum(-2)\n        return (\n            # x, expert_weights,\n            d_input,\n            d_weights,\n            # k, sorted_expert_idxs, sorted_scattered_idxs, expert_offsets,\n            None,\n            None,\n            None,\n            None,\n            # bias, gates\n            d_biases,\n            d_gates,\n            # grouped_in, grouped_out,\n            None,\n            None,\n        )\n\n\ndef parallel_linear(\n    inputs,\n    expert_weights,\n    k,\n    sorted_expert_idxs,\n    sorted_scattered_idxs,\n    expert_offsets,\n    expert_biases=None,\n    gates=None,\n    grouped_in=False,\n    grouped_out=False,\n):\n    results = ParallelLinear.apply(\n        inputs,\n        expert_weights,\n        k,\n        sorted_expert_idxs,\n        sorted_scattered_idxs,\n        expert_offsets,\n        expert_biases,\n        gates,\n        grouped_in,\n        grouped_out,\n    )\n    return results\n\n\nclass ParallelExperts(nn.Module):\n    def __init__(self, num_experts, input_size, output_size, bias=False) -> None:\n        super().__init__()\n        self.weight = nn.Parameter(torch.empty(num_experts, output_size, input_size))\n\n        if bias:\n            self.bias = nn.Parameter(torch.empty(num_experts, output_size))\n        else:\n            self.bias = None\n\n        self.num_experts = num_experts\n        self.input_size = input_size\n        self.output_size = output_size\n        self.reset_parameters()\n\n    def extra_repr(self):\n        return \"num_experts={}, input_size={}, output_size={}\".format(\n            self.num_experts, self.input_size, self.output_size\n        )\n\n    def reset_parameters(self) -> None:\n        nn.init.normal_(self.weight, std=0.02)\n        if self.bias is not None:\n            nn.init.zeros_(self.bias)\n\n    def forward(\n        self,\n        inputs,\n        k,\n        sorted_expert_idxs,\n        sorted_scattered_idxs,\n        expert_offsets,\n        gates=None,\n        grouped_in=False,\n        grouped_out=False,\n    ):\n        results = parallel_linear(\n            inputs,\n            self.weight.permute(0, 2, 1),\n            k,\n            sorted_expert_idxs,\n            sorted_scattered_idxs,\n            expert_offsets,\n            expert_biases=self.bias,\n            gates=gates,\n            grouped_in=grouped_in,\n            grouped_out=grouped_out,\n        )\n        return results\n"
  },
  {
    "path": "src/axolotl/integrations/kernels/libs/scattermoe_lora/parallel_linear_lora.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n# Copyright (c) Axolotl AI\n# Licensed under the Apache License, Version 2.0\n\n\"\"\"\nScatterMoE + LoRA Autograd Function\n====================================\n\nProvides the autograd function and Python interface for fused ScatterMoE + LoRA.\n\nKey design for LoRA training:\n  - Expert weights W are FROZEN (no gradient computed for W).\n  - Only LoRA adapter weights (A, B) receive gradients.\n  - The input gradient dX is still computed (needed for upstream layers).\n  - This avoids the expensive group_bwd_W computation entirely.\n\nForward:\n  Y = X @ W + scaling * (X @ A^T) @ B^T\n\nBackward (W frozen):\n  dX = dY @ W^T + scaling * (dY @ B) @ A          (via scatter2scatter for base, separate for LoRA)\n  dA = scaling * (dY @ B)^T @ X                     (per-expert, on grouped data)\n  dB = scaling * dY^T @ (X @ A^T)                   (per-expert, on grouped data)\n\"\"\"\n\nfrom typing import Optional\n\nimport torch\n\nfrom .kernels import ops as base_ops\nfrom .kernels.lora_ops import (\n    group_bwd_lora,\n    group_bwd_lora_fused,\n    scatter2scatter_lora,\n    scatter2scatter_lora_dX,\n)\n\n\nclass ScatterMoELoRA(torch.autograd.Function):\n    \"\"\"\n    Autograd function for fused ScatterMoE + LoRA with frozen expert weights.\n\n    This function is optimized for the LoRA fine-tuning scenario where:\n    - Expert weights W are frozen (requires_grad=False)\n    - Only LoRA A and B matrices receive gradients\n    - Input gradients are computed for upstream layer backprop\n    \"\"\"\n\n    @staticmethod\n    def forward(\n        ctx,\n        x: torch.Tensor,\n        expert_weights: torch.Tensor,\n        k: int,\n        sorted_expert_idxs: torch.Tensor,\n        sorted_scattered_idxs: torch.Tensor,\n        expert_offsets: torch.Tensor,\n        lora_A: torch.Tensor,\n        lora_B: torch.Tensor,\n        scaling: float,\n        expert_biases: Optional[torch.Tensor] = None,\n        gates: Optional[torch.Tensor] = None,\n        grouped_in: bool = False,\n        grouped_out: bool = False,\n        use_fused_dX: bool = False,\n        use_fused_gather: bool = False,\n    ):\n        with torch.device(x.device):\n            # Fused forward: Y = X @ W + scaling * (X @ A^T) @ B^T\n            output = scatter2scatter_lora(\n                X=x,\n                W=expert_weights,\n                sorted_expert_idxs=sorted_expert_idxs,\n                sorted_scattered_idxs=sorted_scattered_idxs,\n                k=k,\n                lora_A=lora_A,\n                lora_B=lora_B,\n                scaling=scaling,\n                b=expert_biases,\n                x_grouped=grouped_in,\n                y_grouped=grouped_out,\n            )\n\n            # Handle gating (weighted combination of top-k expert outputs)\n            if gates is not None:\n                output_expanded = output.view(\n                    gates.size(0), gates.size(1), output.size(-1)\n                )\n                output = (gates.unsqueeze(1) @ output_expanded).squeeze(1)\n            else:\n                output_expanded = None\n\n            ctx.save_for_backward(\n                x,\n                lora_A,\n                lora_B,\n                sorted_expert_idxs,\n                sorted_scattered_idxs,\n                expert_offsets,\n                gates,\n                output_expanded,\n            )\n            # Store frozen weights as plain Python attributes instead of\n            # save_for_backward.  This avoids:\n            # 1. Version-check conflicts with FSDP unshard/reshard\n            # 2. Pinning all-gathered parameters via saved_tensors hooks\n            # 3. Interfering with activation offloading pack/unpack hooks\n            # Safe because expert_weights are frozen (requires_grad=False).\n            ctx.expert_weights = expert_weights\n            ctx.expert_biases = expert_biases\n            ctx.grouped_in = grouped_in\n            ctx.grouped_out = grouped_out\n            ctx.k = k\n            ctx.scaling = scaling\n            ctx.use_fused_dX = use_fused_dX\n            ctx.use_fused_gather = use_fused_gather\n\n        return output\n\n    @staticmethod\n    def backward(ctx, grad_out: torch.Tensor):\n        with torch.device(grad_out.device):\n            (\n                x,\n                lora_A,\n                lora_B,\n                sorted_expert_idxs,\n                sorted_scattered_idxs,\n                expert_offsets,\n                gates,\n                output_expanded,\n            ) = ctx.saved_tensors\n            expert_weights = ctx.expert_weights\n\n            k = ctx.k\n            scaling = ctx.scaling\n            grouped_in = ctx.grouped_in\n            grouped_out = ctx.grouped_out\n            E = expert_weights.size(0)\n\n            # ------------------------------------------------------------------\n            # Gate gradients (if using top-k gating with routing weights)\n            # ------------------------------------------------------------------\n            if gates is not None:\n                # d_gates[t, j] = output_expanded[t, j, :] . grad_out[t, :]\n                d_gates = (output_expanded @ grad_out.unsqueeze(-1)).squeeze(-1)\n                gates_flat = gates.flatten()\n                gate_fan = gates.size(1)\n                # Reuse output_expanded buffer for grouped_grad_out\n                grouped_grad_out = output_expanded.flatten(0, 1)\n            else:\n                d_gates = None\n                gates_flat = None\n                gate_fan = 1\n                grouped_grad_out = None\n\n            # ------------------------------------------------------------------\n            # LoRA gradients (dA, dB) and setup for dX\n            # ------------------------------------------------------------------\n            # Fused gather uses sorted_scattered_idxs for indirect X access\n            # in the Triton kernel, avoiding the group(x) allocation.\n            #\n            # can_fuse_gather: X is ungrouped and not too large for scatter loads\n            #   - When gates is None and grouped_out=False: both DY and X ungrouped\n            #   - When grouped_out=True (gate_up_proj): DY already grouped, X ungrouped\n            #     -> use dy_grouped=True in the fused kernel\n            M_total = sorted_scattered_idxs.size(0)\n            K_dim = x.size(-1)\n            N_dim = expert_weights.size(-1)\n            fuse_gather_workload = M_total * max(K_dim, N_dim)\n            _FUSE_GATHER_THRESHOLD = 2**24  # ~16M elements\n\n            can_fuse_gather = (\n                ctx.use_fused_gather\n                and not grouped_in  # X must be ungrouped for scatter access\n                and gates is None  # gate coeff requires multiplicative gather\n                and fuse_gather_workload < _FUSE_GATHER_THRESHOLD\n            )\n\n            if can_fuse_gather:\n                # ------------------------------------------------------------------\n                # Fused path: skip group(x) entirely\n                # ------------------------------------------------------------------\n                d_expanded_input = None\n\n                d_lora_A, d_lora_B = group_bwd_lora_fused(\n                    DY=grad_out,\n                    X=x,\n                    lora_A=lora_A,\n                    lora_B=lora_B,\n                    expert_offsets=expert_offsets,\n                    sorted_scattered_idxs=sorted_scattered_idxs,\n                    E=E,\n                    k=k,\n                    scaling=scaling,\n                    dy_grouped=grouped_out,\n                )\n\n                # Prepare grouped_grad_out for the dX path (needed by both\n                # the fused dX kernel when grouped_out=True, and the non-fused path)\n                if grouped_out:\n                    grouped_grad_out = grad_out\n                elif not ctx.use_fused_dX:\n                    grouped_grad_out = base_ops.group(\n                        grad_out,\n                        sorted_scattered_idxs,\n                        fan_out=gate_fan,\n                        coeff=gates_flat,\n                        out=grouped_grad_out,\n                    )\n            else:\n                # ------------------------------------------------------------------\n                # Original path: explicit group() calls\n                # ------------------------------------------------------------------\n                if grouped_out:\n                    grouped_grad_out = grad_out\n                else:\n                    grouped_grad_out = base_ops.group(\n                        grad_out,\n                        sorted_scattered_idxs,\n                        fan_out=gate_fan,\n                        coeff=gates_flat,\n                        out=grouped_grad_out,\n                    )\n\n                if grouped_in:\n                    grouped_x = x\n                    d_expanded_input = None\n                else:\n                    grouped_x = base_ops.group(x, sorted_scattered_idxs, fan_out=k)\n                    d_expanded_input = grouped_x  # Will be overwritten; reuse buffer\n\n                d_lora_A, d_lora_B = group_bwd_lora(\n                    DY=grouped_grad_out,\n                    X=grouped_x,\n                    lora_A=lora_A,\n                    lora_B=lora_B,\n                    expert_offsets=expert_offsets,\n                    E=E,\n                    scaling=scaling,\n                )\n\n            # ------------------------------------------------------------------\n            # Input gradient: dX = dY @ W^T + scaling * (dY @ B) @ A\n            # ------------------------------------------------------------------\n            if ctx.use_fused_dX:\n                if can_fuse_gather and not grouped_out:\n                    # Fully fused: read ungrouped DY via scatter pattern\n                    d_expanded_input = scatter2scatter_lora_dX(\n                        DY=grad_out,\n                        W=expert_weights,\n                        sorted_expert_idxs=sorted_expert_idxs,\n                        sorted_scattered_idxs=sorted_scattered_idxs,\n                        k=1,\n                        lora_A=lora_A,\n                        lora_B=lora_B,\n                        scaling=scaling,\n                        dy_grouped=False,\n                        dx_grouped=grouped_in,\n                        out=d_expanded_input,\n                    )\n                else:\n                    # Fused dX only: read from pre-grouped DY\n                    d_expanded_input = scatter2scatter_lora_dX(\n                        DY=grouped_grad_out,\n                        W=expert_weights,\n                        sorted_expert_idxs=sorted_expert_idxs,\n                        sorted_scattered_idxs=sorted_scattered_idxs,\n                        k=1,\n                        lora_A=lora_A,\n                        lora_B=lora_B,\n                        scaling=scaling,\n                        dy_grouped=True,\n                        dx_grouped=grouped_in,\n                        out=d_expanded_input,\n                    )\n            else:\n                # Original path: separate base scatter2scatter + LoRA Python loop\n                d_expanded_input = base_ops.scatter2scatter(\n                    X=grouped_grad_out,\n                    x_grouped=True,\n                    W=expert_weights.permute(0, 2, 1),  # [E, N, K]\n                    sorted_expert_idxs=sorted_expert_idxs,\n                    sorted_scattered_idxs=sorted_scattered_idxs,\n                    k=1,\n                    y_grouped=grouped_in,\n                    out=d_expanded_input,\n                )\n\n                # LoRA part: dX_lora = scaling * (dY @ B) @ A\n                if scaling != 0.0:\n                    d_input_lora_grouped = _compute_lora_input_grad(\n                        grouped_grad_out,\n                        lora_A,\n                        lora_B,\n                        expert_offsets,\n                        E,\n                        scaling,\n                    )\n                    if grouped_in:\n                        d_expanded_input.add_(d_input_lora_grouped)\n                    else:\n                        # Scatter-add LoRA gradient directly into d_expanded_input.\n                        # Avoids allocating a zeros_like + add result\n                        d_expanded_input[sorted_scattered_idxs] += d_input_lora_grouped\n\n            # Reduce over top-k if k > 1\n            if k == 1:\n                d_input = d_expanded_input\n            else:\n                d_input = d_expanded_input.view(\n                    x.size(0), k, d_expanded_input.size(-1)\n                ).sum(-2)\n\n            # W is frozen during LoRA training -- skip weight gradient\n            d_weights = (\n                torch.zeros_like(expert_weights)\n                if expert_weights.requires_grad\n                else None\n            )\n            d_biases = None\n\n        return (\n            d_input,\n            d_weights,\n            None,\n            None,\n            None,\n            None,  # k, sorted indices, offsets\n            d_lora_A,\n            d_lora_B,\n            None,  # lora_A, lora_B, scaling\n            d_biases,\n            d_gates,\n            None,\n            None,  # grouped_in, grouped_out\n            None,  # use_fused_dX\n            None,  # use_fused_gather\n        )\n\n\ndef _compute_lora_input_grad(\n    grouped_grad_out: torch.Tensor,\n    lora_A: torch.Tensor,\n    lora_B: torch.Tensor,\n    expert_offsets: torch.Tensor,\n    E: int,\n    scaling: float,\n) -> torch.Tensor:\n    \"\"\"\n    Compute the LoRA contribution to the input gradient:\n      dX_lora = scaling * (dY @ B) @ A\n\n    Uses PyTorch ops on expert-grouped data.\n    Each expert e: dX_e = scaling * (dY_e @ B_e) @ A_e\n    \"\"\"\n    R = lora_A.size(0) // E\n    K = lora_A.size(1)\n    M_total = grouped_grad_out.size(0)\n\n    d_input_lora = torch.zeros(\n        (M_total, K), device=grouped_grad_out.device, dtype=grouped_grad_out.dtype\n    )\n\n    compute_dtype = grouped_grad_out.dtype\n\n    prev_offset = 0\n    for e in range(E):\n        curr_offset = expert_offsets[e].item()\n        if curr_offset > prev_offset:\n            dy_e = grouped_grad_out[prev_offset:curr_offset]  # [M_e, N]\n            a_e = lora_A[e * R : (e + 1) * R, :].to(compute_dtype)  # [r, K]\n            b_e = lora_B[:, e * R : (e + 1) * R].to(compute_dtype)  # [N, r]\n\n            # dX_e = scaling * (dY_e @ B_e) @ A_e\n            dy_b = dy_e @ b_e  # [M_e, r]\n            dx_e = scaling * (dy_b @ a_e)  # [M_e, K]\n            d_input_lora[prev_offset:curr_offset] = dx_e\n\n        prev_offset = curr_offset\n\n    return d_input_lora\n\n\n# =============================================================================\n# Helper: Extract LoRA params from PEFT ParamWrapper\n# =============================================================================\n\n\ndef get_lora_params_from_wrapper(module) -> tuple:\n    \"\"\"\n    Extract LoRA parameters from a PEFT ParamWrapper.\n\n    Returns:\n        (lora_A, lora_B, scaling) if LoRA is active, else (None, None, None)\n    \"\"\"\n    if not hasattr(module, \"lora_A\") or not hasattr(module, \"lora_B\"):\n        return None, None, None\n\n    active_adapters = getattr(module, \"active_adapters\", [\"default\"])\n    if not active_adapters:\n        return None, None, None\n\n    adapter_name = active_adapters[0]\n\n    lora_A_dict = getattr(module, \"lora_A\", {})\n    lora_B_dict = getattr(module, \"lora_B\", {})\n    scaling_dict = getattr(module, \"scaling\", {})\n\n    if adapter_name not in lora_A_dict:\n        return None, None, None\n\n    lora_A = lora_A_dict[adapter_name].weight\n    lora_B = lora_B_dict[adapter_name].weight\n    scaling = scaling_dict[adapter_name]\n\n    return lora_A, lora_B, scaling\n\n\n# =============================================================================\n# Drop-in replacement for parallel_linear\n# =============================================================================\n\n\ndef parallel_linear_lora(\n    inputs: torch.Tensor,\n    expert_weights: torch.Tensor,\n    k: int,\n    sorted_expert_idxs: torch.Tensor,\n    sorted_scattered_idxs: torch.Tensor,\n    expert_offsets: torch.Tensor,\n    lora_A: Optional[torch.Tensor] = None,\n    lora_B: Optional[torch.Tensor] = None,\n    scaling: float = 1.0,\n    expert_biases: Optional[torch.Tensor] = None,\n    gates: Optional[torch.Tensor] = None,\n    grouped_in: bool = False,\n    grouped_out: bool = False,\n    use_fused_dX: bool = False,\n    use_fused_gather: bool = False,\n):\n    \"\"\"\n    Drop-in replacement for parallel_linear that supports LoRA.\n\n    If lora_A and lora_B are provided, uses fused LoRA kernel.\n    Otherwise falls back to standard scatter2scatter.\n    \"\"\"\n    if lora_A is not None and lora_B is not None:\n        return ScatterMoELoRA.apply(\n            inputs,\n            expert_weights,\n            k,\n            sorted_expert_idxs,\n            sorted_scattered_idxs,\n            expert_offsets,\n            lora_A,\n            lora_B,\n            scaling,\n            expert_biases,\n            gates,\n            grouped_in,\n            grouped_out,\n            use_fused_dX,\n            use_fused_gather,\n        )\n    else:\n        from .parallel_experts import ParallelLinear\n\n        return ParallelLinear.apply(\n            inputs,\n            expert_weights,\n            k,\n            sorted_expert_idxs,\n            sorted_scattered_idxs,\n            expert_offsets,\n            expert_biases,\n            gates,\n            grouped_in,\n            grouped_out,\n        )\n"
  },
  {
    "path": "src/axolotl/integrations/kernels/libs/scattermoe_lora/selective_dequant.py",
    "content": "\"\"\"\nSelective Expert Dequantization\n===============================\n\nInstead of dequantizing all E expert weight matrices at once (which creates\na ~1 GB transient buffer for 256 experts), only dequantize the experts that\nare actually routed to by the current batch's top-k selection.\n\nFor Qwen3.5-35B-A3B (E=256, top_k=8, hidden=2048, intermediate=512):\n  - Full dequant: [256, 2048, 1024] = 1,074 MB per projection\n  - Selective (8 active): [8, 2048, 1024] = 33.5 MB per projection\n  - Savings: ~97% memory reduction per layer\n\nThis module provides format-agnostic selective weight extraction:\n  - BnB 4-bit (nf4/fp4): slice quantized data + absmax per expert\n  - bf16/fp32: direct indexing (no dequant needed)\n  - FP8: slice + cast\n\nThe ScatterMoE kernel itself doesn't change — we remap expert indices\nfrom global (0..E-1) to compact (0..num_active-1) and pass the smaller\nweight tensor.\n\"\"\"\n\nimport torch\nimport torch.nn as nn\n\n\ndef get_active_experts(sorted_expert_idxs: torch.Tensor, E: int) -> torch.Tensor:\n    \"\"\"Get sorted unique expert indices from the routing output.\n\n    Args:\n        sorted_expert_idxs: Expert assignments sorted by expert id [T*k]\n        E: Total number of experts\n\n    Returns:\n        active: Sorted unique expert indices [num_active]\n    \"\"\"\n    return torch.unique(sorted_expert_idxs)\n\n\ndef remap_expert_indices(\n    sorted_expert_idxs: torch.Tensor,\n    expert_offsets: torch.Tensor,\n    active_experts: torch.Tensor,\n    E: int,\n) -> tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"Remap global expert indices to compact indices.\n\n    Maps expert ids from [0..E-1] to [0..num_active-1], preserving the\n    sort order. Also compacts expert_offsets to only active experts.\n\n    Args:\n        sorted_expert_idxs: [T*k] expert ids in sorted order\n        expert_offsets: [E] cumulative token counts (original)\n        active_experts: [num_active] sorted unique expert ids\n        E: Total number of experts\n\n    Returns:\n        remapped_idxs: [T*k] expert ids in [0..num_active-1]\n        compact_offsets: [num_active] cumulative token counts\n    \"\"\"\n    # Build remap table: global_id -> compact_id\n    remap = torch.empty(E, dtype=torch.long, device=sorted_expert_idxs.device)\n    remap[active_experts] = torch.arange(\n        len(active_experts), device=sorted_expert_idxs.device\n    )\n\n    remapped_idxs = remap[sorted_expert_idxs]\n\n    # Compact the expert_offsets: only keep active experts' cumulative counts\n    compact_offsets = expert_offsets[active_experts]\n\n    return remapped_idxs, compact_offsets\n\n\ndef _selective_dequant_bnb4(\n    raw_param: torch.Tensor,\n    quant_state,\n    active_experts: torch.Tensor,\n    expert_shape: tuple[int, int],\n) -> torch.Tensor:\n    \"\"\"Dequantize only selected experts from BnB 4-bit packed data.\n\n    The raw parameter is a flattened 4-bit packed tensor. Each expert's\n    data is contiguous (stored in expert-major order), so we can gather\n    the packed data and absmax blocks for active experts, then dequantize\n    as one contiguous block.\n\n    Args:\n        raw_param: Flattened uint8 tensor of packed 4-bit weights\n        quant_state: BnB QuantState with absmax, blocksize, code, etc.\n        active_experts: [num_active] expert indices to dequantize\n        expert_shape: (dim1, dim2) shape per expert (e.g. (1024, 2048))\n\n    Returns:\n        Dequantized weights [num_active, dim1, dim2] in original dtype\n    \"\"\"\n    import bitsandbytes.functional as F  # noqa: N812\n    from bitsandbytes.functional import QuantState\n\n    expert_numel = expert_shape[0] * expert_shape[1]\n    packed_per_expert = expert_numel // 2  # 4-bit = 2 values per byte\n    blocks_per_expert = expert_numel // quant_state.blocksize\n    num_active = len(active_experts)\n\n    if blocks_per_expert == 0:\n        # Expert is smaller than one quantization block — blocks span across\n        # expert boundaries, so per-expert slicing isn't possible.\n        # Fallback: full dequantize + index.\n        full = F.dequantize_4bit(raw_param, quant_state)\n        E_total = full.numel() // expert_numel\n        return full.reshape(E_total, *expert_shape)[active_experts]\n\n    # Use fused Triton kernel for NF4 (handles selective gather + dequant in one pass)\n    if quant_state.quant_type == \"nf4\" and raw_param.dtype == torch.uint8:\n        from axolotl.integrations.kernels.libs.scattermoe_lora.selective_dequant_kernel import (\n            selective_dequant_nf4_triton,\n        )\n\n        # Handle nested (double) quantization: dequantize absmax first\n        # BnB uses dequantize_blockwise (not _4bit) for nested absmax + offset\n        if quant_state.nested:\n            absmax = F.dequantize_blockwise(quant_state.absmax, quant_state.state2)\n            absmax += quant_state.offset\n            if absmax.dtype != torch.float32:\n                absmax = absmax.float()\n        else:\n            absmax = quant_state.absmax\n\n        return selective_dequant_nf4_triton(\n            packed_data=raw_param,\n            absmax=absmax,\n            active_experts=active_experts,\n            expert_shape=expert_shape,\n            blocksize=quant_state.blocksize,\n            dtype=quant_state.dtype,\n            codebook=quant_state.code,\n        )\n\n    # Fallback: gather + BnB dequant (for fp4 or non-uint8 packed formats)\n    raw_flat = raw_param.reshape(-1)\n\n    offsets_qt = (\n        active_experts.long()[:, None] * packed_per_expert\n        + torch.arange(packed_per_expert, device=raw_param.device)[None, :]\n    ).reshape(-1)\n    qt_gathered = raw_flat[offsets_qt]\n\n    offsets_abs = (\n        active_experts.long()[:, None] * blocks_per_expert\n        + torch.arange(blocks_per_expert, device=raw_param.device)[None, :]\n    ).reshape(-1)\n\n    if quant_state.nested:\n        full_absmax = F.dequantize_blockwise(quant_state.absmax, quant_state.state2)\n        full_absmax += quant_state.offset\n        if full_absmax.dtype != torch.float32:\n            full_absmax = full_absmax.float()\n        absmax_gathered = full_absmax[offsets_abs]\n    else:\n        absmax_gathered = quant_state.absmax[offsets_abs]\n\n    qt_gathered = qt_gathered.unsqueeze(1) if qt_gathered.dim() == 1 else qt_gathered\n\n    gathered_qs = QuantState(\n        absmax=absmax_gathered,\n        shape=torch.Size([num_active * expert_numel]),\n        blocksize=quant_state.blocksize,\n        quant_type=quant_state.quant_type,\n        code=quant_state.code,\n        dtype=quant_state.dtype,\n    )\n\n    deq = F.dequantize_4bit(qt_gathered, gathered_qs)\n    return deq.reshape(num_active, *expert_shape)\n\n\ndef _selective_index_dense(\n    param: torch.Tensor,\n    active_experts: torch.Tensor,\n) -> torch.Tensor:\n    \"\"\"Select experts from a dense (bf16/fp32) weight tensor.\n\n    Simple indexing — no dequantization needed.\n    \"\"\"\n    return param[active_experts]\n\n\ndef selective_expert_weights(\n    experts_module: nn.Module,\n    param_name: str,\n    active_experts: torch.Tensor,\n) -> torch.Tensor:\n    \"\"\"Extract and dequantize only the active experts' weights.\n\n    Format-agnostic: dispatches based on whether the parameter is\n    BnB 4-bit quantized (via parametrize), FP8, or dense bf16/fp32.\n\n    Args:\n        experts_module: The base experts module (e.g. Qwen3_5MoeExperts)\n        param_name: \"gate_up_proj\" or \"down_proj\"\n        active_experts: [num_active] sorted unique expert indices\n\n    Returns:\n        Compact weight tensor [num_active, dim1, dim2] ready for ScatterMoE\n    \"\"\"\n    # Check if the parameter is BnB-quantized via parametrize\n    if (\n        hasattr(experts_module, \"parametrizations\")\n        and param_name in experts_module.parametrizations\n    ):\n        param_list = experts_module.parametrizations[param_name]\n        parametrization = param_list[0]\n\n        # BnB 4-bit parametrization\n        if hasattr(parametrization, \"quant_state\"):\n            # The raw quantized data is on the ParametrizationList, not the\n            # individual Bnb4bitParametrization module\n            raw_param = param_list.original\n            qs = parametrization.quant_state\n            # qs.shape is the original tensor shape before flattening.\n            # For MoE experts it's [E, d1, d2] (3D) or [total_elements] (1D).\n            orig_shape = qs.shape\n            if isinstance(orig_shape, torch.Size) and len(orig_shape) == 3:\n                expert_shape = (orig_shape[1], orig_shape[2])\n            elif isinstance(orig_shape, torch.Size) and len(orig_shape) == 1:\n                # Flattened — need to infer from module attributes\n                E_total = getattr(experts_module, \"num_experts\", None)\n                if E_total is None:\n                    E_total = int(active_experts.max().item()) + 1\n                expert_numel = orig_shape[0] // E_total\n                d2 = getattr(experts_module, \"hidden_dim\", None) or getattr(\n                    experts_module, \"intermediate_dim\", None\n                )\n                if d2 and expert_numel % d2 == 0:\n                    expert_shape = (expert_numel // d2, d2)\n                else:\n                    full = getattr(experts_module, param_name)\n                    return full[active_experts]\n            else:\n                full = getattr(experts_module, param_name)\n                return full[active_experts]\n\n            return _selective_dequant_bnb4(raw_param, qs, active_experts, expert_shape)\n\n    # Dense parameter (bf16/fp32) — direct indexing\n    param = getattr(experts_module, param_name)\n    if param.dim() == 3:\n        return param[active_experts]\n\n    # Fallback: full access\n    return param\n\n\ndef selective_lora_weights(\n    lora_A: torch.Tensor,\n    lora_B: torch.Tensor,\n    active_experts: torch.Tensor,\n    E: int,\n) -> tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"Select LoRA A and B weights for only the active experts.\n\n    LoRA layout (scattermoe format):\n      A: [r*E, K] — expert e occupies rows [e*r : (e+1)*r]\n      B: [N, r*E] — expert e occupies cols [e*r : (e+1)*r]\n\n    Returns compact:\n      A: [r*num_active, K]\n      B: [N, r*num_active]\n    \"\"\"\n    R = lora_A.size(0) // E\n\n    # Vectorized gather: active_experts[:, None] * R + arange(R)[None, :]\n    row_idx = (\n        active_experts.long()[:, None] * R\n        + torch.arange(R, device=lora_A.device)[None, :]\n    ).reshape(-1)\n\n    compact_A = lora_A[row_idx]  # [r*num_active, K]\n    compact_B = lora_B[:, row_idx]  # [N, r*num_active]\n\n    return compact_A, compact_B\n"
  },
  {
    "path": "src/axolotl/integrations/kernels/libs/scattermoe_lora/selective_dequant_kernel.py",
    "content": "\"\"\"\nTriton kernel for fused selective expert gather + NF4 dequantization.\n\nInstead of:\n  1. Gather packed uint8 data for active experts (memory copy)\n  2. Gather absmax for active experts (memory copy)\n  3. Call BnB dequantize_4bit CUDA kernel\n\nThis kernel does all three in one pass:\n  - Reads packed NF4 bytes from expert-strided positions\n  - Looks up the NF4 codebook\n  - Multiplies by the per-block absmax\n  - Writes bf16 output directly\n\nThis eliminates the intermediate gather buffer entirely.\n\"\"\"\n\nimport torch\nimport triton\nimport triton.language as tl\n\n# NF4 codebook (16 values, precomputed by BnB)\n# These are the normalized float4 reconstruction values\nNF4_CODEBOOK = [\n    -1.0,\n    -0.6961928009986877,\n    -0.5250730514526367,\n    -0.39491748809814453,\n    -0.28444138169288635,\n    -0.18477343022823334,\n    -0.09105003625154495,\n    0.0,\n    0.07958029955625534,\n    0.16093020141124725,\n    0.24611230194568634,\n    0.33791524171829224,\n    0.44070982933044434,\n    0.5626170039176941,\n    0.7229568362236023,\n    1.0,\n]\n\n\n@triton.jit\ndef _selective_dequant_nf4_kernel(\n    # Input: packed NF4 data (flattened, expert-major order)\n    packed_ptr,\n    # Input: absmax values (flattened, expert-major order)\n    absmax_ptr,\n    # Input: active expert indices\n    active_experts_ptr,\n    # Input: NF4 codebook (16 float values)\n    codebook_ptr,\n    # Output: dequantized bf16 weights [num_active, expert_numel]\n    out_ptr,\n    stride_out_e,  # stride for expert dim in output\n    # Dimensions\n    num_active,\n    packed_per_expert,  # expert_numel // 2\n    blocks_per_expert,  # expert_numel // blocksize\n    blocksize: tl.constexpr,\n    # Tile size\n    BLOCK_SIZE: tl.constexpr,  # elements per thread block (must be multiple of 2)\n):\n    \"\"\"\n    Each program processes BLOCK_SIZE elements from one expert.\n\n    Grid: (num_active, cdiv(expert_numel, BLOCK_SIZE))\n\n    For each output element:\n      1. Compute which byte in packed data contains this element\n      2. Extract the 4-bit nibble (high or low)\n      3. Look up in NF4 codebook\n      4. Scale by absmax for this block\n    \"\"\"\n    expert_local_idx = tl.program_id(0)  # which active expert (0..num_active-1)\n    block_id = tl.program_id(1)  # which element block\n\n    # Load the global expert index\n    expert_global = tl.load(active_experts_ptr + expert_local_idx).to(tl.int64)\n\n    expert_numel = packed_per_expert * 2  # 2 elements per packed byte\n    elem_offset = block_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n    mask = elem_offset < expert_numel\n\n    # Each element is packed as: byte[i//2], low nibble for even i, high for odd i\n    byte_idx = elem_offset // 2\n    is_high = (elem_offset % 2) == 1\n\n    # Read packed bytes from the global expert's region\n    packed_global_offset = expert_global * packed_per_expert + byte_idx\n    packed_bytes = tl.load(packed_ptr + packed_global_offset, mask=mask, other=0).to(\n        tl.int32\n    )\n\n    # Extract 4-bit nibble\n    # BnB packing: high nibble = even element, low nibble = odd element\n    nibble = tl.where(is_high, packed_bytes & 0xF, (packed_bytes >> 4) & 0xF)\n\n    # NF4 codebook lookup\n    # Load all 16 codebook values (small, fits in registers)\n    # Use gather from codebook pointer\n    code_val = tl.load(codebook_ptr + nibble, mask=mask, other=0.0)\n\n    # Load absmax for this element's quantization block\n    block_idx = elem_offset // blocksize\n    absmax_global_offset = expert_global * blocks_per_expert + block_idx\n    absmax_val = tl.load(absmax_ptr + absmax_global_offset, mask=mask, other=1.0)\n\n    # Dequantize: value = codebook[nibble] * absmax\n    result = code_val * absmax_val\n\n    # Store to output\n    out_offset = expert_local_idx * stride_out_e + elem_offset\n    tl.store(out_ptr + out_offset, result.to(out_ptr.dtype.element_ty), mask=mask)\n\n\ndef selective_dequant_nf4_triton(\n    packed_data: torch.Tensor,\n    absmax: torch.Tensor,\n    active_experts: torch.Tensor,\n    expert_shape: tuple[int, int],\n    blocksize: int,\n    dtype: torch.dtype = torch.bfloat16,\n    codebook: torch.Tensor | None = None,\n) -> torch.Tensor:\n    \"\"\"Fused selective gather + NF4 dequantization via Triton kernel.\n\n    Args:\n        packed_data: Flattened packed NF4 data [total_packed] or [total_packed, 1]\n        absmax: Per-block scaling factors [total_blocks]\n        active_experts: Sorted indices of experts to dequantize [num_active]\n        expert_shape: (dim1, dim2) per expert\n        blocksize: Quantization block size\n        dtype: Output dtype (default bf16)\n        codebook: NF4 lookup table [16] (uses default NF4 codebook if None)\n\n    Returns:\n        Dequantized weights [num_active, dim1, dim2]\n    \"\"\"\n    num_active = active_experts.shape[0]\n    expert_numel = expert_shape[0] * expert_shape[1]\n    packed_per_expert = expert_numel // 2\n    blocks_per_expert = expert_numel // blocksize\n\n    # Prepare codebook on device\n    if codebook is None:\n        codebook = torch.tensor(\n            NF4_CODEBOOK, dtype=torch.float32, device=packed_data.device\n        )\n    else:\n        codebook = codebook.to(device=packed_data.device, dtype=torch.float32)\n\n    # Flatten inputs\n    packed_flat = packed_data.reshape(-1)\n    absmax_flat = absmax.reshape(-1).float()  # absmax is usually fp32\n\n    # Output buffer\n    out = torch.empty(num_active, expert_numel, dtype=dtype, device=packed_data.device)\n\n    BLOCK_SIZE = 1024  # Process 1024 elements per thread block\n\n    grid = (num_active, triton.cdiv(expert_numel, BLOCK_SIZE))\n\n    _selective_dequant_nf4_kernel[grid](\n        packed_flat,\n        absmax_flat,\n        active_experts,\n        codebook,\n        out,\n        out.stride(0),\n        num_active=num_active,\n        packed_per_expert=packed_per_expert,\n        blocks_per_expert=blocks_per_expert,\n        blocksize=blocksize,\n        BLOCK_SIZE=BLOCK_SIZE,\n    )\n\n    return out.reshape(num_active, *expert_shape)\n"
  },
  {
    "path": "src/axolotl/integrations/kernels/plugin.py",
    "content": "import importlib\nimport os\nfrom pathlib import Path\n\nimport torch\n\nfrom axolotl.integrations.base import BasePlugin\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\n\ndef _check_sonicmoe_gpu_compat():\n    \"\"\"Validate GPU compute capability for SonicMoE and configure env.\n\n    Supported: Hopper (sm_90), Blackwell (sm_100 - sm_103).\n    B300 (sm_103) additionally requires Triton 3.6.0.\n    \"\"\"\n    if not torch.cuda.is_available():\n        return\n\n    cc = torch.cuda.get_device_capability()\n\n    if cc < (9, 0):\n        raise RuntimeError(\n            f\"SonicMoE requires Hopper (sm_90) or Blackwell (sm_100+) GPU, \"\n            f\"but detected sm_{cc[0]}{cc[1]}.\"\n        )\n\n    if cc > (10, 3):\n        raise RuntimeError(\n            f\"SonicMoE does not yet support sm_{cc[0]}{cc[1]}. \"\n            f\"Supported: Hopper (sm_90) and Blackwell (sm_100 - sm_103).\"\n        )\n\n    # Blackwell (sm_100+): enable QuACK GEMM kernels\n    if cc >= (10, 0):\n        os.environ.setdefault(\"USE_QUACK_GEMM\", \"1\")\n        LOG.info(\n            f\"Blackwell GPU (sm_{cc[0]}{cc[1]}) detected, enabling USE_QUACK_GEMM=1\"\n        )\n\n    # B300 (sm_103): requires Triton 3.6.0\n    if cc == (10, 3):\n        triton_spec = importlib.util.find_spec(\"triton\")\n        if triton_spec is None:\n            raise RuntimeError(\n                \"B300 (sm_103) requires Triton 3.6.0, but Triton is not installed.\"\n            )\n        import triton\n\n        triton_version = tuple(int(x) for x in triton.__version__.split(\".\")[:2])\n        if triton_version != (3, 6):\n            raise RuntimeError(\n                f\"B300 (sm_103) requires Triton 3.6.x, but found {triton.__version__}.\"\n            )\n\n\nclass KernelsPlugin(BasePlugin):\n    def get_input_args(self):\n        return \"axolotl.integrations.kernels.KernelsArgs\"\n\n    def pre_model_load(self, cfg):\n        from axolotl.integrations.kernels.constants import SPARSE_MOE_BLOCK\n\n        # Prefer text backbone type for VLMs, but fall back to base type\n        # when the text type isn't in the supported mapping (e.g. qwen3_5_moe_text)\n        moe_model_type = cfg.model_config_type_text or cfg.model_config_type\n        if (\n            moe_model_type not in SPARSE_MOE_BLOCK\n            and cfg.model_config_type in SPARSE_MOE_BLOCK\n        ):\n            moe_model_type = cfg.model_config_type\n\n        if cfg.use_scattermoe:\n            self._register_kernels()\n            self._kernelize_model(moe_model_type)\n        elif cfg.use_sonicmoe:\n            if not importlib.util.find_spec(\"sonicmoe\"):\n                raise RuntimeError(\n                    \"SonicMoE is not installed. See installation instructions at \"\n                    \"https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/integrations/kernels/README.md#sonicmoe-installation\"\n                )\n\n            _check_sonicmoe_gpu_compat()\n\n            from axolotl.integrations.kernels.sonicmoe import patch_sonicmoe\n\n            LOG.info(f\"Applying SonicMoE patches for model type: {moe_model_type}\")\n            patch_sonicmoe(\n                moe_model_type,\n                torch_compile=bool(getattr(cfg, \"torch_compile\", False)),\n            )\n\n    def _register_kernels(self):\n        from kernels import (\n            LocalLayerRepository,\n            Mode,\n            register_kernel_mapping,\n        )\n\n        plugin_root = Path(__file__).parent\n        register_kernel_mapping(\n            {\n                \"HFScatterMoEParallelExperts\": {\n                    \"cuda\": {\n                        Mode.TRAINING: LocalLayerRepository(\n                            repo_path=plugin_root / \"libs\" / \"scattermoe_lora\",\n                            package_name=\"scattermoe_lora\",\n                            layer_name=\"HFScatterMoEGatedMLP\",\n                        ),\n                        Mode.INFERENCE: LocalLayerRepository(\n                            repo_path=plugin_root / \"libs\" / \"scattermoe_lora\",\n                            package_name=\"scattermoe_lora\",\n                            layer_name=\"HFScatterMoEGatedMLP\",\n                        ),\n                    },\n                }\n            }\n        )\n\n    def add_callbacks_pre_trainer(self, cfg, model):\n        callbacks = []\n        if cfg.use_scattermoe:\n            from axolotl.integrations.kernels.autotune_callback import (\n                AutotuneReportCallback,\n            )\n\n            callbacks.append(AutotuneReportCallback())\n        return callbacks\n\n    def _kernelize_model(self, model_type: str):\n        from kernels import replace_kernel_forward_from_hub\n\n        from axolotl.integrations.kernels.constants import resolve_moe_block_classes\n\n        for model_moe_cls in resolve_moe_block_classes(model_type):\n            replace_kernel_forward_from_hub(\n                model_moe_cls, \"HFScatterMoEParallelExperts\"\n            )\n"
  },
  {
    "path": "src/axolotl/integrations/kernels/sonicmoe/__init__.py",
    "content": "from .patch import patch_sonicmoe\n\n__all__ = [\"patch_sonicmoe\"]\n"
  },
  {
    "path": "src/axolotl/integrations/kernels/sonicmoe/patch.py",
    "content": "\"\"\"\nSonicMoE patching for SparseMoeBlock forward pass.\n\nMonkeypatches the SparseMoeBlock class for a given model type to use\nSonicMoE's optimized kernels. Two forward paths are supported:\n\n1. **General routing path** (routing_fn is not None):\n   Uses a custom routing function + ``moe_general_routing_inputs``.\n   Suitable for models with non-standard routing (softmax->topk, sigmoid->topk).\n\n2. **Fused topk->softmax path** (routing_fn is None):\n   Uses ``moe_TC_softmax_topk_layer`` which fuses routing + expert computation.\n   Suitable for models with simple topk->softmax routing.\n\nWeight format conversion (interleave/deinterleave) is handled by the\nWeightConverter system, so the forward assumes weights are already in\ninterleaved format.\n\nShared experts are handled generically: if the block has a ``shared_expert``\nor ``shared_experts`` attribute, its output is computed alongside the routed\nexperts and added to the final output. An optional ``shared_expert_gate``\napplies sigmoid gating to the shared expert contribution.\n\"\"\"\n\nimport torch\nimport torch.nn.functional as F\n\nfrom axolotl.integrations.kernels.constants import resolve_moe_block_classes\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\n\ndef patch_sonicmoe(model_type: str, torch_compile: bool = False):\n    \"\"\"Main entry point: patch SparseMoeBlock for SonicMoE support.\n\n    Args:\n        model_type: The HuggingFace model type (e.g. \"qwen3_moe\").\n        torch_compile: If True, wrap routing functions with torch.compile\n            for kernel fusion (fuses softmax+topk+renorm into fewer launches).\n    \"\"\"\n    from .routing import get_model_moe_config\n    from .weight_converter import register_sonicmoe_weight_converter\n\n    routing_fn, activation, router_attr = get_model_moe_config(model_type)\n\n    if torch_compile and routing_fn is not None:\n        routing_fn = _try_compile_routing(routing_fn)\n\n    for moe_cls in resolve_moe_block_classes(model_type):\n        _patch_forward(moe_cls, routing_fn, activation, router_attr)\n    register_sonicmoe_weight_converter(model_type)\n\n\ndef _try_compile_routing(routing_fn):\n    \"\"\"Attempt to torch.compile the routing function, fall back to eager on failure.\"\"\"\n    try:\n        compiled_fn = torch.compile(routing_fn, mode=\"reduce-overhead\", dynamic=False)\n        LOG.info(f\"torch.compile enabled for routing function: {routing_fn.__name__}\")\n        return compiled_fn\n    except Exception as exc:  # pylint: disable=broad-except\n        LOG.warning(\n            f\"torch.compile failed for routing function {routing_fn.__name__}, \"\n            f\"falling back to eager: {exc}\"\n        )\n        return routing_fn\n\n\ndef _patch_forward(moe_cls, routing_fn, activation, router_attr):\n    \"\"\"Monkeypatch the SparseMoeBlock class with a SonicMoE forward.\n\n    The patched forward handles shared experts generically: if\n    ``self.shared_expert`` or ``self.shared_experts`` exists, it is computed\n    and added to the routed output. If ``self.shared_expert_gate`` also exists,\n    it applies sigmoid gating to the shared expert contribution (as in qwen2_moe).\n\n    Args:\n        moe_cls: The SparseMoeBlock class to patch.\n        routing_fn: Routing function (e.g. softmax_topk_routing), or None\n            for the fused moe_TC_softmax_topk_layer path.\n        activation: SonicMoE ActivationType enum value.\n        router_attr: Name of the router module attribute on the MoE block.\n    \"\"\"\n    if hasattr(moe_cls, \"_original_forward\"):\n        LOG.info(f\"{moe_cls.__name__}.forward already patched with SonicMoE, skipping\")\n        return\n\n    original_forward = moe_cls.forward\n\n    if routing_fn is not None:\n        _make_general_forward(moe_cls, routing_fn, activation)\n    else:\n        _make_fused_forward(moe_cls, activation, router_attr)\n\n    moe_cls._original_forward = original_forward\n    LOG.info(f\"Patched {moe_cls.__name__}.forward with SonicMoE implementation\")\n\n\ndef _make_general_forward(moe_cls, routing_fn, activation):\n    \"\"\"Create forward using routing_fn + moe_general_routing_inputs.\"\"\"\n\n    def sonicmoe_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        from sonicmoe import moe_general_routing_inputs\n\n        batch_size, sequence_length, hidden_dim = hidden_states.shape\n        hidden_states_flat = hidden_states.view(-1, hidden_dim)\n\n        # Shared expert (computed early, matching original model ordering)\n        shared_expert_output = _compute_shared_expert(self, hidden_states_flat)\n\n        # Routing\n        router_scores, token_indices, expert_indices, _router_logits = routing_fn(\n            hidden_states_flat, self\n        )\n\n        # Permute weights to SonicMoE layout:\n        #   gate_up: [E, 2*I, H] -> [2*I, H, E]\n        #   down:    [E, H, I]   -> [H, I, E]\n        gate_up_weight = self.experts.gate_up_proj.permute(1, 2, 0)\n        down_weight = self.experts.down_proj.permute(1, 2, 0)\n        E = gate_up_weight.shape[-1]\n\n        output, _ = moe_general_routing_inputs(\n            hidden_states_flat,\n            router_scores,\n            token_indices,\n            expert_indices,\n            gate_up_weight,\n            None,  # b1 (no gate/up bias)\n            down_weight,\n            None,  # b2 (no down bias)\n            E,\n            torch.cuda.current_stream().cuda_stream,\n            activation,\n            False,  # is_inference_mode\n        )\n\n        # Add shared expert contribution if present\n        if shared_expert_output is not None:\n            if hasattr(self, \"shared_expert_gate\"):\n                shared_expert_output = (\n                    F.sigmoid(self.shared_expert_gate(hidden_states_flat))\n                    * shared_expert_output\n                )\n            output = output + shared_expert_output\n\n        return output.view(batch_size, sequence_length, hidden_dim)\n\n    moe_cls.forward = sonicmoe_forward\n\n\ndef _make_fused_forward(moe_cls, activation, router_attr):\n    \"\"\"Create forward using moe_TC_softmax_topk_layer (topk -> softmax).\"\"\"\n\n    def sonicmoe_fused_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        from sonicmoe import moe_TC_softmax_topk_layer\n\n        batch_size, sequence_length, hidden_dim = hidden_states.shape\n        hidden_states_flat = hidden_states.view(-1, hidden_dim)\n\n        # Shared expert (computed early, matching original model ordering)\n        shared_expert_output = _compute_shared_expert(self, hidden_states_flat)\n\n        router = getattr(self, router_attr)\n\n        # Permute weights to SonicMoE layout:\n        #   gate_up: [E, 2*I, H] -> [2*I, H, E]\n        #   down:    [E, H, I]   -> [H, I, E]\n        gate_up_weight = self.experts.gate_up_proj.permute(1, 2, 0)\n        down_weight = self.experts.down_proj.permute(1, 2, 0)\n\n        output, _router_logits, _expert_freq = moe_TC_softmax_topk_layer(\n            hidden_states_flat,\n            router.weight,\n            gate_up_weight,\n            None,  # b1 (no gate/up bias)\n            down_weight,\n            None,  # b2 (no down bias)\n            router.top_k,\n            torch.cuda.current_stream().cuda_stream,\n            activation,\n            False,  # is_inference_mode\n        )\n\n        # Add shared expert contribution if present\n        if shared_expert_output is not None:\n            if hasattr(self, \"shared_expert_gate\"):\n                shared_expert_output = (\n                    F.sigmoid(self.shared_expert_gate(hidden_states_flat))\n                    * shared_expert_output\n                )\n            output = output + shared_expert_output\n\n        return output.view(batch_size, sequence_length, hidden_dim)\n\n    moe_cls.forward = sonicmoe_fused_forward\n\n\ndef _compute_shared_expert(moe_block, hidden_states_flat):\n    \"\"\"Compute shared expert output if the block has one.\n\n    Handles singular (qwen2_moe: ``shared_expert``), plural\n    (glm_moe_dsa/deepseek_v3: ``shared_experts``), and MLP\n    (hunyuan_v1_moe: ``shared_mlp``) attribute names.\n    \"\"\"\n    shared_expert = (\n        getattr(moe_block, \"shared_expert\", None)\n        or getattr(moe_block, \"shared_experts\", None)\n        or getattr(moe_block, \"shared_mlp\", None)\n    )\n    if shared_expert is not None:\n        return shared_expert(hidden_states_flat)\n    return None\n"
  },
  {
    "path": "src/axolotl/integrations/kernels/sonicmoe/routing.py",
    "content": "\"\"\"\nRouting functions for SonicMoE integration.\n\nDifferent MoE architectures use different routing strategies:\n- qwen3_moe / qwen2_moe / qwen3_5_moe / qwen3_vl_moe / qwen3_omni_moe: softmax -> topk (with optional renormalization)\n- gpt_oss: topk -> softmax (uses fused moe_TC_softmax_topk_layer, routing_fn=None)\n- glm_moe_dsa: sigmoid -> topk (with group-based expert selection)\n- mistral4: softmax -> group selection -> topk (with renormalization and scaling)\n\nEach model type maps to a (routing_fn, activation_type, router_attr) triple.\nWhen routing_fn is None, the fused moe_TC_softmax_topk_layer path is used.\n\"\"\"\n\nimport torch\nimport torch.nn.functional as F\n\n\ndef get_model_moe_config(model_type: str):\n    \"\"\"Returns (routing_fn, activation, router_attr) for a given model type.\n\n    Args:\n        model_type: HuggingFace model type string.\n\n    Returns:\n        routing_fn: Callable or None. None signals the fused\n            moe_TC_softmax_topk_layer path (topk -> softmax models).\n        activation: SonicMoE ActivationType enum value.\n        router_attr: Name of the router module attribute on the MoE block\n            (e.g. \"gate\" or \"router\").\n\n    The activation type cannot be derived from config.hidden_act because\n    e.g. qwen3_moe reports \"silu\" but architecturally uses SwiGLU\n    (act_fn(gate) * up pattern). So we specify it per model type.\n    \"\"\"\n    from sonicmoe.enums import ActivationType\n\n    if model_type in (\n        \"qwen2_moe\",\n        \"qwen3_moe\",\n        \"qwen3_5_moe\",\n        \"qwen3_next\",\n        \"qwen3_vl_moe\",\n        \"qwen3_omni_moe\",\n        \"olmoe\",\n        \"mixtral\",\n        \"minimax\",\n    ):\n        return softmax_topk_routing, ActivationType.SWIGLU, \"gate\"\n    elif model_type in (\"mistral4\",):\n        return softmax_group_topk_routing, ActivationType.SWIGLU, \"gate\"\n    elif model_type in (\n        \"glm_moe_dsa\",\n        \"deepseek_v3\",\n        \"glm4_moe\",\n        \"glm4_moe_lite\",\n        \"glm4v_moe\",\n        \"minimax_m2\",\n    ):\n        return sigmoid_topk_routing, ActivationType.SWIGLU, \"gate\"\n    # elif model_type in (\"ernie4_5_moe\",):\n    #     # Softmax→topk with e_score_correction_bias applied between softmax and topk.\n    #     return ..., ActivationType.SWIGLU, \"gate\"\n    # elif model_type in (\"deepseek_v2\",):\n    #     # Softmax→topk with group_limited_greedy. Different attr names: num_group\n    #     # (not n_group), gate is nn.Linear (not a router class).\n    #     return ..., ActivationType.SWIGLU, \"gate\"\n    # elif model_type in (\"hunyuan_v1_moe\",):\n    #     # Softmax→topk but gate structure differs: gate.wg (not gate.weight),\n    #     # top_k on block not gate, creates scatter routing matrix.\n    #     return ..., ActivationType.SWIGLU, \"gate\"\n    # Fused topk -> softmax path (routing_fn=None):\n    # elif model_type in (\"gpt_oss\",):\n    #     # NOTE: gpt_oss has a router bias which moe_TC_softmax_topk_layer\n    #     # ignores (it only takes router_w, not bias). Also has transposed\n    #     # weight layout [E, H, 2*I] and custom GLU activation.\n    #     return None, ActivationType.SWIGLU, \"router\"\n    else:\n        raise ValueError(f\"SonicMoE: unsupported model type '{model_type}'\")\n\n\ndef softmax_topk_routing(\n    hidden_states: torch.Tensor, moe_block\n) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n    \"\"\"Qwen3/Qwen2-style routing: softmax -> topk -> optional renorm.\n\n    Args:\n        hidden_states: [T, H] flattened token representations\n        moe_block: MoE block module (accesses moe_block.gate.*)\n\n    Returns:\n        router_scores: [T*K] flattened scores (float32)\n        token_indices: [T*K] which token each entry belongs to (int32), sorted ascending\n        expert_indices: [T*K] which expert (int32)\n        router_logits: [T, E] original logits for aux loss\n    \"\"\"\n    gate = moe_block.gate\n    T, H = hidden_states.shape\n    K = gate.top_k\n\n    # Compute router logits and softmax over all experts\n    router_logits = F.linear(hidden_states, gate.weight)  # [T, E]\n    router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32)  # [T, E]\n\n    # Select top-k experts per token\n    top_values, top_indices = torch.topk(router_probs, K, dim=-1)  # [T, K] each\n\n    # Renormalize if configured (default True for models without the attribute,\n    # e.g. Mixtral/MiniMax which always normalize)\n    if getattr(gate, \"norm_topk_prob\", True):\n        top_values = top_values / top_values.sum(dim=-1, keepdim=True)\n\n    # no-op: matches transformers which casts to softmax output dtype (float32).\n    # top_values = top_values.to(router_probs.dtype)\n\n    # Flatten for moe_general_routing_inputs.\n    # Token indices are naturally sorted ascending from the [T, K] layout:\n    # [0, 0, ..., 1, 1, ..., T-1, T-1, ...] — this is required by SonicMoE.\n    # Expert sorting is handled internally by general_routing_router_metadata.\n    token_indices = (\n        torch.arange(T, device=hidden_states.device, dtype=torch.int32)\n        .unsqueeze(1)\n        .expand(T, K)\n    )\n\n    flat_scores = top_values.reshape(-1)  # [T*K]\n    flat_token_idx = token_indices.reshape(-1)  # [T*K]\n    flat_expert_idx = top_indices.to(torch.int32).reshape(-1)  # [T*K]\n\n    return flat_scores, flat_token_idx, flat_expert_idx, router_logits\n\n\ndef softmax_group_topk_routing(\n    hidden_states: torch.Tensor, moe_block\n) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n    \"\"\"Mistral4-style routing: softmax -> group selection -> topk -> renorm -> scale.\"\"\"\n    gate = moe_block.gate\n    T, H = hidden_states.shape\n    K = moe_block.top_k\n    E = getattr(moe_block, \"n_routed_experts\", gate.weight.shape[0])\n    n_group = getattr(moe_block, \"n_group\", 1)\n\n    router_logits = F.linear(hidden_states, gate.weight)  # [T, E]\n    router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32)  # [T, E]\n\n    scores_for_choice = router_probs\n\n    # Group selection: pick top groups, mask the rest\n    if n_group > 1:\n        group_scores = (\n            scores_for_choice.view(-1, n_group, E // n_group)\n            .topk(2, dim=-1)[0]\n            .sum(dim=-1)\n        )\n        group_idx = torch.topk(\n            group_scores, k=moe_block.topk_group, dim=-1, sorted=False\n        )[1]\n        group_mask = torch.zeros_like(group_scores)\n        group_mask.scatter_(1, group_idx, 1)\n        score_mask = (\n            group_mask.unsqueeze(-1).expand(-1, n_group, E // n_group).reshape(-1, E)\n        )\n        scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)\n\n    topk_indices = torch.topk(scores_for_choice, k=K, dim=-1, sorted=False)[1]\n    topk_weights = router_probs.gather(1, topk_indices)\n\n    # Renormalization + scaling\n    norm_topk_prob = getattr(moe_block, \"norm_topk_prob\", True)\n    if norm_topk_prob:\n        topk_weights = topk_weights / (topk_weights.sum(dim=-1, keepdim=True) + 1e-20)\n    routed_scaling_factor = getattr(moe_block, \"routed_scaling_factor\", 1.0)\n    topk_weights = topk_weights * routed_scaling_factor\n\n    # Flatten for moe_general_routing_inputs\n    token_indices = (\n        torch.arange(T, device=hidden_states.device, dtype=torch.int32)\n        .unsqueeze(1)\n        .expand(T, K)\n    )\n\n    flat_scores = topk_weights.to(torch.float32).reshape(-1)  # [T*K]\n    flat_token_idx = token_indices.reshape(-1)  # [T*K]\n    flat_expert_idx = topk_indices.to(torch.int32).reshape(-1)  # [T*K]\n\n    return flat_scores, flat_token_idx, flat_expert_idx, router_logits\n\n\ndef sigmoid_topk_routing(\n    hidden_states: torch.Tensor, moe_block\n) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n    \"\"\"Sigmoid-based routing: sigmoid -> optional group selection -> topk.\n\n    Supports two variants:\n    - **Group selection** (glm_moe_dsa, deepseek_v3, etc.): n_group > 1,\n      bias on gate, group-based masking before topk.\n    - **No group selection** (minimax_m2): n_group == 1 (or absent),\n      bias on moe_block, straight topk from all experts.\n\n    Final routing weights come from the original sigmoid scores (not\n    bias-corrected), with optional renormalization and scaling.\n\n    Args:\n        hidden_states: [T, H] flattened token representations\n        moe_block: MoE block module (accesses moe_block.gate.* and\n            optional moe_block.n_group, .topk_group, .top_k, .norm_topk_prob,\n            .routed_scaling_factor, .n_routed_experts)\n\n    Returns:\n        router_scores: [T*K] flattened scores (float32)\n        token_indices: [T*K] which token each entry belongs to (int32), sorted ascending\n        expert_indices: [T*K] which expert (int32)\n        router_logits: [T, E] original logits for aux loss\n    \"\"\"\n    gate = moe_block.gate\n    T, H = hidden_states.shape\n    K = moe_block.top_k\n    E = getattr(moe_block, \"n_routed_experts\", gate.weight.shape[0])\n    n_group = getattr(moe_block, \"n_group\", 1)\n\n    # Compute router logits and sigmoid probabilities\n    router_logits = F.linear(hidden_states.float(), gate.weight.float())  # [T, E]\n    router_probs = router_logits.sigmoid()  # [T, E]\n\n    # Bias-corrected scores for expert selection (not used for final weights).\n    # glm_moe_dsa/deepseek_v3 store the bias on gate; minimax_m2 stores it on the block.\n    e_score_correction_bias = getattr(gate, \"e_score_correction_bias\", None)\n    if e_score_correction_bias is None:\n        e_score_correction_bias = getattr(moe_block, \"e_score_correction_bias\", None)\n    if e_score_correction_bias is None:\n        raise AttributeError(\n            f\"sigmoid_topk_routing requires e_score_correction_bias on \"\n            f\"gate ({type(gate)}) or moe_block ({type(moe_block)}), but neither has it\"\n        )\n    scores_for_choice = router_probs + e_score_correction_bias\n\n    # Group-based selection: pick top groups, mask the rest (skip when n_group == 1)\n    if n_group > 1:\n        group_scores = (\n            scores_for_choice.view(-1, n_group, E // n_group)\n            .topk(2, dim=-1)[0]\n            .sum(dim=-1)\n        )  # [T, n_group]\n        group_idx = torch.topk(\n            group_scores, k=moe_block.topk_group, dim=-1, sorted=False\n        )[1]\n        group_mask = torch.zeros_like(group_scores)\n        group_mask.scatter_(1, group_idx, 1)\n        score_mask = (\n            group_mask.unsqueeze(-1).expand(-1, n_group, E // n_group).reshape(-1, E)\n        )\n        scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)\n\n    # Final topk from (possibly masked) scores\n    topk_indices = torch.topk(scores_for_choice, k=K, dim=-1, sorted=False)[1]\n\n    # Gather weights from original sigmoid scores (not bias-corrected)\n    topk_weights = router_probs.gather(1, topk_indices)\n\n    # Optional renormalization + scaling\n    norm_topk_prob = getattr(moe_block, \"norm_topk_prob\", True)\n    if norm_topk_prob:\n        topk_weights = topk_weights / (topk_weights.sum(dim=-1, keepdim=True) + 1e-20)\n    routed_scaling_factor = getattr(moe_block, \"routed_scaling_factor\", 1.0)\n    topk_weights = topk_weights * routed_scaling_factor\n\n    # Flatten for moe_general_routing_inputs.\n    # Token indices are naturally sorted ascending from the [T, K] layout.\n    token_indices = (\n        torch.arange(T, device=hidden_states.device, dtype=torch.int32)\n        .unsqueeze(1)\n        .expand(T, K)\n    )\n\n    flat_scores = topk_weights.to(torch.float32).reshape(-1)  # [T*K]\n    flat_token_idx = token_indices.reshape(-1)  # [T*K]\n    flat_expert_idx = topk_indices.to(torch.int32).reshape(-1)  # [T*K]\n\n    return flat_scores, flat_token_idx, flat_expert_idx, router_logits\n"
  },
  {
    "path": "src/axolotl/integrations/kernels/sonicmoe/weight_converter.py",
    "content": "\"\"\"\nCustom WeightConverter operations for SonicMoE weight format conversion.\n\nSonicMoE requires gate_up_proj weights in interleaved format:\n- Standard (concatenated): [E, 2*I, H] where first I rows are gate, last I rows are up\n- SonicMoE (interleaved): [E, 2*I, H] where rows alternate [g0, u0, g1, u1, ...]\n\nThese ConversionOps integrate with transformers' WeightConverter system so that\nweights are transparently converted during loading and reverted during saving.\n\"\"\"\n\nfrom typing import Any\n\nimport torch\nfrom einops import rearrange\nfrom transformers.core_model_loading import ConversionOps\n\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\n\ndef interleave_gate_up(tensor: torch.Tensor) -> torch.Tensor:\n    \"\"\"[gate..., up...] -> [g0, u0, g1, u1, ...] along the 2*I dimension.\"\"\"\n    return rearrange(tensor, \"... (two out) h -> ... (out two) h\", two=2)\n\n\ndef deinterleave_gate_up(tensor: torch.Tensor) -> torch.Tensor:\n    \"\"\"[g0, u0, g1, u1, ...] -> [gate..., up...] along the 2*I dimension.\"\"\"\n    return rearrange(tensor, \"... (out two) h -> ... (two out) h\", two=2)\n\n\nclass ConcatenatedToInterleaved(ConversionOps):\n    \"\"\"Convert concatenated gate/up projections to interleaved format.\n\n    Input:  [E, 2*I, H] with gate=[E, :I, H] and up=[E, I:, H]\n    Output: [E, 2*I, H] with rows alternating [g0, u0, g1, u1, ...]\n\n    This operation is applied along ``dim`` (default 1, the 2*I dimension).\n    \"\"\"\n\n    def __init__(self, dim: int = 1):\n        self.dim = dim\n\n    @torch.no_grad()\n    def convert(\n        self,\n        input_dict: dict[str, Any],\n        source_patterns: list[str],\n        target_patterns: list[str],\n        **kwargs,\n    ) -> dict[str, torch.Tensor]:\n        target_pattern = self._get_target_pattern(\n            input_dict, source_patterns, target_patterns\n        )\n        tensors = next(iter(input_dict.values()))\n        tensor = tensors[0] if isinstance(tensors, list) else tensors\n\n        interleaved = interleave_gate_up(tensor)\n\n        return {target_pattern: interleaved}\n\n    def _get_target_pattern(\n        self,\n        input_dict: dict[str, Any],\n        source_patterns: list[str],\n        target_patterns: list[str],\n    ) -> str:\n        # Follow the same logic as Transpose.get_target_pattern\n        if len(input_dict) != 1:\n            raise ValueError(\"Undefined Operation encountered!\")\n        if len(target_patterns) > 1:\n            if len(source_patterns) == 1:\n                return source_patterns[0]\n            raise ValueError(\"Undefined Operation encountered!\")\n        return target_patterns[0]\n\n    @property\n    def reverse_op(self) -> ConversionOps:\n        return InterleavedToConcatenated(self.dim)\n\n\nclass InterleavedToConcatenated(ConversionOps):\n    \"\"\"Convert interleaved gate/up projections back to concatenated format.\n\n    Input:  [E, 2*I, H] with rows alternating [g0, u0, g1, u1, ...]\n    Output: [E, 2*I, H] with gate=[E, :I, H] and up=[E, I:, H]\n\n    This is the reverse of ``ConcatenatedToInterleaved``.\n    \"\"\"\n\n    def __init__(self, dim: int = 1):\n        self.dim = dim\n\n    @torch.no_grad()\n    def convert(\n        self,\n        input_dict: dict[str, Any],\n        source_patterns: list[str],\n        target_patterns: list[str],\n        **kwargs,\n    ) -> dict[str, torch.Tensor]:\n        target_pattern = self._get_target_pattern(\n            input_dict, source_patterns, target_patterns\n        )\n        tensors = next(iter(input_dict.values()))\n        tensor = tensors[0] if isinstance(tensors, list) else tensors\n\n        concatenated = deinterleave_gate_up(tensor)\n\n        return {target_pattern: concatenated}\n\n    def _get_target_pattern(\n        self,\n        input_dict: dict[str, Any],\n        source_patterns: list[str],\n        target_patterns: list[str],\n    ) -> str:\n        if len(input_dict) != 1:\n            raise ValueError(\"Undefined Operation encountered!\")\n        if len(target_patterns) > 1:\n            if len(source_patterns) == 1:\n                return source_patterns[0]\n            raise ValueError(\"Undefined Operation encountered!\")\n        return target_patterns[0]\n\n    @property\n    def reverse_op(self) -> ConversionOps:\n        return ConcatenatedToInterleaved(self.dim)\n\n\ndef register_sonicmoe_weight_converter(model_type: str):\n    \"\"\"Override the conversion mapping to add interleave step for gate_up_proj.\n\n    Appends a ConcatenatedToInterleaved operation to the existing gate_up_proj\n    converter chain. For example, qwen3_moe's chain becomes:\n        MergeModulelist(dim=0) -> Concatenate(dim=1) -> ConcatenatedToInterleaved(dim=1)\n\n    The reverse is auto-generated for saving:\n        InterleavedToConcatenated(dim=1) -> Chunk(dim=1) -> SplitModulelist(dim=0)\n    \"\"\"\n    from transformers.conversion_mapping import (\n        get_checkpoint_conversion_mapping,\n        register_checkpoint_conversion_mapping,\n    )\n\n    existing = get_checkpoint_conversion_mapping(model_type)\n    if existing is None:\n        LOG.warning(\n            f\"No conversion mapping found for model type '{model_type}'. \"\n            \"SonicMoE weight interleaving will not be applied during checkpoint loading.\"\n        )\n        return\n\n    # Find the gate_up_proj converter and append ConcatenatedToInterleaved\n    patched = False\n    for converter in existing:\n        if hasattr(converter, \"operations\") and any(\n            \"gate_up_proj\" in pat for pat in converter.target_patterns\n        ):\n            # Guard against double registration (e.g. plugin reloaded)\n            if any(\n                isinstance(op, ConcatenatedToInterleaved) for op in converter.operations\n            ):\n                LOG.info(\n                    f\"SonicMoE weight converter already registered for '{model_type}'\"\n                )\n                return\n            converter.operations.append(ConcatenatedToInterleaved(dim=1))\n            patched = True\n            break\n\n    if not patched:\n        LOG.warning(\n            f\"Could not find gate_up_proj converter for model type '{model_type}'. \"\n            \"SonicMoE weight interleaving will not be applied during checkpoint loading.\"\n        )\n        return\n\n    register_checkpoint_conversion_mapping(model_type, existing, overwrite=True)\n    LOG.info(f\"Registered SonicMoE weight converter for model type '{model_type}'\")\n"
  },
  {
    "path": "src/axolotl/integrations/liger/LICENSE",
    "content": "\n                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "src/axolotl/integrations/liger/README.md",
    "content": "# Liger Kernel Integration\n\nLiger Kernel provides efficient Triton kernels for LLM training, offering:\n\n- 20% increase in multi-GPU training throughput\n- 60% reduction in memory usage\n- Compatibility with both FSDP and DeepSpeed\n\nSee https://github.com/linkedin/Liger-Kernel\n\n## Usage\n\n```yaml\nplugins:\n  - axolotl.integrations.liger.LigerPlugin\nliger_rope: true\nliger_rms_norm: true\nliger_glu_activation: true\nliger_layer_norm: true\nliger_fused_linear_cross_entropy: true\n\n# FLCE-specific\nliger_use_token_scaling: true\n```\n\n## Supported Models\n\n- deepseek_v2\n- gemma\n- gemma2\n- gemma3\n- granite\n- jamba\n- llama\n- mistral\n- mixtral\n- mllama\n- mllama_text_model\n- olmo2\n- paligemma\n- phi3\n- qwen2\n- qwen2_5_vl\n- qwen2_vl\n\n## Citation\n\n```bib\n@article{hsu2024ligerkernelefficienttriton,\n      title={Liger Kernel: Efficient Triton Kernels for LLM Training},\n      author={Pin-Lun Hsu and Yun Dai and Vignesh Kothapalli and Qingquan Song and Shao Tang and Siyu Zhu and Steven Shimizu and Shivam Sahni and Haowen Ning and Yanning Chen},\n      year={2024},\n      eprint={2410.10989},\n      archivePrefix={arXiv},\n      primaryClass={cs.LG},\n      url={https://arxiv.org/abs/2410.10989},\n      journal={arXiv preprint arXiv:2410.10989},\n}\n```\n"
  },
  {
    "path": "src/axolotl/integrations/liger/__init__.py",
    "content": "# Copyright 2024 Axolotl AI. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nModule for the Plugin for LIGER integraton with Axolotl.\n\nLiger Kernel is the collection of Triton-native kernels for LLM Training.\nIt is designed to be performant, correct, and light-weight.\n\"\"\"\n\nfrom .args import LigerArgs\nfrom .plugin import LigerPlugin\n\n__all__ = [\n    \"LigerArgs\",\n    \"LigerPlugin\",\n]\n"
  },
  {
    "path": "src/axolotl/integrations/liger/args.py",
    "content": "# Copyright 2024 Axolotl AI. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nModule for handling LIGER input arguments.\n\"\"\"\n\nfrom pydantic import BaseModel, Field, model_validator\n\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\n\nclass LigerArgs(BaseModel):\n    \"\"\"\n    Input args for LIGER.\n    \"\"\"\n\n    liger_rope: bool | None = None\n    liger_rms_norm: bool | None = None\n    liger_layer_norm: bool | None = None\n    liger_swiglu: bool | None = None\n    liger_glu_activation: bool | None = None\n    liger_cross_entropy: bool | None = None\n    liger_fused_linear_cross_entropy: bool | None = None\n    liger_use_token_scaling: bool | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": (\n                \"Enables use_token_scaling in fused_linear_cross_entropy. \"\n                \"When True, each token's loss is multiplied by its predicted probability (detached from gradients).\"\n            )\n        },\n    )\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_deprecated_swiglu(cls, data):\n        if data.get(\"liger_swiglu\") is not None:\n            if data.get(\"liger_glu_activation\") is not None:\n                raise ValueError(\n                    \"You cannot have both `liger_swiglu` and `liger_glu_activation` set.\"\n                )\n\n            LOG.warning(\n                \"The 'liger_swiglu' argument is deprecated and will be removed in a future release. \"\n                \"Please use 'liger_glu_activation' instead.\"\n            )\n            data[\"liger_glu_activation\"] = data.pop(\"liger_swiglu\")\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_tiled_mlp_conflict(cls, data):\n        if (\n            data.get(\"liger_glu_activation\") is True\n            and data.get(\"tiled_mlp\") is True\n            and not data.get(\"tiled_mlp_use_original_mlp\")\n        ):\n            raise ValueError(\n                \"You cannot have both `liger_glu_activation` and `tiled_mlp` set without `tiled_mlp_use_original_mlp: true`.\"\n            )\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_liger_rms_norm_tensor_parallel(cls, data):\n        if data.get(\"liger_rms_norm\") and data.get(\"tensor_parallel_size\", 1) > 1:\n            raise ValueError(\n                \"`liger_rms_norm` is incompatible with tensor parallelism, \"\n                \"see https://github.com/linkedin/Liger-Kernel/issues/826\"\n            )\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_liger_use_token_scaling_flce(cls, data):\n        if data.get(\"liger_use_token_scaling\") and not data.get(\n            \"liger_fused_linear_cross_entropy\"\n        ):\n            raise ValueError(\n                \"`liger_use_token_scaling: true` requires `liger_fused_linear_cross_entropy` enabled.\"\n            )\n\n        return data\n\n    @model_validator(mode=\"after\")\n    def check_tensor_parallel_size_liger_fused_linear_cross_entropy(self):\n        # TODO @SalmanMohammadi this is a larger fix - investigate\n        if self.tensor_parallel_size > 1 and self.liger_fused_linear_cross_entropy:\n            raise ValueError(\"Tensor parallelism is not compatible with liger losses.\")\n        return self\n"
  },
  {
    "path": "src/axolotl/integrations/liger/models/__init__.py",
    "content": ""
  },
  {
    "path": "src/axolotl/integrations/liger/models/base.py",
    "content": "\"\"\"\nGeneric FLCE patch for untested models similar to Llama\n\"\"\"\n\nfrom typing import Optional, Tuple, Union\n\nimport torch\nfrom liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss\nfrom liger_kernel.transformers.trainer.orpo_trainer import _FSDPForwardRedirection\nfrom liger_kernel.utils import PEFT_AVAILABLE\nfrom peft.utils import ModulesToSaveWrapper\nfrom torch.distributed.fsdp import FullyShardedDataParallel\nfrom transformers.modeling_outputs import CausalLMOutputWithPast\n\nfrom axolotl.utils.callbacks.models import get_causal_lm_model_cls_prefix\n\n\ndef lce_forward(\n    self,\n    *args,\n    output_attentions: Optional[bool] = None,\n    output_hidden_states: Optional[bool] = None,\n    return_dict: Optional[bool] = None,\n    labels: Optional[torch.LongTensor] = None,\n    logits_to_keep: Union[int, torch.Tensor] = 0,\n    skip_logits: Optional[bool] = None,\n    **kwargs,\n) -> Union[Tuple, CausalLMOutputWithPast]:\n    r\"\"\"\n    Args:\n        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,\n            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored\n            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.\n\n        logits_to_keep (`int` or `torch.Tensor`, *optional*):\n            If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all\n            `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that\n            token can save memory, which becomes pretty significant for long sequences or large vocabulary size.\n            If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.\n            This is useful when using packed tensor format (single dimension for batch and sequence length).\n    \"\"\"\n\n    output_attentions = (\n        output_attentions\n        if output_attentions is not None\n        else self.config.output_attentions\n    )\n    output_hidden_states = (\n        output_hidden_states\n        if output_hidden_states is not None\n        else self.config.output_hidden_states\n    )\n\n    return_dict = (\n        return_dict if return_dict is not None else self.config.use_return_dict\n    )\n\n    # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\n    outputs = self.model(\n        *args,\n        output_attentions=output_attentions,\n        output_hidden_states=output_hidden_states,\n        **kwargs,\n    )\n\n    hidden_states = outputs[0]\n    # Only compute necessary logits, and do not upcast them to float if we are not computing the loss\n    slice_indices = (\n        slice(-logits_to_keep, None)\n        if isinstance(logits_to_keep, int)\n        else logits_to_keep\n    )\n    kept_hidden_states = hidden_states[:, slice_indices, :]\n\n    shift_labels = kwargs.pop(\"shift_labels\", None)\n    logits = None\n    loss = None\n\n    # if in training mode, don't materialize logits\n    if skip_logits and labels is None and shift_labels is None:\n        raise ValueError(\"skip_logits is True, but labels and shift_labels are None\")\n\n    if skip_logits is None:\n        # By default, if in training mode, don't materialize logits\n        skip_logits = self.training and (labels is not None or shift_labels is not None)\n\n    if skip_logits:\n        loss = lce_maybe_trainable_lm_head(\n            self,\n            hidden_states=kept_hidden_states,\n            hidden_size=self.config.hidden_size,\n            labels=labels,\n            shift_labels=shift_labels,\n            **kwargs,\n        )\n\n    else:\n        logits = self.lm_head(kept_hidden_states)\n        if labels is not None:\n            loss = self.loss_function(\n                logits=logits,\n                labels=labels,\n                vocab_size=self.config.vocab_size,\n                **kwargs,\n            )\n\n    if not return_dict:\n        output = (logits,) + outputs[1:]\n        return (loss,) + output if loss is not None else output\n\n    return CausalLMOutputWithPast(\n        loss=loss,\n        logits=logits,\n        past_key_values=outputs.past_key_values,\n        hidden_states=outputs.hidden_states,\n        attentions=outputs.attentions,\n    )\n\n\ndef lce_maybe_trainable_lm_head(\n    self, hidden_states, hidden_size, labels, shift_labels, **loss_kwargs\n):\n    lm_head = self.lm_head\n\n    # Unwrap the module if lm_head has been added as trainable module in PEFT LoRA configuration,\n    # i.e. listed in the modules_to_save field of LoraConfig, so the lm_head weights are read\n    # from the unwrapped module.\n    # See https://huggingface.co/docs/peft/package_reference/lora for reference.\n    if PEFT_AVAILABLE and isinstance(lm_head, ModulesToSaveWrapper):\n        lm_head = lm_head.modules_to_save.default\n\n    # If FSDP is used and lm_head is trainable, e.g., during full fine-tuning or with LoRA,\n    # reading the lm_head module weights and calling the kernel must be done within FSDP forward pass\n    # so the module entire parameters are summoned and kept in memory during the kernel execution.\n    if isinstance(lm_head, FullyShardedDataParallel):\n        return _FSDPForwardRedirection()(\n            lm_head,\n            _liger_for_causal_lm_loss,\n            lm_head.module,\n            hidden_states,\n            hidden_size,\n            labels,\n            shift_labels,\n            **loss_kwargs,\n        )\n\n    # FSDP is not used so we can read the lm_head weights and call the kernel directly\n    return _liger_for_causal_lm_loss(\n        lm_head=self.lm_head,\n        hidden_states=hidden_states,\n        hidden_size=hidden_size,\n        labels=labels,\n        shift_labels=shift_labels,\n        **loss_kwargs,\n    )\n\n\ndef _liger_for_causal_lm_loss(\n    lm_head, hidden_states, hidden_size, labels, shift_labels, **loss_kwargs\n):\n    return LigerForCausalLMLoss(\n        hidden_states=hidden_states,\n        lm_head_weight=lm_head.weight,\n        labels=labels,\n        hidden_size=hidden_size,\n        shift_labels=shift_labels,\n        **loss_kwargs,\n    )\n\n\ndef patch_lce_forward(\n    model_type,\n):\n    try:\n        # Dynamically import the module and MLP class\n        module_path = f\"transformers.models.{model_type}.modeling_{model_type}\"\n        model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type)\n        module = __import__(module_path, fromlist=[f\"{model_cls_prefix}ForCausalLM\"])\n        model_cls = getattr(module, f\"{model_cls_prefix}ForCausalLM\")\n\n        model_cls.forward = lce_forward\n\n    except (ImportError, AttributeError) as e:\n        raise RuntimeError(\n            f\"Could not import ForCausalLM class for model_type: {model_type}. \"\n            f\"Error: {str(e)}\"\n        ) from e\n"
  },
  {
    "path": "src/axolotl/integrations/liger/models/deepseekv2.py",
    "content": "\"\"\"\nDeepseekV2 model with LigerFusedLinearCrossEntropyLoss\n\"\"\"\n\nfrom typing import List, Optional, Tuple, Union\n\nimport torch\nfrom liger_kernel.transformers.fused_linear_cross_entropy import (\n    LigerFusedLinearCrossEntropyLoss,\n)\nfrom torch.nn import CrossEntropyLoss\nfrom transformers.modeling_outputs import CausalLMOutputWithPast\n\n\ndef lce_forward(\n    self,\n    input_ids: torch.LongTensor = None,\n    attention_mask: Optional[torch.Tensor] = None,\n    position_ids: Optional[torch.LongTensor] = None,\n    past_key_values: Optional[List[torch.FloatTensor]] = None,\n    inputs_embeds: Optional[torch.FloatTensor] = None,\n    labels: Optional[torch.LongTensor] = None,\n    use_cache: Optional[bool] = None,\n    output_attentions: Optional[bool] = None,\n    output_hidden_states: Optional[bool] = None,\n    return_dict: Optional[bool] = None,\n) -> Union[Tuple, CausalLMOutputWithPast]:\n    r\"\"\"\n    Args:\n        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Labels for computing the masked language modeling loss. Indices should either be in `[0, transformers.,\n            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored\n            (masked), the loss is only computed for the tokens with labels in `[0, transformers., config.vocab_size]`.\n\n    Returns:\n\n    Example:\n\n    ```python\n    >>> from transformers import AutoTokenizer, DeepseekV2ForCausalLM\n\n    >>> model = DeepseekV2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)\n    >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)\n\n    >>> prompt = \"Hey, are you conscious? Can you talk to me?\"\n    >>> inputs = tokenizer(prompt, return_tensors=\"pt\")\n\n    >>> # Generate\n    >>> generate_ids = model.generate(inputs.input_ids, max_length=30)\n    >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]\n    \"Hey, are you conscious? Can you talk to me?\\nI'm not conscious, but I can talk to you.\"\n    ```\"\"\"\n    output_attentions = (\n        output_attentions\n        if output_attentions is not None\n        else self.config.output_attentions\n    )\n    output_hidden_states = (\n        output_hidden_states\n        if output_hidden_states is not None\n        else self.config.output_hidden_states\n    )\n    return_dict = (\n        return_dict if return_dict is not None else self.config.use_return_dict\n    )\n\n    # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\n    outputs = self.model(\n        input_ids=input_ids,\n        attention_mask=attention_mask,\n        position_ids=position_ids,\n        past_key_values=past_key_values,\n        inputs_embeds=inputs_embeds,\n        use_cache=use_cache,\n        output_attentions=output_attentions,\n        output_hidden_states=output_hidden_states,\n        return_dict=return_dict,\n    )\n\n    hidden_states = outputs[0]\n\n    loss = None\n    logits = None\n\n    if self.training:\n        shift_hidden_states = hidden_states[..., :-1, :].contiguous()\n        shift_labels = labels[..., 1:].contiguous()\n\n        # flatten tokens\n        shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)\n        shift_labels = shift_labels.view(-1)\n\n        lce = LigerFusedLinearCrossEntropyLoss()\n        loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)\n    else:\n        logits = self.lm_head(hidden_states)\n        logits = logits.float()\n\n        loss = None\n        if labels is not None:\n            # Shift so that tokens < n predict n\n            shift_logits = logits[..., :-1, :].contiguous()\n            shift_labels = labels[..., 1:].contiguous()\n            # Flatten the tokens\n            loss_fct = CrossEntropyLoss()\n            shift_logits = shift_logits.view(-1, self.config.vocab_size)\n            shift_labels = shift_labels.view(-1)\n            # Enable model parallelism\n            shift_labels = shift_labels.to(shift_logits.device)\n            loss = loss_fct(shift_logits, shift_labels)\n\n    if not return_dict:\n        output = (logits,) + outputs[1:]\n        return (loss,) + output if loss is not None else output\n\n    return CausalLMOutputWithPast(\n        loss=loss,\n        logits=logits,\n        past_key_values=outputs.past_key_values,\n        hidden_states=outputs.hidden_states,\n        attentions=outputs.attentions,\n    )\n"
  },
  {
    "path": "src/axolotl/integrations/liger/models/jamba.py",
    "content": "\"\"\"\nJamba model with LigerFusedLinearCrossEntropyLoss\n\"\"\"\n\nfrom typing import Optional, Tuple, Union\n\nimport torch\nfrom liger_kernel.transformers.fused_linear_cross_entropy import (\n    LigerFusedLinearCrossEntropyLoss,\n)\nfrom torch.nn import CrossEntropyLoss\nfrom transformers.modeling_outputs import MoeCausalLMOutputWithPast\nfrom transformers.models.jamba.modeling_jamba import (\n    HybridMambaAttentionDynamicCache,\n    load_balancing_loss_func,\n)\n\n\ndef lce_forward(\n    self,\n    input_ids: torch.LongTensor = None,\n    attention_mask: Optional[torch.Tensor] = None,\n    position_ids: Optional[torch.LongTensor] = None,\n    past_key_values: Optional[HybridMambaAttentionDynamicCache] = None,\n    inputs_embeds: Optional[torch.FloatTensor] = None,\n    labels: Optional[torch.LongTensor] = None,\n    use_cache: Optional[bool] = None,\n    output_attentions: Optional[bool] = None,\n    output_hidden_states: Optional[bool] = None,\n    output_router_logits: Optional[bool] = None,\n    return_dict: Optional[bool] = None,\n    cache_position: Optional[torch.LongTensor] = None,\n    num_logits_to_keep: Optional[Union[int, None]] = None,\n) -> Union[Tuple, MoeCausalLMOutputWithPast]:\n    r\"\"\"\n    Args:\n        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,\n            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored\n            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.\n\n        num_logits_to_keep (`int` or `None`, *optional*):\n            Calculate logits for the last `num_logits_to_keep` tokens. If `None`, calculate logits for all\n            `input_ids`. Only last token logits are needed for generation, and calculating them only for that token\n            can save memory, which becomes pretty significant for long sequences.\n\n    Returns:\n\n    Example:\n\n    ```python\n    >>> from transformers import AutoTokenizer, JambaForCausalLM\n\n    >>> model = JambaForCausalLM.from_pretrained(\"ai21labs/Jamba-v0.1\")\n    >>> tokenizer = AutoTokenizer.from_pretrained(\"ai21labs/Jamba-v0.1\")\n\n    >>> prompt = \"Hey, are you conscious? Can you talk to me?\"\n    >>> inputs = tokenizer(prompt, return_tensors=\"pt\")\n\n    >>> # Generate\n    >>> generate_ids = model.generate(inputs.input_ids, max_length=30)\n    >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]\n    \"Hey, are you conscious? Can you talk to me?\\nI'm not conscious, but I can talk to you.\"\n    ```\"\"\"\n\n    output_attentions = (\n        output_attentions\n        if output_attentions is not None\n        else self.config.output_attentions\n    )\n    output_router_logits = (\n        output_router_logits\n        if output_router_logits is not None\n        else self.config.output_router_logits\n    )\n\n    output_hidden_states = (\n        output_hidden_states\n        if output_hidden_states is not None\n        else self.config.output_hidden_states\n    )\n    return_dict = (\n        return_dict if return_dict is not None else self.config.use_return_dict\n    )\n\n    # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\n    outputs = self.model(\n        input_ids=input_ids,\n        attention_mask=attention_mask,\n        position_ids=position_ids,\n        past_key_values=past_key_values,\n        inputs_embeds=inputs_embeds,\n        use_cache=use_cache,\n        output_attentions=output_attentions,\n        output_hidden_states=output_hidden_states,\n        output_router_logits=output_router_logits,\n        cache_position=cache_position,\n        return_dict=return_dict,\n    )\n\n    hidden_states = outputs[0]\n\n    loss = None\n    logits = None\n\n    if self.training:\n        shift_hidden_states = hidden_states[..., :-1, :].contiguous()\n        shift_labels = labels[..., 1:].contiguous()\n\n        # flatten tokens\n        shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)\n        shift_labels = shift_labels.view(-1)\n\n        lce = LigerFusedLinearCrossEntropyLoss()\n        loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)\n    else:\n        if num_logits_to_keep is None:\n            logits = self.lm_head(hidden_states)\n        else:\n            logits = self.lm_head(hidden_states[..., -num_logits_to_keep:, :])\n        logits = logits.float()\n\n        if labels is not None:\n            # Shift so that tokens < n predict n\n            shift_logits = logits[..., :-1, :].contiguous()\n            shift_labels = labels[..., 1:].contiguous()\n            # Flatten the tokens\n            loss_fct = CrossEntropyLoss()\n            shift_logits = shift_logits.view(-1, self.config.vocab_size)\n            shift_labels = shift_labels.view(-1)\n            # Enable model parallelism\n            shift_labels = shift_labels.to(shift_logits.device)\n            loss = loss_fct(shift_logits, shift_labels)\n\n    aux_loss = None\n    if output_router_logits:\n        aux_loss = load_balancing_loss_func(\n            outputs.router_logits if return_dict else outputs[-1],\n            self.num_experts,\n            self.num_experts_per_tok,\n            attention_mask,\n        )\n        if labels is not None:\n            loss += self.router_aux_loss_coef * aux_loss.to(\n                loss.device\n            )  # make sure to reside in the same device\n\n    if not return_dict:\n        output = (logits,) + outputs[1:]\n        if output_router_logits:\n            output = (aux_loss,) + output\n        return (loss,) + output if loss is not None else output\n\n    return MoeCausalLMOutputWithPast(\n        loss=loss,\n        aux_loss=aux_loss,\n        logits=logits,\n        past_key_values=outputs.past_key_values,\n        hidden_states=outputs.hidden_states,\n        attentions=outputs.attentions,\n        router_logits=outputs.router_logits,\n    )\n"
  },
  {
    "path": "src/axolotl/integrations/liger/models/llama4.py",
    "content": "\"\"\"\nLiger FLCE for llama4\n\"\"\"\n\nimport sys\nfrom copy import deepcopy\nfrom typing import List, Optional, Tuple, Union\n\nimport torch\nfrom liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss\nfrom transformers.modeling_outputs import CausalLMOutputWithPast\n\n\ndef lce_forward(\n    self,\n    input_ids: torch.LongTensor = None,\n    attention_mask: Optional[torch.Tensor] = None,\n    position_ids: Optional[torch.LongTensor] = None,\n    past_key_values: Optional[\n        Union[\"Cache\", List[torch.FloatTensor]]  # noqa: F821\n    ] = None,\n    inputs_embeds: Optional[torch.FloatTensor] = None,\n    labels: Optional[torch.LongTensor] = None,\n    use_cache: Optional[bool] = None,\n    output_attentions: Optional[bool] = None,\n    output_hidden_states: Optional[bool] = None,\n    return_dict: Optional[bool] = None,\n    cache_position: Optional[torch.LongTensor] = None,\n    logits_to_keep: Union[int, torch.Tensor] = 0,\n    **loss_kwargs,\n) -> Union[Tuple, CausalLMOutputWithPast]:\n    r\"\"\"\n    Args:\n        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,\n            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored\n            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.\n\n        logits_to_keep (`int` or `torch.Tensor`, *optional*):\n            If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all\n            `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that\n            token can save memory, which becomes pretty significant for long sequences or large vocabulary size.\n            If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.\n            This is useful when using packed tensor format (single dimension for batch and sequence length).\n\n    Returns:\n    \"\"\"\n\n    output_attentions = (\n        output_attentions\n        if output_attentions is not None\n        else self.config.output_attentions\n    )\n    output_hidden_states = (\n        output_hidden_states\n        if output_hidden_states is not None\n        else self.config.output_hidden_states\n    )\n    return_dict = (\n        return_dict if return_dict is not None else self.config.use_return_dict\n    )\n\n    # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\n    outputs = self.model(\n        input_ids=input_ids,\n        attention_mask=attention_mask,\n        position_ids=position_ids,\n        past_key_values=past_key_values,\n        inputs_embeds=inputs_embeds,\n        use_cache=use_cache,\n        output_attentions=output_attentions,\n        output_hidden_states=output_hidden_states,\n        return_dict=return_dict,\n        cache_position=cache_position,\n    )\n\n    hidden_states = outputs[0]\n\n    if hasattr(self.config, \"pretraining_tp\") and self.config.pretraining_tp > 1:\n        raise Exception(\"Liger Kernel does not support pretraining_tp!!\")\n\n    logits = None\n    loss = None\n    # if in training mode, don't materialize logits\n    if self.training and (labels is not None):\n        loss = LigerForCausalLMLoss(\n            hidden_states=hidden_states,\n            lm_head_weight=self.lm_head.weight,\n            labels=labels,\n            hidden_size=self.config.hidden_size,\n            **loss_kwargs,\n        )\n\n    else:  # if in inference mode materialize logits\n        slice_indices = (\n            slice(-logits_to_keep, None)\n            if isinstance(logits_to_keep, int)\n            else logits_to_keep\n        )\n        logits = self.lm_head(hidden_states[:, slice_indices, :])\n        if labels is not None:\n            loss = self.loss_function(\n                logits=logits,\n                labels=labels,\n                vocab_size=self.config.vocab_size,\n                **loss_kwargs,\n            )\n\n    if not return_dict:\n        output = (logits,) + outputs[1:]\n        return (loss,) + output if loss is not None else output\n\n    return CausalLMOutputWithPast(\n        loss=loss,\n        logits=logits,\n        past_key_values=outputs.past_key_values,\n        hidden_states=outputs.hidden_states,\n        attentions=outputs.attentions,\n    )\n\n\ndef apply_liger_kernel_to_llama4(\n    cross_entropy: bool = False,\n    fused_linear_cross_entropy: bool = False,\n    rms_norm: bool = False,\n    glu_activation: bool = False,\n    layer_norm: bool = False,\n    **kwargs,\n) -> None:\n    \"\"\"\n    Apply Liger kernels to replace original implementation in HuggingFace Llama models (2 and 3)\n\n    Args:\n        cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.\n        fused_linear_cross_entropy (bool):\n            Whether to apply Liger's fused linear cross entropy loss. Default is False.\n            `cross_entropy` and `fused_linear_cross_entropy` cannot both be False.\n            If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.\n        rms_norm (bool): Whether to apply Liger's RMSNorm. Default is False.\n        glu_activation (bool): Whether to apply Liger's SwiGLU MLP. Default is False.\n        layer_norm (bool): Whether to apply Liger's LayerNorm. Default is False.\n    \"\"\"\n\n    import transformers.models.llama4.modeling_llama4  # noqa: F401\n    from liger_kernel.transformers.functional import liger_cross_entropy\n    from liger_kernel.transformers.layer_norm import LigerLayerNorm\n    from liger_kernel.transformers.rms_norm import LigerRMSNorm\n    from liger_kernel.transformers.swiglu import LigerSwiGLUMLP\n\n    assert not (cross_entropy and fused_linear_cross_entropy), (\n        \"cross_entropy and fused_linear_cross_entropy cannot both be True.\"\n    )\n\n    modeling_llama4 = sys.modules[\"transformers.models.llama4.modeling_llama4\"]\n\n    if rms_norm:\n        modeling_llama4.Llama4TextRMSNorm = LigerRMSNorm\n    if glu_activation:\n\n        def _liger_swiglu_mlp_wrapper(config, intermediate_size=None, **kwargs):\n            \"Accepts intermediate_size to pass to LigerSwiGLUMLP\"\n            # clone config to avoid modifying the original\n            config = deepcopy(config)\n            if intermediate_size:\n                config.intermediate_size = intermediate_size\n            return LigerSwiGLUMLP(config, **kwargs)\n\n        modeling_llama4.Llama4TextMLP = _liger_swiglu_mlp_wrapper\n    if layer_norm:\n        modeling_llama4.nn.LayerNorm = LigerLayerNorm\n\n    if cross_entropy:\n        from transformers.loss.loss_utils import nn\n\n        nn.functional.cross_entropy = liger_cross_entropy\n\n    if fused_linear_cross_entropy:\n        modeling_llama4.Llama4ForCausalLM.forward = lce_forward\n"
  },
  {
    "path": "src/axolotl/integrations/liger/models/qwen3.py",
    "content": "\"\"\"\nLiger FLCE for Qwen3. Based on transformers v4.51.3.\n\"\"\"\n\nimport sys\nfrom typing import Optional, Tuple, Union\n\nimport torch\nfrom liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss\nfrom transformers.cache_utils import Cache\nfrom transformers.modeling_outputs import CausalLMOutputWithPast\n\n\ndef lce_forward(\n    self,\n    input_ids: Optional[torch.LongTensor] = None,\n    attention_mask: Optional[torch.Tensor] = None,\n    position_ids: Optional[torch.LongTensor] = None,\n    past_key_values: Optional[Cache] = None,\n    inputs_embeds: Optional[torch.FloatTensor] = None,\n    labels: Optional[torch.LongTensor] = None,\n    use_cache: Optional[bool] = None,\n    output_attentions: Optional[bool] = None,\n    output_hidden_states: Optional[bool] = None,\n    cache_position: Optional[torch.LongTensor] = None,\n    logits_to_keep: Union[int, torch.Tensor] = 0,\n    **kwargs,\n) -> Union[Tuple, CausalLMOutputWithPast]:\n    r\"\"\"\n    Args:\n        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,\n            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored\n            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.\n\n        logits_to_keep (`int` or `torch.Tensor`, *optional*):\n            If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all\n            `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that\n            token can save memory, which becomes pretty significant for long sequences or large vocabulary size.\n            If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.\n            This is useful when using packed tensor format (single dimension for batch and sequence length).\n\n    Returns:\n    \"\"\"\n\n    output_attentions = (\n        output_attentions\n        if output_attentions is not None\n        else self.config.output_attentions\n    )\n    output_hidden_states = (\n        output_hidden_states\n        if output_hidden_states is not None\n        else self.config.output_hidden_states\n    )\n\n    # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\n    outputs = self.model(\n        input_ids=input_ids,\n        attention_mask=attention_mask,\n        position_ids=position_ids,\n        past_key_values=past_key_values,\n        inputs_embeds=inputs_embeds,\n        use_cache=use_cache,\n        output_attentions=output_attentions,\n        output_hidden_states=output_hidden_states,\n        cache_position=cache_position,\n        **kwargs,\n    )\n\n    hidden_states = outputs[0]\n\n    logits = None\n    loss = None\n    # if in training mode, don't materialize logits\n    if self.training and (labels is not None):\n        loss = LigerForCausalLMLoss(\n            hidden_states=hidden_states,\n            lm_head_weight=self.lm_head.weight,\n            labels=labels,\n            hidden_size=self.config.hidden_size,\n            **kwargs,\n        )\n\n    else:  # if in inference mode materialize logits\n        slice_indices = (\n            slice(-logits_to_keep, None)\n            if isinstance(logits_to_keep, int)\n            else logits_to_keep\n        )\n        logits = self.lm_head(hidden_states[:, slice_indices, :])\n        if labels is not None:\n            loss = self.loss_function(\n                logits=logits,\n                labels=labels,\n                vocab_size=self.config.vocab_size,\n                **kwargs,\n            )\n\n    return CausalLMOutputWithPast(\n        loss=loss,\n        logits=logits,\n        past_key_values=outputs.past_key_values,\n        hidden_states=outputs.hidden_states,\n        attentions=outputs.attentions,\n    )\n\n\ndef apply_liger_kernel_to_qwen3(\n    cross_entropy: bool = False,\n    fused_linear_cross_entropy: bool = False,\n    rms_norm: bool = False,\n    glu_activation: bool = False,\n    layer_norm: bool = False,\n    **kwargs,\n) -> None:\n    \"\"\"\n    Apply Liger kernels to replace original implementation in HuggingFace Llama models (2 and 3)\n\n    Args:\n        cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.\n        fused_linear_cross_entropy (bool):\n            Whether to apply Liger's fused linear cross entropy loss. Default is False.\n            `cross_entropy` and `fused_linear_cross_entropy` cannot both be False.\n            If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.\n        rms_norm (bool): Whether to apply Liger's RMSNorm. Default is False.\n        glu_activation (bool): Whether to apply Liger's SwiGLU MLP. Default is False.\n        layer_norm (bool): Whether to apply Liger's LayerNorm. Default is False.\n    \"\"\"\n\n    import transformers.models.qwen3.modeling_qwen3  # noqa: F401\n    from liger_kernel.transformers.functional import liger_cross_entropy\n    from liger_kernel.transformers.layer_norm import LigerLayerNorm\n    from liger_kernel.transformers.rms_norm import LigerRMSNorm\n    from liger_kernel.transformers.swiglu import LigerSwiGLUMLP\n\n    assert not (cross_entropy and fused_linear_cross_entropy), (\n        \"cross_entropy and fused_linear_cross_entropy cannot both be True.\"\n    )\n\n    modeling_qwen3 = sys.modules[\"transformers.models.qwen3.modeling_qwen3\"]\n\n    if rms_norm:\n        modeling_qwen3.Qwen3RMSNorm = LigerRMSNorm\n\n    if glu_activation:\n        modeling_qwen3.Qwen3MLP = LigerSwiGLUMLP\n\n    if layer_norm:\n        modeling_qwen3.nn.LayerNorm = LigerLayerNorm\n\n    if cross_entropy:\n        from transformers.loss.loss_utils import nn\n\n        nn.functional.cross_entropy = liger_cross_entropy\n\n    if fused_linear_cross_entropy:\n        modeling_qwen3.Qwen3ForCausalLM.forward = lce_forward\n"
  },
  {
    "path": "src/axolotl/integrations/liger/models/qwen3_moe.py",
    "content": "\"\"\"\nLiger FLCE for Qwen3 MoE. Based on transformers v4.51.3.\n\"\"\"\n\nimport sys\nfrom copy import deepcopy\nfrom typing import List, Optional, Union\n\nimport torch\nfrom liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss\nfrom transformers.modeling_outputs import MoeCausalLMOutputWithPast\nfrom transformers.models.qwen3_moe.modeling_qwen3_moe import load_balancing_loss_func\n\n\ndef lce_forward(\n    self,\n    input_ids: Optional[torch.LongTensor] = None,\n    attention_mask: Optional[torch.Tensor] = None,\n    position_ids: Optional[torch.LongTensor] = None,\n    past_key_values: Optional[List[torch.FloatTensor]] = None,\n    inputs_embeds: Optional[torch.FloatTensor] = None,\n    labels: Optional[torch.LongTensor] = None,\n    use_cache: Optional[bool] = None,\n    output_attentions: Optional[bool] = None,\n    output_hidden_states: Optional[bool] = None,\n    output_router_logits: Optional[bool] = None,\n    cache_position: Optional[torch.LongTensor] = None,\n    logits_to_keep: Union[int, torch.Tensor] = 0,\n    **kwargs,\n) -> MoeCausalLMOutputWithPast:\n    r\"\"\"\n    Args:\n        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,\n            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored\n            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.\n\n        logits_to_keep (`int` or `torch.Tensor`, *optional*):\n            If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all\n            `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that\n            token can save memory, which becomes pretty significant for long sequences or large vocabulary size.\n            If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.\n            This is useful when using packed tensor format (single dimension for batch and sequence length).\n\n    Returns:\n    \"\"\"\n\n    output_attentions = (\n        output_attentions\n        if output_attentions is not None\n        else self.config.output_attentions\n    )\n    output_router_logits = (\n        output_router_logits\n        if output_router_logits is not None\n        else self.config.output_router_logits\n    )\n    output_hidden_states = (\n        output_hidden_states\n        if output_hidden_states is not None\n        else self.config.output_hidden_states\n    )\n\n    # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\n    outputs = self.model(\n        input_ids=input_ids,\n        attention_mask=attention_mask,\n        position_ids=position_ids,\n        past_key_values=past_key_values,\n        inputs_embeds=inputs_embeds,\n        use_cache=use_cache,\n        output_attentions=output_attentions,\n        output_hidden_states=output_hidden_states,\n        output_router_logits=output_router_logits,\n        cache_position=cache_position,\n        **kwargs,\n    )\n\n    hidden_states = outputs[0]\n\n    logits = None\n    loss = None\n    # if in training mode, don't materialize logits\n    if self.training and (labels is not None):\n        loss = LigerForCausalLMLoss(\n            hidden_states=hidden_states,\n            lm_head_weight=self.lm_head.weight,\n            labels=labels,\n            hidden_size=self.config.hidden_size,\n            **kwargs,\n        )\n\n    else:  # if in inference mode materialize logits\n        slice_indices = (\n            slice(-logits_to_keep, None)\n            if isinstance(logits_to_keep, int)\n            else logits_to_keep\n        )\n        logits = self.lm_head(hidden_states[:, slice_indices, :])\n        if labels is not None:\n            loss = self.loss_function(\n                logits=logits,\n                labels=labels,\n                vocab_size=self.config.vocab_size,\n                **kwargs,\n            )\n\n    aux_loss = None\n    if output_router_logits:\n        aux_loss = load_balancing_loss_func(\n            outputs.router_logits,\n            self.num_experts,\n            self.num_experts_per_tok,\n            attention_mask,\n        )\n        if labels is not None:\n            loss += self.router_aux_loss_coef * aux_loss.to(\n                loss.device\n            )  # make sure to reside in the same device\n\n    return MoeCausalLMOutputWithPast(\n        loss=loss,\n        aux_loss=aux_loss,\n        logits=logits,\n        past_key_values=outputs.past_key_values,\n        hidden_states=outputs.hidden_states,\n        attentions=outputs.attentions,\n    )\n\n\ndef apply_liger_kernel_to_qwen3_moe(\n    cross_entropy: bool = False,\n    fused_linear_cross_entropy: bool = False,\n    rms_norm: bool = False,\n    glu_activation: bool = False,\n    layer_norm: bool = False,\n    **kwargs,\n) -> None:\n    \"\"\"\n    Apply Liger kernels to replace original implementation in HuggingFace Llama models (2 and 3)\n\n    Args:\n        cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.\n        fused_linear_cross_entropy (bool):\n            Whether to apply Liger's fused linear cross entropy loss. Default is False.\n            `cross_entropy` and `fused_linear_cross_entropy` cannot both be False.\n            If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.\n        rms_norm (bool): Whether to apply Liger's RMSNorm. Default is False.\n        glu_activation (bool): Whether to apply Liger's SwiGLU MLP. Default is False.\n        layer_norm (bool): Whether to apply Liger's LayerNorm. Default is False.\n    \"\"\"\n\n    import transformers.models.qwen3_moe.modeling_qwen3_moe  # noqa: F401\n    from liger_kernel.transformers.functional import liger_cross_entropy\n    from liger_kernel.transformers.layer_norm import LigerLayerNorm\n    from liger_kernel.transformers.rms_norm import LigerRMSNorm\n    from liger_kernel.transformers.swiglu import LigerSwiGLUMLP\n\n    assert not (cross_entropy and fused_linear_cross_entropy), (\n        \"cross_entropy and fused_linear_cross_entropy cannot both be True.\"\n    )\n\n    modeling_qwen3_moe = sys.modules[\"transformers.models.qwen3_moe.modeling_qwen3_moe\"]\n\n    if rms_norm:\n        modeling_qwen3_moe.Qwen3MoeRMSNorm = LigerRMSNorm\n\n    if glu_activation:\n\n        def _liger_swiglu_mlp_wrapper(config, intermediate_size=None, **kwargs):\n            \"Accepts intermediate_size to pass to LigerSwiGLUMLP\"\n            # clone config to avoid modifying the original\n            config = deepcopy(config)\n            if intermediate_size:\n                config.intermediate_size = intermediate_size\n            return LigerSwiGLUMLP(config, **kwargs)\n\n        modeling_qwen3_moe.Qwen3MoeMLP = _liger_swiglu_mlp_wrapper\n\n    if layer_norm:\n        modeling_qwen3_moe.nn.LayerNorm = LigerLayerNorm\n\n    if cross_entropy:\n        from transformers.loss.loss_utils import nn\n\n        nn.functional.cross_entropy = liger_cross_entropy\n\n    if fused_linear_cross_entropy:\n        modeling_qwen3_moe.Qwen3MoeForCausalLM.forward = lce_forward\n"
  },
  {
    "path": "src/axolotl/integrations/liger/plugin.py",
    "content": "\"\"\"\nLiger-Kernel Plugin for Axolotl\n\"\"\"\n\nimport inspect\nimport sys\n\nfrom axolotl.integrations.base import BasePlugin\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\n\nclass LigerPlugin(BasePlugin):\n    \"\"\"\n    Plugin for LIGER integraton with Axolotl.\n    \"\"\"\n\n    def get_input_args(self):\n        return \"axolotl.integrations.liger.LigerArgs\"\n\n    def pre_model_load(self, cfg):\n        # shim: liger-kernel 0.7.0 imports ORPOTrainer from old trl path\n        import trl.trainer\n        from trl.experimental.orpo import ORPOTrainer\n\n        trl.trainer.ORPOTrainer = ORPOTrainer\n\n        if cfg.torch_compile:\n            # torch compile will unnecessarily attempt to optimize the triton kernel unless explicitly disabled\n            import liger_kernel.ops.fused_linear_cross_entropy\n\n            from .utils import patch_with_compile_disable\n\n            patch_with_compile_disable(\n                liger_kernel.ops.fused_linear_cross_entropy,\n                \"fused_linear_cross_entropy_forward\",\n            )\n            patch_with_compile_disable(\n                liger_kernel.ops.fused_linear_cross_entropy,\n                \"fused_linear_cross_entropy_backward\",\n            )\n\n        from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss\n        from liger_kernel.transformers.functional import liger_cross_entropy\n        from liger_kernel.transformers.layer_norm import LigerLayerNorm\n        from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN\n        from liger_kernel.transformers.rms_norm import LigerRMSNorm\n        from liger_kernel.transformers.rope import liger_rotary_pos_emb\n        from liger_kernel.transformers.swiglu import LigerSwiGLUMLP\n\n        if cfg.liger_cross_entropy and cfg.liger_fused_linear_cross_entropy:\n            raise ValueError(\n                \"Cannot have both `liger_cross_entropy` and `liger_fused_linear_cross_entropy` set.\"\n            )\n\n        if cfg.liger_use_token_scaling:\n            # Patch FLCE to set token_scaling=True for function and class API\n            from liger_kernel.transformers import functional\n            from liger_kernel.transformers.fused_linear_cross_entropy import (\n                LigerFusedLinearCrossEntropyLoss,\n            )\n\n            old_liger_fused_linear_cross_entropy = (\n                functional.liger_fused_linear_cross_entropy\n            )\n\n            def patched_liger_fused_linear_cross_entropy(*args, **kwargs):\n                kwargs[\"use_token_scaling\"] = True\n                return old_liger_fused_linear_cross_entropy(*args, **kwargs)\n\n            functional.liger_fused_linear_cross_entropy = (\n                patched_liger_fused_linear_cross_entropy\n            )\n\n            old_init = LigerFusedLinearCrossEntropyLoss.__init__\n\n            def patched_init(self, *args, **kwargs):\n                kwargs[\"use_token_scaling\"] = True\n                return old_init(self, *args, **kwargs)\n\n            LigerFusedLinearCrossEntropyLoss.__init__ = patched_init\n\n        if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN:\n            apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type]\n            liger_fn_sig = inspect.signature(apply_liger_fn)\n            kwargs = {}\n            if \"rope\" in liger_fn_sig.parameters:\n                kwargs[\"rope\"] = cfg.liger_rope\n            if \"cross_entropy\" in liger_fn_sig.parameters:\n                kwargs[\"cross_entropy\"] = cfg.liger_cross_entropy\n            if \"fused_linear_cross_entropy\" in liger_fn_sig.parameters:\n                kwargs[\"fused_linear_cross_entropy\"] = (\n                    cfg.liger_fused_linear_cross_entropy\n                )\n            if \"rms_norm\" in liger_fn_sig.parameters:\n                kwargs[\"rms_norm\"] = cfg.liger_rms_norm\n            if \"layer_norm\" in liger_fn_sig.parameters:\n                kwargs[\"layer_norm\"] = cfg.liger_layer_norm\n            if \"geglu\" in liger_fn_sig.parameters:\n                kwargs[\"geglu\"] = cfg.liger_glu_activation\n            elif \"swiglu\" in liger_fn_sig.parameters:\n                kwargs[\"swiglu\"] = cfg.liger_glu_activation\n            LOG.info(f\"Applying LIGER to {cfg.model_config_type} with kwargs: {kwargs}\")\n            apply_liger_fn(**kwargs)\n        elif cfg.model_config_type == \"jamba\":\n            from transformers.models.jamba import modeling_jamba\n\n            from .models.jamba import lce_forward as jamba_lce_forward\n\n            if cfg.liger_rope:\n                modeling_jamba.apply_rotary_pos_emb = liger_rotary_pos_emb\n            if cfg.liger_rms_norm:\n                modeling_jamba.JambaRMSNorm = LigerRMSNorm\n            if cfg.liger_glu_activation:\n                modeling_jamba.JambaMLP = LigerSwiGLUMLP\n            if cfg.liger_layer_norm:\n                modeling_jamba.nn.LayerNorm = LigerLayerNorm\n            if cfg.liger_cross_entropy:\n                from transformers.loss.loss_utils import nn\n\n                nn.functional.cross_entropy = liger_cross_entropy\n            if cfg.liger_fused_linear_cross_entropy:\n                modeling_jamba.JambaForCausalLM.forward = jamba_lce_forward\n        elif cfg.model_config_type == \"deepseek_v2\":\n            from accelerate import init_empty_weights\n            from transformers import AutoModelForCausalLM\n\n            with init_empty_weights():\n                model = AutoModelForCausalLM.from_pretrained(\n                    cfg.base_model, trust_remote_code=cfg.trust_remote_code or False\n                )\n                modeling_mod = sys.modules[model.__class__.__module__]\n\n            from .models.deepseekv2 import lce_forward as deepseekv2_lce_forward\n\n            if cfg.liger_rope:\n                # The DeepseekV2 version of RoPE is different than upstream LLaMA.\n                # See https://github.com/linkedin/Liger-Kernel/issues/129#issuecomment-2313763528\n                LOG.warning(\"Fused liger_rope is not supported for DeepseekV2.\")\n            if cfg.liger_rms_norm:\n                modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm\n            if cfg.liger_glu_activation:\n                modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward\n            if cfg.liger_layer_norm:\n                LOG.warning(\"liger_layer_norm is not supported for DeepseekV2.\")\n            if cfg.liger_cross_entropy:\n                # We do not patch `nn.functional.cross_entropy` for DeepseekV2 as it still uses\n                # nn.CrossEntropyLoss in the forward method.\n                modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss\n            if cfg.liger_fused_linear_cross_entropy:\n                modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward\n        elif cfg.model_config_type == \"llama4\":\n            from axolotl.integrations.liger.models.llama4 import (\n                apply_liger_kernel_to_llama4,\n            )\n\n            apply_liger_kernel_to_llama4(\n                cross_entropy=cfg.liger_cross_entropy,\n                fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,\n                glu_activation=cfg.liger_glu_activation,\n                rms_norm=cfg.liger_rms_norm,\n                layer_norm=cfg.liger_layer_norm,\n            )\n        elif cfg.model_config_type == \"qwen3\":\n            from axolotl.integrations.liger.models.qwen3 import (\n                apply_liger_kernel_to_qwen3,\n            )\n\n            apply_liger_kernel_to_qwen3(\n                cross_entropy=cfg.liger_cross_entropy,\n                fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,\n                glu_activation=cfg.liger_glu_activation,\n                rms_norm=cfg.liger_rms_norm,\n                layer_norm=cfg.liger_layer_norm,\n            )\n        elif cfg.model_config_type == \"qwen3_moe\":\n            from axolotl.integrations.liger.models.qwen3_moe import (\n                apply_liger_kernel_to_qwen3_moe,\n            )\n\n            apply_liger_kernel_to_qwen3_moe(\n                cross_entropy=cfg.liger_cross_entropy,\n                fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,\n                glu_activation=cfg.liger_glu_activation,\n                rms_norm=cfg.liger_rms_norm,\n                layer_norm=cfg.liger_layer_norm,\n            )\n        elif cfg.model_config_type == \"granitemoe\":\n            from liger_kernel.transformers import apply_liger_kernel_to_granite\n\n            apply_liger_kernel_to_granite(\n                rope=cfg.liger_rope,\n                cross_entropy=cfg.liger_cross_entropy,\n                fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,\n                rms_norm=cfg.liger_rms_norm,\n                swiglu=cfg.liger_glu_activation,\n            )\n        elif cfg.liger_fused_linear_cross_entropy:\n            try:\n                from .models.base import patch_lce_forward\n\n                patch_lce_forward(cfg.model_config_type)\n                LOG.warning_once(\n                    f\"Applied ONLY liger_fused_linear_cross_entropy genericpatches for model type: {cfg.model_config_type}\"\n                )\n                LOG.warning_once(\n                    f\"Liger + {cfg.model_config_type} generic FLCE support is experimental and may not work as expected.\"\n                )\n            except RuntimeError:\n                LOG.warning(\n                    f\"Unsupported model config type: {cfg.model_config_type}. Liger not applied.\"\n                )\n        else:\n            LOG.warning(\n                f\"Unsupported model config type: {cfg.model_config_type}. Liger not applied.\"\n            )\n"
  },
  {
    "path": "src/axolotl/integrations/liger/utils.py",
    "content": "\"\"\"\nutils to patch liger kernel ops to disable torch.compile\n\"\"\"\n\nfrom functools import wraps\n\nimport torch\n\n\ndef patch_with_compile_disable(module, function_name):\n    \"\"\"\n    Patch a function in a module by wrapping it with torch.compile.disable\n\n    Args:\n        module: The module containing the function to patch\n        function_name: The name of the function to patch\n    \"\"\"\n    original_function = getattr(module, function_name)\n\n    @wraps(original_function)\n    @torch.compiler.disable\n    def wrapped_function(*args, **kwargs):\n        return original_function(*args, **kwargs)\n\n    # Replace the original function with the wrapped one\n    setattr(module, function_name, wrapped_function)\n\n    # Return the original function in case you need to restore it later\n    return original_function\n"
  },
  {
    "path": "src/axolotl/integrations/llm_compressor/README.md",
    "content": "# LLMCompressor Integration\n\nFine-tune sparsified models in Axolotl using Neural Magic's [LLMCompressor](https://github.com/vllm-project/llm-compressor).\n\nThis integration enables fine-tuning of models sparsified using LLMCompressor within the Axolotl training framework. By combining LLMCompressor's model compression capabilities with Axolotl's distributed training pipelines, users can efficiently fine-tune sparse models at scale.\n\nIt uses Axolotl’s plugin system to hook into the fine-tuning flows while maintaining sparsity throughout training.\n\n---\n\n## Requirements\n\n- Axolotl with `llmcompressor` extras:\n\n  ```bash\n  pip install \"axolotl[llmcompressor]\"\n  ```\n\n- Requires `llmcompressor >= 0.5.1`\n\nThis will install all necessary dependencies to fine-tune sparsified models using the integration.\n\n---\n\n## Usage\n\nTo enable sparse fine-tuning with this integration, include the plugin in your Axolotl config:\n\n```yaml\nplugins:\n  - axolotl.integrations.llm_compressor.LLMCompressorPlugin\n\nllmcompressor:\n  recipe:\n    finetuning_stage:\n      finetuning_modifiers:\n        ConstantPruningModifier:\n          targets: [\n            're:.*q_proj.weight',\n            're:.*k_proj.weight',\n            're:.*v_proj.weight',\n            're:.*o_proj.weight',\n            're:.*gate_proj.weight',\n            're:.*up_proj.weight',\n            're:.*down_proj.weight',\n          ]\n          start: 0\n  save_compressed: true\n# ... (other training arguments)\n```\n\nThis plugin **does not apply pruning or sparsification itself** — it is intended for **fine-tuning models that have already been sparsified**.\n\nPre-sparsified checkpoints can be:\n- Generated using [LLMCompressor](https://github.com/vllm-project/llm-compressor)\n- Downloaded from [Neural Magic's Hugging Face page](https://huggingface.co/neuralmagic)\n- Any custom LLM with compatible sparsity patterns that you've created yourself\n\nTo learn more about writing and customizing LLMCompressor recipes, refer to the official documentation:\n[https://github.com/vllm-project/llm-compressor/blob/main/README.md](https://github.com/vllm-project/llm-compressor/blob/main/README.md)\n\n### Storage Optimization with save_compressed\n\nSetting `save_compressed: true` in your configuration enables saving models in a compressed format, which:\n- Reduces disk space usage by approximately 40%\n- Maintains compatibility with vLLM for accelerated inference\n- Maintains compatibility with llmcompressor for further optimization (example: quantization)\n\nThis option is highly recommended when working with sparse models to maximize the benefits of model compression.\n\n### Example Config\n\nSee [`examples/llama-3/sparse-finetuning.yaml`](examples/llama-3/sparse-finetuning.yaml) for a complete example.\n\n---\n\n## Inference with vLLM\n\nAfter fine-tuning your sparse model, you can leverage vLLM for efficient inference.\nYou can also use LLMCompressor to apply additional quantization to your fine-tuned\nsparse model before inference for even greater performance benefits.:\n\n```python\nfrom vllm import LLM, SamplingParams\n\nprompts = [\n    \"Hello, my name is\",\n    \"The president of the United States is\",\n    \"The capital of France is\",\n    \"The future of AI is\",\n]\nsampling_params = SamplingParams(temperature=0.8, top_p=0.95)\nllm = LLM(\"path/to/your/sparse/model\")\noutputs = llm.generate(prompts, sampling_params)\n\nfor output in outputs:\n    prompt = output.prompt\n    generated_text = output.outputs[0].text\n    print(f\"Prompt: {prompt!r}, Generated text: {generated_text!r}\")\n```\n\nFor more details on vLLM's capabilities and advanced configuration options, see the [official vLLM documentation](https://docs.vllm.ai/).\n\n## Learn More\n\nFor details on available sparsity and quantization schemes, fine-tuning recipes, and usage examples, visit the official LLMCompressor repository:\n\n[https://github.com/vllm-project/llm-compressor](https://github.com/vllm-project/llm-compressor)\n"
  },
  {
    "path": "src/axolotl/integrations/llm_compressor/__init__.py",
    "content": "\"\"\"Integration entry point for the LLMCompressor plugin.\"\"\"\n\nfrom .plugin import LLMCompressorPlugin\n\n__all__ = [\"LLMCompressorPlugin\"]\n"
  },
  {
    "path": "src/axolotl/integrations/llm_compressor/args.py",
    "content": "\"\"\"\nLLMCompressor and Sparse Finetuning config models.\n\"\"\"\n\nfrom typing import Any\n\nfrom pydantic import BaseModel, Field\nfrom typing_extensions import Annotated\n\n\nclass CompressionArgs(BaseModel):\n    \"\"\"Sparse Finetuning config for LLMCompressor.\"\"\"\n\n    # Typing for recipe is set to Any due to:\n    # https://github.com/vllm-project/llm-compressor/issues/1319\n    recipe: Annotated[\n        Any,\n        Field(\n            description=\"The recipe containing the compression algorithms and hyperparameters to apply.\"\n        ),\n    ]\n\n    save_compressed: Annotated[\n        bool,\n        Field(\n            default=False,\n            description=\"Whether to save the compressed model after training.\",\n        ),\n    ]\n\n\nclass LLMCompressorArgs(BaseModel):\n    \"\"\"LLMCompressor configuration BaseModel.\"\"\"\n\n    llmcompressor: Annotated[\n        CompressionArgs,\n        Field(\n            description=\"Arguments enabling compression pathways through the LLM Compressor plugins\"\n        ),\n    ]\n"
  },
  {
    "path": "src/axolotl/integrations/llm_compressor/plugin.py",
    "content": "\"\"\"\nSparse Finetuning plugin for Axolotl — enables handling of sparse neural networks\nby maintaining masks for zero weights during training.\n\"\"\"\n\nfrom functools import wraps\nfrom typing import Any, Callable, Concatenate, ParamSpec, TypeVar\n\nfrom llmcompressor import active_session, create_session\nfrom llmcompressor.core import callbacks as session_callbacks\nfrom llmcompressor.recipe import Recipe\nfrom torch.nn import Module\nfrom transformers.trainer import Trainer\nfrom transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState\nfrom transformers.training_args import TrainingArguments\n\nfrom axolotl.integrations.base import BasePlugin\nfrom axolotl.utils.logging import get_logger\n\nP = ParamSpec(\"P\")  # Params for generic function signatures\nR = TypeVar(\"R\")  # Return type for generic function signatures\n\nLOG = get_logger(__name__)\n\n\nclass LLMCompressorCallbackHandler(TrainerCallback):\n    \"\"\"\n    Trainer callback for Sparse Finetuning.\n    Maintains sparsity patterns during training by applying masks after optimization steps,\n    ensuring zero-weight updates are canceled out.\n    \"\"\"\n\n    def __init__(self, trainer: Trainer, recipe: Any):\n        \"\"\"\n        Initialize the Sparse Finetuning callback handler.\n\n        Args:\n            trainer (Trainer): Huggingface Trainer instance.\n            recipe (Recipe | dict): Sparse finetuning recipe to apply.\n        \"\"\"\n        super().__init__()\n        self.trainer = trainer\n        self.recipe = (\n            Recipe.model_validate(recipe) if not isinstance(recipe, Recipe) else recipe\n        )\n        self.original_compute_loss = trainer.compute_loss\n        self.trainer.compute_loss = compute_loss_wrapper(self.trainer.compute_loss)\n        create_session()\n\n    def on_train_begin(\n        self,\n        args: TrainingArguments,\n        state: TrainerState,\n        control: TrainerControl,\n        **kwargs,\n    ) -> None:\n        \"\"\"\n        Called at the beginning of training. Initializes the compression session.\n\n        Args:\n            args (TrainingArguments): Training arguments.\n            state (TrainerState): Trainer state.\n            control (TrainerControl): Trainer control.\n        \"\"\"\n        super().on_train_begin(args, state, control, **kwargs)\n        self.trainer.accelerator.wait_for_everyone()\n        active_session().initialize(\n            model=self.trainer.model,\n            optimizer=self.trainer.optimizer,\n            start=state.epoch,\n            recipe=self.recipe,\n        )\n        self.trainer.accelerator.wait_for_everyone()\n\n    def on_step_begin(\n        self,\n        args: TrainingArguments,\n        state: TrainerState,\n        control: TrainerControl,\n        **kwargs,\n    ) -> None:\n        \"\"\"\n        Called at the beginning of a training step. Triggers batch_start callback.\n        \"\"\"\n        super().on_step_begin(args, state, control, **kwargs)\n        session_callbacks.batch_start()\n\n    def on_step_end(\n        self,\n        args: TrainingArguments,\n        state: TrainerState,\n        control: TrainerControl,\n        **kwargs,\n    ) -> None:\n        \"\"\"\n        Called at the end of a training step. Triggers optimizer and batch_end callbacks.\n        \"\"\"\n        super().on_step_end(args, state, control, **kwargs)\n        session_callbacks.optim_pre_step()\n        session_callbacks.optim_post_step()\n        session_callbacks.batch_end()\n\n    def on_train_end(\n        self,\n        args: TrainingArguments,\n        state: TrainerState,\n        control: TrainerControl,\n        **kwargs,\n    ) -> None:\n        \"\"\"\n        Called at the end of training. Finalizes the compression session.\n        \"\"\"\n        super().on_train_end(args, state, control, **kwargs)\n        active_session().finalize()\n        self.trainer.compute_loss_func = self.original_compute_loss\n\n\nclass LLMCompressorPlugin(BasePlugin):\n    \"\"\"\n    Sparse Finetuning plugin for Axolotl integration.\n    \"\"\"\n\n    def get_input_args(self) -> str:\n        \"\"\"\n        Returns the path to the plugin's argument definition.\n\n        Returns:\n            str: Dotted path to the LLMCompressorArgs class.\n        \"\"\"\n        return \"axolotl.integrations.llm_compressor.args.LLMCompressorArgs\"\n\n    def add_callbacks_post_trainer(self, cfg: Any, trainer: Trainer) -> list:\n        \"\"\"\n        Adds Sparse Finetuning callback to the Trainer instance.\n\n        Args:\n            cfg (Any): Configuration object containing the sparse recipe.\n            trainer (Trainer): Huggingface Trainer instance.\n\n        Returns:\n            list: List containing the configured callback instances.\n        \"\"\"\n        LOG.info(\"Adding Sparse Finetuning callback to the trainer\")\n        callback = LLMCompressorCallbackHandler(\n            trainer=trainer,\n            recipe=cfg.llmcompressor.recipe,\n        )\n        return [callback]\n\n\ndef compute_loss_wrapper(\n    compute_loss_func: Callable[Concatenate[Module, P], R],\n) -> Callable[Concatenate[Module, P], R]:\n    \"\"\"\n    Wraps the loss computation function to trigger the loss_calculated callback.\n\n    Args:\n        compute_loss_func (Callable): Original loss computation function.\n\n    Returns:\n        Callable: Wrapped function that also invokes the loss_calculated callback.\n    \"\"\"\n\n    @wraps(compute_loss_func)\n    def compute_and_notify(model: Module, *args: P.args, **kwargs: P.kwargs) -> R:\n        loss = compute_loss_func(model, *args, **kwargs)\n        if active_session().lifecycle.initialized_ and model.training:\n            session_callbacks.loss_calculated(loss=loss)\n        return loss\n\n    return compute_and_notify\n"
  },
  {
    "path": "src/axolotl/integrations/llm_compressor/utils.py",
    "content": "\"\"\"Utilities for llmcompressor integration with axolotl.\"\"\"\n\nfrom typing import Union\n\nfrom llmcompressor.transformers.sparsification.compressed_tensors_utils import (\n    modify_save_pretrained,\n)\nfrom transformers import PreTrainedModel, Trainer\n\n\ndef save_compressed_model(\n    model: PreTrainedModel,\n    output_dir: Union[str, bytes],\n    trainer: Trainer,\n    save_compressed: bool = False,\n) -> None:\n    \"\"\"\n    Synchronize processes, apply compression hooks, and save the model.\n\n    Args:\n        model (PreTrainedModel): The model to be saved.\n        output_dir (str or bytes): Path where the model files will be written.\n        trainer (Trainer): Hugging Face Trainer for process synchronization.\n        save_compressed (bool): Write compressed tensors if True.\n    \"\"\"\n    trainer.accelerator.wait_for_everyone()\n\n    # Only the main process writes the files\n    if not trainer.accelerator.is_main_process:\n        return\n\n    modify_save_pretrained(model)\n    model.save_pretrained(\n        output_dir,\n        save_compressed=save_compressed,\n        skip_sparsity_compression_stats=not save_compressed,\n    )\n"
  },
  {
    "path": "src/axolotl/integrations/lm_eval/README.md",
    "content": "# LM Eval Harness\n\nRun evaluation on model using the popular lm-evaluation-harness library.\n\nSee https://github.com/EleutherAI/lm-evaluation-harness\n\n## Usage\n\nThere are two ways to use the LM Eval integration:\n\n### 1. Post-Training Evaluation\n\nWhen training with the plugin enabled, evaluation runs automatically after training completes:\n\n```yaml\nplugins:\n  - axolotl.integrations.lm_eval.LMEvalPlugin\n\nlm_eval_tasks:\n  - gsm8k\n  - hellaswag\n  - arc_easy\n\nlm_eval_batch_size: # Batch size for evaluation\n\n# Directory to save evaluation results.\n# The final model is loaded from this directory\n# unless specified otherwise (see below)\noutput_dir:\n```\n\nRun training as usual:\n```bash\naxolotl train config.yml\n```\n\n### 2. Standalone CLI Evaluation\n\nEvaluate any model directly without training:\n\n```yaml\nlm_eval_model: meta-llama/Llama-2-7b-hf\n\nplugins:\n  - axolotl.integrations.lm_eval.LMEvalPlugin\n\nlm_eval_tasks:\n  - gsm8k\n  - hellaswag\n  - arc_easy\n\nlm_eval_batch_size: 8\noutput_dir: ./outputs\n```\n\nRun evaluation:\n```bash\naxolotl lm-eval config.yml\n```\n\n## Model Selection Priority\n\nThe model to evaluate is selected in the following priority order:\n\n1. **`lm_eval_model`** - Explicit model path or HuggingFace repo (highest priority)\n2. **`hub_model_id`** - Trained model pushed to HuggingFace Hub\n3. **`output_dir`** - Local checkpoint directory containing trained model weights\n\n## Citation\n\n```bib\n@misc{eval-harness,\n  author       = {Gao, Leo and Tow, Jonathan and Abbasi, Baber and Biderman, Stella and Black, Sid and DiPofi, Anthony and Foster, Charles and Golding, Laurence and Hsu, Jeffrey and Le Noac'h, Alain and Li, Haonan and McDonell, Kyle and Muennighoff, Niklas and Ociepa, Chris and Phang, Jason and Reynolds, Laria and Schoelkopf, Hailey and Skowron, Aviya and Sutawika, Lintang and Tang, Eric and Thite, Anish and Wang, Ben and Wang, Kevin and Zou, Andy},\n  title        = {A framework for few-shot language model evaluation},\n  month        = 07,\n  year         = 2024,\n  publisher    = {Zenodo},\n  version      = {v0.4.3},\n  doi          = {10.5281/zenodo.12608602},\n  url          = {https://zenodo.org/records/12608602}\n}\n```\n"
  },
  {
    "path": "src/axolotl/integrations/lm_eval/__init__.py",
    "content": "\"\"\"\nModule for the Plugin for LM Eval Harness\n\"\"\"\n\nimport subprocess  # nosec\n\nfrom axolotl.integrations.base import BasePlugin\nfrom axolotl.integrations.lm_eval.cli import build_lm_eval_command, get_model_path\n\nfrom .args import LMEvalArgs as LMEvalArgs\n\n\nclass LMEvalPlugin(BasePlugin):\n    \"\"\"\n    Plugin for LM Evaluation Harness integraton with Axolotl.\n    \"\"\"\n\n    def get_input_args(self):\n        return \"axolotl.integrations.lm_eval.LMEvalArgs\"\n\n    def post_train_unload(self, cfg):\n        if cfg.lm_eval_post_train:\n            for lm_eval_args in build_lm_eval_command(\n                cfg.lm_eval_tasks,\n                bfloat16=cfg.bfloat16 or cfg.bf16,\n                flash_attention=cfg.flash_attention,\n                output_dir=cfg.output_dir,\n                batch_size=cfg.lm_eval_batch_size,\n                wandb_project=cfg.wandb_project,\n                wandb_entity=cfg.wandb_entity,\n                wandb_name=cfg.wandb_name,\n                model=get_model_path(cfg),\n            ):\n                subprocess.run(  # nosec\n                    lm_eval_args,\n                    check=True,\n                )\n"
  },
  {
    "path": "src/axolotl/integrations/lm_eval/args.py",
    "content": "\"\"\"\nModule for handling lm eval harness input arguments.\n\"\"\"\n\nfrom typing import List, Optional\n\nfrom pydantic import BaseModel\n\n\nclass LMEvalArgs(BaseModel):\n    \"\"\"\n    Input args for lm eval harness\n    \"\"\"\n\n    lm_eval_tasks: List[str] = []\n    lm_eval_batch_size: Optional[int] = 8\n    lm_eval_post_train: Optional[bool] = True\n    lm_eval_model: Optional[str] = None\n"
  },
  {
    "path": "src/axolotl/integrations/lm_eval/cli.py",
    "content": "\"\"\"\naxolotl CLI for running lm_eval tasks\n\"\"\"\n\nimport subprocess  # nosec\nfrom collections import defaultdict\nfrom datetime import datetime\nfrom typing import Optional\n\nimport click\nimport yaml\n\nfrom axolotl.utils.dict import DictDefault\n\n\ndef get_model_path(cfg: DictDefault) -> str | None:\n    \"\"\"\n    Determine which model path to use for evaluation.\n\n    Priority order (highest to lowest):\n    1. lm_eval_model - Explicit model path override\n    2. hub_model_id - Model pushed to HuggingFace Hub\n    3. None - Falls back to output_dir in build_lm_eval_command\n\n    Returns:\n        Model path string or None to use output_dir fallback\n    \"\"\"\n    return cfg.lm_eval_model or cfg.hub_model_id or None\n\n\ndef build_lm_eval_command(\n    tasks: list[str],\n    bfloat16=True,\n    flash_attention=False,\n    output_dir=\"./\",\n    batch_size=8,\n    wandb_project=None,\n    wandb_entity=None,\n    wandb_name=None,\n    model=None,\n    revision=None,\n    apply_chat_template=None,\n    fewshot_as_multiturn=None,\n):\n    tasks_by_num_fewshot: dict[str, list] = defaultdict(list)\n    if isinstance(tasks, str):\n        tasks = [tasks]\n    for task in tasks:\n        num_fewshot = \"-1\"\n        task_parts = task.split(\":\")\n        task_name = task_parts[0]\n        if len(task_parts) == 2:\n            task_name, num_fewshot = task_parts\n        tasks_by_num_fewshot[str(num_fewshot)].append(task_name)\n\n    for num_fewshot, tasks_list in tasks_by_num_fewshot.items():\n        tasks_str = \",\".join(tasks_list)\n        num_fewshot_val = num_fewshot if num_fewshot != \"-1\" else None\n        pretrained = \"pretrained=\"\n        pretrained += model if model else output_dir\n        fa2 = \",attn_implementation=flash_attention_2\" if flash_attention else \"\"\n        dtype = \",dtype=bfloat16\" if bfloat16 else \",dtype=float16\"\n        revision = f\",revision={revision}\" if revision else \"\"\n        output_path = output_dir\n        output_path += \"\" if output_dir.endswith(\"/\") else \"/\"\n        output_path += \"lm_eval_results/\" + datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n        lm_eval_args = [\n            \"lm_eval\",\n            \"--model\",\n            \"hf\",\n            \"--model_args\",\n            f\"{pretrained}{fa2}{dtype}{revision}\",\n            \"--tasks\",\n            tasks_str,\n            \"--batch_size\",\n            str(batch_size),\n            \"--output_path\",\n            output_path,\n        ]\n        wandb_args = []\n        if wandb_project:\n            wandb_args.append(f\"project={wandb_project}\")\n        if wandb_entity:\n            wandb_args.append(f\"entity={wandb_entity}\")\n        if wandb_name:\n            wandb_args.append(f\"name={wandb_name}\")\n        if wandb_args:\n            lm_eval_args.append(\"--wandb_args\")\n            lm_eval_args.append(\",\".join(wandb_args))\n        if apply_chat_template:\n            lm_eval_args.append(\"--apply_chat_template\")\n        if num_fewshot_val:\n            lm_eval_args.append(\"--num_fewshot\")\n            lm_eval_args.append(str(num_fewshot_val))\n            if apply_chat_template and fewshot_as_multiturn:\n                lm_eval_args.append(\"--fewshot_as_multiturn\")\n\n        yield lm_eval_args\n\n\n@click.command()\n@click.argument(\"config\", type=click.Path(exists=True, path_type=str))\n@click.option(\"--cloud\", default=None, type=click.Path(exists=True, path_type=str))\ndef lm_eval(config: str, cloud: Optional[str] = None):\n    \"\"\"\n    use lm eval to evaluate a trained language model\n    \"\"\"\n\n    if cloud:\n        from axolotl.cli.cloud import do_cli_lm_eval\n\n        do_cli_lm_eval(cloud_config=cloud, config=config)\n    else:\n        with open(config, encoding=\"utf-8\") as file:\n            cfg: DictDefault = DictDefault(yaml.safe_load(file))\n\n        for lm_eval_args in build_lm_eval_command(\n            cfg.lm_eval_tasks,\n            bfloat16=cfg.bfloat16 or cfg.bf16,\n            flash_attention=cfg.flash_attention,\n            output_dir=cfg.output_dir,\n            batch_size=cfg.lm_eval_batch_size,\n            wandb_project=cfg.wandb_project,\n            wandb_entity=cfg.wandb_entity,\n            wandb_name=cfg.wandb_name,\n            model=get_model_path(cfg),\n            revision=cfg.revision,\n            apply_chat_template=cfg.apply_chat_template,\n            fewshot_as_multiturn=cfg.fewshot_as_multiturn,\n        ):\n            subprocess.run(  # nosec\n                lm_eval_args,\n                check=True,\n            )\n"
  },
  {
    "path": "src/axolotl/integrations/spectrum/LICENSE",
    "content": "\n                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "src/axolotl/integrations/spectrum/README.md",
    "content": "# Spectrum: Targeted Training on Signal to Noise Ratio\n\nby Eric Hartford, Lucas Atkins, Fernando Fernandes, David Golchinfar\n\nThis plugin contains code to freeze the bottom fraction of modules in a model, based on the Signal-to-Noise Ratio (SNR).\n\nSee https://github.com/cognitivecomputations/spectrum\n\n## Overview\n\nSpectrum is a tool for scanning and evaluating the Signal-to-Noise Ratio (SNR) of layers in large language models.\nBy identifying the top n% of layers with the highest SNR, you can optimize training efficiency.\n\n## Usage\n\n```yaml\nplugins:\n  - axolotl.integrations.spectrum.SpectrumPlugin\n\nspectrum_top_fraction: 0.5\n# Optional if using a pre-scanned model as your base_model. Useful if using a model mirror\nspectrum_model_name: meta-llama/Meta-Llama-3.1-8B\n```\n\n## Citation\n\n```bib\n@misc{hartford2024spectrumtargetedtrainingsignal,\n      title={Spectrum: Targeted Training on Signal to Noise Ratio},\n      author={Eric Hartford and Lucas Atkins and Fernando Fernandes Neto and David Golchinfar},\n      year={2024},\n      eprint={2406.06623},\n      archivePrefix={arXiv},\n      primaryClass={cs.LG},\n      url={https://arxiv.org/abs/2406.06623},\n}\n```\n"
  },
  {
    "path": "src/axolotl/integrations/spectrum/__init__.py",
    "content": "# Copyright 2024 Axolotl AI. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nSpectrum Plugin to automatically generate unfrozen parameters based on SNR data.\n\"\"\"\n\nimport json\n\nimport requests\n\nfrom axolotl.integrations.base import BasePlugin\nfrom axolotl.utils.logging import get_logger\n\nfrom .args import SpectrumArgs as SpectrumArgs\n\nLOG = get_logger(__name__)\n\n\ndef _generate_unfrozen_params_yaml(snr_data, top_fraction=0.5):\n    unfrozen_parameters = {}\n    for layer_name, info in snr_data.items():\n        layer_type = info[\"type\"]\n        if layer_type not in unfrozen_parameters:\n            unfrozen_parameters[layer_type] = []\n        unfrozen_parameters[layer_type].append((layer_name, info[\"snr\"]))\n    top_layers_by_type = {}\n    for layer_type, layers in unfrozen_parameters.items():\n        layers_sorted = sorted(layers, key=lambda x: x[1], reverse=True)\n        num_top_layers = int(len(layers) * top_fraction)\n        top_layers_by_type[layer_type] = [\n            layer[0] for layer in layers_sorted[:num_top_layers]\n        ]\n    unfrozen_parameters = [\n        \"^lm_head.weight$\",\n        \"^model.embed_tokens.weight$\",\n    ]\n    for _, layer_names in top_layers_by_type.items():\n        for layer_name in layer_names:\n            unfrozen_parameters.append(layer_name)\n    return unfrozen_parameters\n\n\nclass SpectrumPlugin(BasePlugin):\n    \"\"\"\n    Spectrum Plugin to automatically generate unfrozen parameters based on SNR data.\n    \"\"\"\n\n    base_url = \"https://raw.githubusercontent.com/cognitivecomputations/spectrum/main/model_snr_results/\"\n    base_path = \"./model_snr_results/\"\n    snr_file_template = \"snr_results_{model_name_slug}.json\"\n\n    def get_input_args(self):\n        return \"axolotl.integrations.spectrum.SpectrumArgs\"\n\n    def pre_model_load(self, cfg):\n        if cfg.get(\"spectrum_model_name\"):\n            model_name = cfg[\"spectrum_model_name\"]\n        else:\n            model_name = cfg[\"base_model\"]\n        top_fraction = cfg.get(\"spectrum_top_fraction\", 50)\n        model_slug = model_name.replace(\"/\", \"-\").replace(\"_\", \"-\")\n        snr_url = self.base_url + self.snr_file_template.format(\n            model_name_slug=model_slug\n        )\n        snr_path = self.base_path + self.snr_file_template.format(\n            model_name_slug=model_slug\n        )\n        # first check if the files exist locally and read the json\n        snr_data = None\n        try:\n            with open(snr_path, \"r\", encoding=\"utf-8\") as fin:\n                snr_data = json.load(fin)\n        except FileNotFoundError:\n            pass\n        except Exception as exc:\n            LOG.warning(f\"Failed to read SNR data from {snr_path}: {exc}\")\n\n        if not snr_data:\n            try:\n                snr_data = requests.get(snr_url, timeout=60).json()\n            except requests.exceptions.RequestException as exc:\n                LOG.warning(f\"Failed to fetch SNR data from {snr_url}: {exc}\")\n                return\n            # also catch json parsing errors\n            except json.JSONDecodeError as exc:\n                LOG.warning(f\"Failed to parse SNR data from {snr_url}: {exc}\")\n                return\n\n        unfrozen_parameters = _generate_unfrozen_params_yaml(\n            snr_data, top_fraction=top_fraction\n        )\n        cfg[\"unfrozen_parameters\"] = unfrozen_parameters\n"
  },
  {
    "path": "src/axolotl/integrations/spectrum/args.py",
    "content": "# Copyright 2024 Axolotl AI. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nModule for handling Spectrum input arguments.\n\"\"\"\n\nfrom typing import Optional\n\nfrom pydantic import BaseModel, model_validator\n\n\nclass SpectrumArgs(BaseModel):\n    \"\"\"\n    Input args for Spectrum.\n    \"\"\"\n\n    spectrum_top_fraction: Optional[float] = 0.5\n    spectrum_model_name: Optional[str] = None\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_fsdp_use_orig_params(cls, data):\n        if (\n            data.get(\"fsdp\")\n            and data.get(\"fsdp_config\")\n            and not data[\"fsdp_config\"].get(\"use_orig_params\")\n            and data.get(\"plugins\")\n            and any(\"SpectrumPlugin\" in plugin for plugin in data[\"plugins\"])\n        ):\n            # would otherwise raise\n            # ValueError: Must flatten tensors with uniform `requires_grad` when `use_orig_params=False`\n            raise ValueError(\n                \"FSDP + SpectrumPlugin cannot be used together when `use_orig_params=False` is set\"\n            )\n        return data\n"
  },
  {
    "path": "src/axolotl/integrations/spectrum/model_snr_results/snr_results_Qwen-Qwen2.5-1.5B-Instruct.json",
    "content": "{\n    \"model.layers.0.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.1.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.2.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.3.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.4.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.5.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.6.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.7.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.8.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.9.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.10.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.11.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.12.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.13.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.14.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.15.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.16.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.17.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.18.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.19.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.20.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.21.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.22.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.23.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.24.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.25.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.26.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.27.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"lm_head\": {\n        \"snr\": Infinity,\n        \"type\": \"lm_head\"\n    },\n    \"model.layers.0.mlp.down_proj\": {\n        \"snr\": 70.50235748291016,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.1.mlp.down_proj\": {\n        \"snr\": 134.4214630126953,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.2.mlp.down_proj\": {\n        \"snr\": 235.74794006347656,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.3.mlp.down_proj\": {\n        \"snr\": 73.25755310058594,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.4.mlp.down_proj\": {\n        \"snr\": 27.22879981994629,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.5.mlp.down_proj\": {\n        \"snr\": 17.5551815032959,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.6.mlp.down_proj\": {\n        \"snr\": 54.210426330566406,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.7.mlp.down_proj\": {\n        \"snr\": 38.808937072753906,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.8.mlp.down_proj\": {\n        \"snr\": 29.799747467041016,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.9.mlp.down_proj\": {\n        \"snr\": 10.296355247497559,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.10.mlp.down_proj\": {\n        \"snr\": 8.86428165435791,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.11.mlp.down_proj\": {\n        \"snr\": 6.43813943862915,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.12.mlp.down_proj\": {\n        \"snr\": 7.0912184715271,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.13.mlp.down_proj\": {\n        \"snr\": 3.285884141921997,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.14.mlp.down_proj\": {\n        \"snr\": 6.073758125305176,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.15.mlp.down_proj\": {\n        \"snr\": 5.325990676879883,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.16.mlp.down_proj\": {\n        \"snr\": 4.591946601867676,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.17.mlp.down_proj\": {\n        \"snr\": 7.021907329559326,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.18.mlp.down_proj\": {\n        \"snr\": 6.392782211303711,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.19.mlp.down_proj\": {\n        \"snr\": 210.51983642578125,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.20.mlp.down_proj\": {\n        \"snr\": 7.1035943031311035,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.21.mlp.down_proj\": {\n        \"snr\": 18.701711654663086,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.22.mlp.down_proj\": {\n        \"snr\": 14.842622756958008,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.23.mlp.down_proj\": {\n        \"snr\": 10.50004768371582,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.24.mlp.down_proj\": {\n        \"snr\": 7.225146770477295,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.25.mlp.down_proj\": {\n        \"snr\": 7.463952541351318,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.26.mlp.down_proj\": {\n        \"snr\": 15.226134300231934,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.27.mlp.down_proj\": {\n        \"snr\": 105.4173355102539,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.0.mlp.gate_proj\": {\n        \"snr\": 0.5021594166755676,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.1.mlp.gate_proj\": {\n        \"snr\": 34.75935363769531,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.2.mlp.gate_proj\": {\n        \"snr\": 22.855531692504883,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.3.mlp.gate_proj\": {\n        \"snr\": 25.09166717529297,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.4.mlp.gate_proj\": {\n        \"snr\": 28.533172607421875,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.5.mlp.gate_proj\": {\n        \"snr\": 18.625717163085938,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.6.mlp.gate_proj\": {\n        \"snr\": 39.77565383911133,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.7.mlp.gate_proj\": {\n        \"snr\": 24.77678680419922,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.8.mlp.gate_proj\": {\n        \"snr\": 11.854388236999512,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.9.mlp.gate_proj\": {\n        \"snr\": 20.372356414794922,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.10.mlp.gate_proj\": {\n        \"snr\": 14.639552116394043,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.11.mlp.gate_proj\": {\n        \"snr\": 9.82955551147461,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.12.mlp.gate_proj\": {\n        \"snr\": 13.942151069641113,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.13.mlp.gate_proj\": {\n        \"snr\": 12.524999618530273,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.14.mlp.gate_proj\": {\n        \"snr\": 8.19681167602539,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.15.mlp.gate_proj\": {\n        \"snr\": 8.561081886291504,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.16.mlp.gate_proj\": {\n        \"snr\": 6.421900749206543,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.17.mlp.gate_proj\": {\n        \"snr\": 5.568161964416504,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.18.mlp.gate_proj\": {\n        \"snr\": 10.090147972106934,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.19.mlp.gate_proj\": {\n        \"snr\": 5.6181230545043945,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.20.mlp.gate_proj\": {\n        \"snr\": 5.173826694488525,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.21.mlp.gate_proj\": {\n        \"snr\": 5.663441181182861,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.22.mlp.gate_proj\": {\n        \"snr\": 6.824708461761475,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.23.mlp.gate_proj\": {\n        \"snr\": 4.724992275238037,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.24.mlp.gate_proj\": {\n        \"snr\": 6.829834938049316,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.25.mlp.gate_proj\": {\n        \"snr\": 9.968582153320312,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.26.mlp.gate_proj\": {\n        \"snr\": 14.35350513458252,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.27.mlp.gate_proj\": {\n        \"snr\": 20.121768951416016,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.0.mlp.up_proj\": {\n        \"snr\": 1.9020992517471313,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.1.mlp.up_proj\": {\n        \"snr\": 46.9393424987793,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.2.mlp.up_proj\": {\n        \"snr\": 76.04901123046875,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.3.mlp.up_proj\": {\n        \"snr\": 104.08525848388672,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.4.mlp.up_proj\": {\n        \"snr\": 77.74343872070312,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.5.mlp.up_proj\": {\n        \"snr\": 104.15605926513672,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.6.mlp.up_proj\": {\n        \"snr\": 105.16349792480469,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.7.mlp.up_proj\": {\n        \"snr\": 78.4150390625,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.8.mlp.up_proj\": {\n        \"snr\": 57.51069641113281,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.9.mlp.up_proj\": {\n        \"snr\": 50.26409912109375,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.10.mlp.up_proj\": {\n        \"snr\": 50.36701965332031,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.11.mlp.up_proj\": {\n        \"snr\": 56.66413497924805,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.12.mlp.up_proj\": {\n        \"snr\": 62.384559631347656,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.13.mlp.up_proj\": {\n        \"snr\": 44.97883987426758,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.14.mlp.up_proj\": {\n        \"snr\": 69.7376480102539,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.15.mlp.up_proj\": {\n        \"snr\": 35.93111801147461,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.16.mlp.up_proj\": {\n        \"snr\": 33.63168716430664,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.17.mlp.up_proj\": {\n        \"snr\": 37.695919036865234,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.18.mlp.up_proj\": {\n        \"snr\": 43.516517639160156,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.19.mlp.up_proj\": {\n        \"snr\": 30.479318618774414,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.20.mlp.up_proj\": {\n        \"snr\": 12.495409965515137,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.21.mlp.up_proj\": {\n        \"snr\": 19.616689682006836,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.22.mlp.up_proj\": {\n        \"snr\": 18.42948341369629,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.23.mlp.up_proj\": {\n        \"snr\": 10.799560546875,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.24.mlp.up_proj\": {\n        \"snr\": 14.167623519897461,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.25.mlp.up_proj\": {\n        \"snr\": 14.938597679138184,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.26.mlp.up_proj\": {\n        \"snr\": 8.896568298339844,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.27.mlp.up_proj\": {\n        \"snr\": 25.774547576904297,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.embed_tokens\": {\n        \"snr\": Infinity,\n        \"type\": \"model.embed_tokens\"\n    },\n    \"model.norm\": {\n        \"snr\": Infinity,\n        \"type\": \"model.norm\"\n    },\n    \"model.layers.0.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.1.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.2.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.3.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.4.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.5.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.6.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.7.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.8.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.9.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.10.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.11.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.12.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.13.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.14.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.15.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.16.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.17.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.18.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.19.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.20.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.21.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.22.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.23.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.24.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.25.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.26.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.27.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.0.self_attn.k_proj\": {\n        \"snr\": 1.8306859731674194,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.1.self_attn.k_proj\": {\n        \"snr\": 0.896544337272644,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.2.self_attn.k_proj\": {\n        \"snr\": 2.345759868621826,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.3.self_attn.k_proj\": {\n        \"snr\": 2.0610744953155518,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.4.self_attn.k_proj\": {\n        \"snr\": 2.3658556938171387,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.5.self_attn.k_proj\": {\n        \"snr\": 1.6586917638778687,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.6.self_attn.k_proj\": {\n        \"snr\": 1.7613047361373901,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.7.self_attn.k_proj\": {\n        \"snr\": 1.325312852859497,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.8.self_attn.k_proj\": {\n        \"snr\": 1.458108901977539,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.9.self_attn.k_proj\": {\n        \"snr\": 1.4319790601730347,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.10.self_attn.k_proj\": {\n        \"snr\": 0.9579543471336365,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.11.self_attn.k_proj\": {\n        \"snr\": 0.8787619471549988,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.12.self_attn.k_proj\": {\n        \"snr\": 1.0447536706924438,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.13.self_attn.k_proj\": {\n        \"snr\": 0.9157310724258423,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.14.self_attn.k_proj\": {\n        \"snr\": 0.7528730630874634,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.15.self_attn.k_proj\": {\n        \"snr\": 0.9293556213378906,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.16.self_attn.k_proj\": {\n        \"snr\": 0.8057093620300293,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.17.self_attn.k_proj\": {\n        \"snr\": 1.2973601818084717,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.18.self_attn.k_proj\": {\n        \"snr\": 1.1357901096343994,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.19.self_attn.k_proj\": {\n        \"snr\": 1.3661632537841797,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.20.self_attn.k_proj\": {\n        \"snr\": 0.8829066753387451,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.21.self_attn.k_proj\": {\n        \"snr\": 0.9105398654937744,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.22.self_attn.k_proj\": {\n        \"snr\": 2.086926221847534,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.23.self_attn.k_proj\": {\n        \"snr\": 1.0393351316452026,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.24.self_attn.k_proj\": {\n        \"snr\": 1.114574670791626,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.25.self_attn.k_proj\": {\n        \"snr\": 2.599745035171509,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.26.self_attn.k_proj\": {\n        \"snr\": 1.1256712675094604,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.27.self_attn.k_proj\": {\n        \"snr\": 1.1784162521362305,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.0.self_attn.o_proj\": {\n        \"snr\": 0.8094121813774109,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.1.self_attn.o_proj\": {\n        \"snr\": 0.22000817954540253,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.2.self_attn.o_proj\": {\n        \"snr\": 0.21972468495368958,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.3.self_attn.o_proj\": {\n        \"snr\": 0.22064059972763062,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.4.self_attn.o_proj\": {\n        \"snr\": 0.22308556735515594,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.5.self_attn.o_proj\": {\n        \"snr\": 0.22396250069141388,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.6.self_attn.o_proj\": {\n        \"snr\": 0.228360116481781,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.7.self_attn.o_proj\": {\n        \"snr\": 0.2306283563375473,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.8.self_attn.o_proj\": {\n        \"snr\": 0.2430228292942047,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.9.self_attn.o_proj\": {\n        \"snr\": 0.2115175724029541,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.10.self_attn.o_proj\": {\n        \"snr\": 0.18226943910121918,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.11.self_attn.o_proj\": {\n        \"snr\": 0.144245907664299,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.12.self_attn.o_proj\": {\n        \"snr\": 0.21965907514095306,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.13.self_attn.o_proj\": {\n        \"snr\": 0.1797526627779007,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.14.self_attn.o_proj\": {\n        \"snr\": 0.26513636112213135,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.15.self_attn.o_proj\": {\n        \"snr\": 0.19463808834552765,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.16.self_attn.o_proj\": {\n        \"snr\": 0.22129350900650024,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.17.self_attn.o_proj\": {\n        \"snr\": 0.22545330226421356,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.18.self_attn.o_proj\": {\n        \"snr\": 0.25302645564079285,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.19.self_attn.o_proj\": {\n        \"snr\": 0.26326504349708557,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.20.self_attn.o_proj\": {\n        \"snr\": 0.15203869342803955,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.21.self_attn.o_proj\": {\n        \"snr\": 0.22418837249279022,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.22.self_attn.o_proj\": {\n        \"snr\": 0.23777326941490173,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.23.self_attn.o_proj\": {\n        \"snr\": 0.18076598644256592,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.24.self_attn.o_proj\": {\n        \"snr\": 0.19919466972351074,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.25.self_attn.o_proj\": {\n        \"snr\": 0.11310968548059464,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.26.self_attn.o_proj\": {\n        \"snr\": 0.08452697843313217,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.27.self_attn.o_proj\": {\n        \"snr\": 0.1029304787516594,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.0.self_attn.q_proj\": {\n        \"snr\": 0.03922705352306366,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.1.self_attn.q_proj\": {\n        \"snr\": 0.1410205066204071,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.2.self_attn.q_proj\": {\n        \"snr\": 0.18240582942962646,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.3.self_attn.q_proj\": {\n        \"snr\": 0.1702580451965332,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.4.self_attn.q_proj\": {\n        \"snr\": 0.19508686661720276,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.5.self_attn.q_proj\": {\n        \"snr\": 0.21549257636070251,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.6.self_attn.q_proj\": {\n        \"snr\": 0.22021502256393433,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.7.self_attn.q_proj\": {\n        \"snr\": 0.2044307142496109,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.8.self_attn.q_proj\": {\n        \"snr\": 0.22745060920715332,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.9.self_attn.q_proj\": {\n        \"snr\": 0.23825915157794952,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.10.self_attn.q_proj\": {\n        \"snr\": 0.2181481122970581,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.11.self_attn.q_proj\": {\n        \"snr\": 0.23490090668201447,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.12.self_attn.q_proj\": {\n        \"snr\": 0.2379382699728012,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.13.self_attn.q_proj\": {\n        \"snr\": 0.19233369827270508,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.14.self_attn.q_proj\": {\n        \"snr\": 0.2587313652038574,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.15.self_attn.q_proj\": {\n        \"snr\": 0.07332809269428253,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.16.self_attn.q_proj\": {\n        \"snr\": 0.22992204129695892,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.17.self_attn.q_proj\": {\n        \"snr\": 0.2537729740142822,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.18.self_attn.q_proj\": {\n        \"snr\": 0.2389948070049286,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.19.self_attn.q_proj\": {\n        \"snr\": 0.20716068148612976,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.20.self_attn.q_proj\": {\n        \"snr\": 0.2575169503688812,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.21.self_attn.q_proj\": {\n        \"snr\": 0.22347678244113922,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.22.self_attn.q_proj\": {\n        \"snr\": 0.18831054866313934,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.23.self_attn.q_proj\": {\n        \"snr\": 0.19853907823562622,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.24.self_attn.q_proj\": {\n        \"snr\": 0.16343259811401367,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.25.self_attn.q_proj\": {\n        \"snr\": 0.1583252102136612,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.26.self_attn.q_proj\": {\n        \"snr\": 0.254446804523468,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.27.self_attn.q_proj\": {\n        \"snr\": 0.23828543722629547,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.0.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.1.self_attn.v_proj\": {\n        \"snr\": 856.5148315429688,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.2.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.3.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.4.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.5.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.6.self_attn.v_proj\": {\n        \"snr\": 48.941104888916016,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.7.self_attn.v_proj\": {\n        \"snr\": 70.25466918945312,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.8.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.9.self_attn.v_proj\": {\n        \"snr\": 370.885986328125,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.10.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.11.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.12.self_attn.v_proj\": {\n        \"snr\": 75.51139831542969,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.13.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.14.self_attn.v_proj\": {\n        \"snr\": 52.004058837890625,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.15.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.16.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.17.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.18.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.19.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.20.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.21.self_attn.v_proj\": {\n        \"snr\": 641.026611328125,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.22.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.23.self_attn.v_proj\": {\n        \"snr\": 323.4858093261719,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.24.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.25.self_attn.v_proj\": {\n        \"snr\": 2.1745388507843018,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.26.self_attn.v_proj\": {\n        \"snr\": 3.0791690349578857,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.27.self_attn.v_proj\": {\n        \"snr\": 2.029968023300171,\n        \"type\": \"self_attn.v_proj\"\n    }\n}\n"
  },
  {
    "path": "src/axolotl/integrations/spectrum/model_snr_results/snr_results_Qwen-Qwen2.5-1.5B.json",
    "content": "{\n    \"model.layers.0.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.1.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.2.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.3.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.4.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.5.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.6.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.7.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.8.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.9.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.10.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.11.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.12.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.13.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.14.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.15.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.16.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.17.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.18.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.19.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.20.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.21.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.22.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.23.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.24.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.25.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.26.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.27.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"lm_head\": {\n        \"snr\": Infinity,\n        \"type\": \"lm_head\"\n    },\n    \"model.layers.0.mlp.down_proj\": {\n        \"snr\": 70.4939193725586,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.1.mlp.down_proj\": {\n        \"snr\": 134.2310028076172,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.2.mlp.down_proj\": {\n        \"snr\": 235.44140625,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.3.mlp.down_proj\": {\n        \"snr\": 73.19381713867188,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.4.mlp.down_proj\": {\n        \"snr\": 27.216264724731445,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.5.mlp.down_proj\": {\n        \"snr\": 17.544504165649414,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.6.mlp.down_proj\": {\n        \"snr\": 54.17462158203125,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.7.mlp.down_proj\": {\n        \"snr\": 38.78171920776367,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.8.mlp.down_proj\": {\n        \"snr\": 29.777149200439453,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.9.mlp.down_proj\": {\n        \"snr\": 10.289377212524414,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.10.mlp.down_proj\": {\n        \"snr\": 8.858332633972168,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.11.mlp.down_proj\": {\n        \"snr\": 6.433396816253662,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.12.mlp.down_proj\": {\n        \"snr\": 7.085702419281006,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.13.mlp.down_proj\": {\n        \"snr\": 3.323948383331299,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.14.mlp.down_proj\": {\n        \"snr\": 6.204164505004883,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.15.mlp.down_proj\": {\n        \"snr\": 5.321533203125,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.16.mlp.down_proj\": {\n        \"snr\": 4.588479995727539,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.17.mlp.down_proj\": {\n        \"snr\": 7.01450252532959,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.18.mlp.down_proj\": {\n        \"snr\": 6.386813163757324,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.19.mlp.down_proj\": {\n        \"snr\": 210.38458251953125,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.20.mlp.down_proj\": {\n        \"snr\": 7.096683979034424,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.21.mlp.down_proj\": {\n        \"snr\": 18.68245506286621,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.22.mlp.down_proj\": {\n        \"snr\": 14.824685096740723,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.23.mlp.down_proj\": {\n        \"snr\": 10.491303443908691,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.24.mlp.down_proj\": {\n        \"snr\": 7.2194437980651855,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.25.mlp.down_proj\": {\n        \"snr\": 7.458613872528076,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.26.mlp.down_proj\": {\n        \"snr\": 15.222760200500488,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.27.mlp.down_proj\": {\n        \"snr\": 105.41569519042969,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.0.mlp.gate_proj\": {\n        \"snr\": 0.5017311573028564,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.1.mlp.gate_proj\": {\n        \"snr\": 34.71562576293945,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.2.mlp.gate_proj\": {\n        \"snr\": 22.82915496826172,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.3.mlp.gate_proj\": {\n        \"snr\": 25.0699520111084,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.4.mlp.gate_proj\": {\n        \"snr\": 28.508079528808594,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.5.mlp.gate_proj\": {\n        \"snr\": 18.608009338378906,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.6.mlp.gate_proj\": {\n        \"snr\": 39.732391357421875,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.7.mlp.gate_proj\": {\n        \"snr\": 24.760026931762695,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.8.mlp.gate_proj\": {\n        \"snr\": 11.842738151550293,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.9.mlp.gate_proj\": {\n        \"snr\": 20.35906982421875,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.10.mlp.gate_proj\": {\n        \"snr\": 14.627532958984375,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.11.mlp.gate_proj\": {\n        \"snr\": 9.821962356567383,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.12.mlp.gate_proj\": {\n        \"snr\": 13.930404663085938,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.13.mlp.gate_proj\": {\n        \"snr\": 12.509871482849121,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.14.mlp.gate_proj\": {\n        \"snr\": 8.187695503234863,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.15.mlp.gate_proj\": {\n        \"snr\": 8.553187370300293,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.16.mlp.gate_proj\": {\n        \"snr\": 6.414614200592041,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.17.mlp.gate_proj\": {\n        \"snr\": 5.561778545379639,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.18.mlp.gate_proj\": {\n        \"snr\": 10.078697204589844,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.19.mlp.gate_proj\": {\n        \"snr\": 5.61345100402832,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.20.mlp.gate_proj\": {\n        \"snr\": 5.265484809875488,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.21.mlp.gate_proj\": {\n        \"snr\": 5.659949779510498,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.22.mlp.gate_proj\": {\n        \"snr\": 6.8203511238098145,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.23.mlp.gate_proj\": {\n        \"snr\": 4.721294403076172,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.24.mlp.gate_proj\": {\n        \"snr\": 6.82572603225708,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.25.mlp.gate_proj\": {\n        \"snr\": 9.963521003723145,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.26.mlp.gate_proj\": {\n        \"snr\": 14.342291831970215,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.27.mlp.gate_proj\": {\n        \"snr\": 20.092098236083984,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.0.mlp.up_proj\": {\n        \"snr\": 1.901187777519226,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.1.mlp.up_proj\": {\n        \"snr\": 46.9141731262207,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.2.mlp.up_proj\": {\n        \"snr\": 76.07878112792969,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.3.mlp.up_proj\": {\n        \"snr\": 103.9194564819336,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.4.mlp.up_proj\": {\n        \"snr\": 77.62561798095703,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.5.mlp.up_proj\": {\n        \"snr\": 104.01624298095703,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.6.mlp.up_proj\": {\n        \"snr\": 105.0235366821289,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.7.mlp.up_proj\": {\n        \"snr\": 78.33445739746094,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.8.mlp.up_proj\": {\n        \"snr\": 57.44070816040039,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.9.mlp.up_proj\": {\n        \"snr\": 50.20344924926758,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.10.mlp.up_proj\": {\n        \"snr\": 50.32845687866211,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.11.mlp.up_proj\": {\n        \"snr\": 56.6197624206543,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.12.mlp.up_proj\": {\n        \"snr\": 62.338096618652344,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.13.mlp.up_proj\": {\n        \"snr\": 44.92917251586914,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.14.mlp.up_proj\": {\n        \"snr\": 69.69624328613281,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.15.mlp.up_proj\": {\n        \"snr\": 35.90705108642578,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.16.mlp.up_proj\": {\n        \"snr\": 33.610374450683594,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.17.mlp.up_proj\": {\n        \"snr\": 37.67365646362305,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.18.mlp.up_proj\": {\n        \"snr\": 43.488929748535156,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.19.mlp.up_proj\": {\n        \"snr\": 30.451993942260742,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.20.mlp.up_proj\": {\n        \"snr\": 12.480182647705078,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.21.mlp.up_proj\": {\n        \"snr\": 19.595102310180664,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.22.mlp.up_proj\": {\n        \"snr\": 19.067970275878906,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.23.mlp.up_proj\": {\n        \"snr\": 10.786394119262695,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.24.mlp.up_proj\": {\n        \"snr\": 14.150126457214355,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.25.mlp.up_proj\": {\n        \"snr\": 14.927021026611328,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.26.mlp.up_proj\": {\n        \"snr\": 8.891448020935059,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.27.mlp.up_proj\": {\n        \"snr\": 25.74305534362793,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.embed_tokens\": {\n        \"snr\": Infinity,\n        \"type\": \"model.embed_tokens\"\n    },\n    \"model.norm\": {\n        \"snr\": Infinity,\n        \"type\": \"model.norm\"\n    },\n    \"model.layers.0.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.1.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.2.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.3.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.4.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.5.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.6.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.7.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.8.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.9.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.10.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.11.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.12.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.13.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.14.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.15.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.16.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.17.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.18.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.19.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.20.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.21.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.22.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.23.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.24.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.25.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.26.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.27.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.0.self_attn.k_proj\": {\n        \"snr\": 1.7818864583969116,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.1.self_attn.k_proj\": {\n        \"snr\": 0.8955822587013245,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.2.self_attn.k_proj\": {\n        \"snr\": 2.344149351119995,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.3.self_attn.k_proj\": {\n        \"snr\": 2.0597119331359863,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.4.self_attn.k_proj\": {\n        \"snr\": 2.36411714553833,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.5.self_attn.k_proj\": {\n        \"snr\": 1.6570613384246826,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.6.self_attn.k_proj\": {\n        \"snr\": 1.7604507207870483,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.7.self_attn.k_proj\": {\n        \"snr\": 1.3245182037353516,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.8.self_attn.k_proj\": {\n        \"snr\": 1.4567548036575317,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.9.self_attn.k_proj\": {\n        \"snr\": 1.4310829639434814,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.10.self_attn.k_proj\": {\n        \"snr\": 0.95713210105896,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.11.self_attn.k_proj\": {\n        \"snr\": 0.8781776428222656,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.12.self_attn.k_proj\": {\n        \"snr\": 1.0438013076782227,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.13.self_attn.k_proj\": {\n        \"snr\": 0.9315219521522522,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.14.self_attn.k_proj\": {\n        \"snr\": 0.7521569728851318,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.15.self_attn.k_proj\": {\n        \"snr\": 0.9286947250366211,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.16.self_attn.k_proj\": {\n        \"snr\": 0.8047553896903992,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.17.self_attn.k_proj\": {\n        \"snr\": 1.2965552806854248,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.18.self_attn.k_proj\": {\n        \"snr\": 1.134974479675293,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.19.self_attn.k_proj\": {\n        \"snr\": 1.3648872375488281,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.20.self_attn.k_proj\": {\n        \"snr\": 0.8667459487915039,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.21.self_attn.k_proj\": {\n        \"snr\": 0.9100639224052429,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.22.self_attn.k_proj\": {\n        \"snr\": 2.127535820007324,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.23.self_attn.k_proj\": {\n        \"snr\": 1.0382369756698608,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.24.self_attn.k_proj\": {\n        \"snr\": 1.113753318786621,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.25.self_attn.k_proj\": {\n        \"snr\": 2.597890853881836,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.26.self_attn.k_proj\": {\n        \"snr\": 1.1248247623443604,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.27.self_attn.k_proj\": {\n        \"snr\": 1.1984941959381104,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.0.self_attn.o_proj\": {\n        \"snr\": 0.8139898777008057,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.1.self_attn.o_proj\": {\n        \"snr\": 0.21965594589710236,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.2.self_attn.o_proj\": {\n        \"snr\": 0.219479501247406,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.3.self_attn.o_proj\": {\n        \"snr\": 0.22144284844398499,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.4.self_attn.o_proj\": {\n        \"snr\": 0.22390463948249817,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.5.self_attn.o_proj\": {\n        \"snr\": 0.22383669018745422,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.6.self_attn.o_proj\": {\n        \"snr\": 0.22818723320960999,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.7.self_attn.o_proj\": {\n        \"snr\": 0.23134392499923706,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.8.self_attn.o_proj\": {\n        \"snr\": 0.24275101721286774,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.9.self_attn.o_proj\": {\n        \"snr\": 0.21139128506183624,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.10.self_attn.o_proj\": {\n        \"snr\": 0.18210072815418243,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.11.self_attn.o_proj\": {\n        \"snr\": 0.14415481686592102,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.12.self_attn.o_proj\": {\n        \"snr\": 0.21947966516017914,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.13.self_attn.o_proj\": {\n        \"snr\": 0.17875106632709503,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.14.self_attn.o_proj\": {\n        \"snr\": 0.264996200799942,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.15.self_attn.o_proj\": {\n        \"snr\": 0.19353187084197998,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.16.self_attn.o_proj\": {\n        \"snr\": 0.22111012041568756,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.17.self_attn.o_proj\": {\n        \"snr\": 0.2242278754711151,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.18.self_attn.o_proj\": {\n        \"snr\": 0.2527434229850769,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.19.self_attn.o_proj\": {\n        \"snr\": 0.26184532046318054,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.20.self_attn.o_proj\": {\n        \"snr\": 0.1519661247730255,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.21.self_attn.o_proj\": {\n        \"snr\": 0.22386522591114044,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.22.self_attn.o_proj\": {\n        \"snr\": 0.2386160045862198,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.23.self_attn.o_proj\": {\n        \"snr\": 0.18057651817798615,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.24.self_attn.o_proj\": {\n        \"snr\": 0.1989467740058899,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.25.self_attn.o_proj\": {\n        \"snr\": 0.11306505650281906,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.26.self_attn.o_proj\": {\n        \"snr\": 0.08449216932058334,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.27.self_attn.o_proj\": {\n        \"snr\": 0.10287519544363022,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.0.self_attn.q_proj\": {\n        \"snr\": 0.039204664528369904,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.1.self_attn.q_proj\": {\n        \"snr\": 0.14075909554958344,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.2.self_attn.q_proj\": {\n        \"snr\": 0.18212397396564484,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.3.self_attn.q_proj\": {\n        \"snr\": 0.1700422316789627,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.4.self_attn.q_proj\": {\n        \"snr\": 0.1948907971382141,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.5.self_attn.q_proj\": {\n        \"snr\": 0.2153141051530838,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.6.self_attn.q_proj\": {\n        \"snr\": 0.21998055279254913,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.7.self_attn.q_proj\": {\n        \"snr\": 0.20416118204593658,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.8.self_attn.q_proj\": {\n        \"snr\": 0.2272879034280777,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.9.self_attn.q_proj\": {\n        \"snr\": 0.23795834183692932,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.10.self_attn.q_proj\": {\n        \"snr\": 0.21887299418449402,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.11.self_attn.q_proj\": {\n        \"snr\": 0.23469635844230652,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.12.self_attn.q_proj\": {\n        \"snr\": 0.23774078488349915,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.13.self_attn.q_proj\": {\n        \"snr\": 0.1920779049396515,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.14.self_attn.q_proj\": {\n        \"snr\": 0.2584812641143799,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.15.self_attn.q_proj\": {\n        \"snr\": 0.07330238074064255,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.16.self_attn.q_proj\": {\n        \"snr\": 0.23073157668113708,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.17.self_attn.q_proj\": {\n        \"snr\": 0.2523840367794037,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.18.self_attn.q_proj\": {\n        \"snr\": 0.23874858021736145,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.19.self_attn.q_proj\": {\n        \"snr\": 0.20698708295822144,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.20.self_attn.q_proj\": {\n        \"snr\": 0.25723400712013245,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.21.self_attn.q_proj\": {\n        \"snr\": 0.223300039768219,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.22.self_attn.q_proj\": {\n        \"snr\": 0.18824049830436707,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.23.self_attn.q_proj\": {\n        \"snr\": 0.19840741157531738,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.24.self_attn.q_proj\": {\n        \"snr\": 0.16326843202114105,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.25.self_attn.q_proj\": {\n        \"snr\": 0.1581888198852539,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.26.self_attn.q_proj\": {\n        \"snr\": 0.25306230783462524,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.27.self_attn.q_proj\": {\n        \"snr\": 0.23808495700359344,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.0.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.1.self_attn.v_proj\": {\n        \"snr\": 864.8881225585938,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.2.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.3.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.4.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.5.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.6.self_attn.v_proj\": {\n        \"snr\": 48.853694915771484,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.7.self_attn.v_proj\": {\n        \"snr\": 70.18457794189453,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.8.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.9.self_attn.v_proj\": {\n        \"snr\": 371.1153259277344,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.10.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.11.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.12.self_attn.v_proj\": {\n        \"snr\": 75.41203308105469,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.13.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.14.self_attn.v_proj\": {\n        \"snr\": 51.92624282836914,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.15.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.16.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.17.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.18.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.19.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.20.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.21.self_attn.v_proj\": {\n        \"snr\": 642.9313354492188,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.22.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.23.self_attn.v_proj\": {\n        \"snr\": 323.5724182128906,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.24.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.25.self_attn.v_proj\": {\n        \"snr\": 2.1736748218536377,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.26.self_attn.v_proj\": {\n        \"snr\": 3.1729259490966797,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.27.self_attn.v_proj\": {\n        \"snr\": 2.024953842163086,\n        \"type\": \"self_attn.v_proj\"\n    }\n}\n"
  },
  {
    "path": "src/axolotl/integrations/spectrum/model_snr_results/snr_results_Qwen-Qwen2.5-3B-Instruct.json",
    "content": "{\n    \"model.layers.0.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.1.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.2.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.3.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.4.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.5.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.6.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.7.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.8.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.9.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.10.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.11.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.12.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.13.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.14.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.15.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.16.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.17.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.18.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.19.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.20.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.21.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.22.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.23.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.24.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.25.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.26.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.27.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.28.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.29.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.30.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.31.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.32.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.33.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.34.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.35.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"lm_head\": {\n        \"snr\": Infinity,\n        \"type\": \"lm_head\"\n    },\n    \"model.layers.0.mlp.down_proj\": {\n        \"snr\": 20.964319229125977,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.1.mlp.down_proj\": {\n        \"snr\": 0.11561352014541626,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.2.mlp.down_proj\": {\n        \"snr\": 0.14991413056850433,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.3.mlp.down_proj\": {\n        \"snr\": 0.3673713207244873,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.4.mlp.down_proj\": {\n        \"snr\": 0.5076134204864502,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.5.mlp.down_proj\": {\n        \"snr\": 33.89468002319336,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.6.mlp.down_proj\": {\n        \"snr\": 45.08732986450195,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.7.mlp.down_proj\": {\n        \"snr\": 33.234222412109375,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.8.mlp.down_proj\": {\n        \"snr\": 29.3447322845459,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.9.mlp.down_proj\": {\n        \"snr\": 26.664169311523438,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.10.mlp.down_proj\": {\n        \"snr\": 22.323949813842773,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.11.mlp.down_proj\": {\n        \"snr\": 18.259737014770508,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.12.mlp.down_proj\": {\n        \"snr\": 14.422037124633789,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.13.mlp.down_proj\": {\n        \"snr\": 22.172054290771484,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.14.mlp.down_proj\": {\n        \"snr\": 27.363698959350586,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.15.mlp.down_proj\": {\n        \"snr\": 28.474334716796875,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.16.mlp.down_proj\": {\n        \"snr\": 10.4143648147583,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.17.mlp.down_proj\": {\n        \"snr\": 10.719133377075195,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.18.mlp.down_proj\": {\n        \"snr\": 8.6494722366333,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.19.mlp.down_proj\": {\n        \"snr\": 5.69321870803833,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.20.mlp.down_proj\": {\n        \"snr\": 23.889677047729492,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.21.mlp.down_proj\": {\n        \"snr\": 11.59121036529541,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.22.mlp.down_proj\": {\n        \"snr\": 5.997435569763184,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.23.mlp.down_proj\": {\n        \"snr\": 19.415578842163086,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.24.mlp.down_proj\": {\n        \"snr\": 8.241704940795898,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.25.mlp.down_proj\": {\n        \"snr\": 12.993823051452637,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.26.mlp.down_proj\": {\n        \"snr\": 36.26508712768555,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.27.mlp.down_proj\": {\n        \"snr\": 19.957971572875977,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.28.mlp.down_proj\": {\n        \"snr\": 6.067765235900879,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.29.mlp.down_proj\": {\n        \"snr\": 5.369481086730957,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.30.mlp.down_proj\": {\n        \"snr\": 7.358774662017822,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.31.mlp.down_proj\": {\n        \"snr\": 7.8687238693237305,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.32.mlp.down_proj\": {\n        \"snr\": 8.713484764099121,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.33.mlp.down_proj\": {\n        \"snr\": 21.233531951904297,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.34.mlp.down_proj\": {\n        \"snr\": 32.37357711791992,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.35.mlp.down_proj\": {\n        \"snr\": 179.8053741455078,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.0.mlp.gate_proj\": {\n        \"snr\": 0.24989914894104004,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.1.mlp.gate_proj\": {\n        \"snr\": 0.11613649874925613,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.2.mlp.gate_proj\": {\n        \"snr\": 0.16354432702064514,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.3.mlp.gate_proj\": {\n        \"snr\": 0.36216047406196594,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.4.mlp.gate_proj\": {\n        \"snr\": 0.3485107719898224,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.5.mlp.gate_proj\": {\n        \"snr\": 2.6546616554260254,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.6.mlp.gate_proj\": {\n        \"snr\": 8.362885475158691,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.7.mlp.gate_proj\": {\n        \"snr\": 7.38665246963501,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.8.mlp.gate_proj\": {\n        \"snr\": 13.016111373901367,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.9.mlp.gate_proj\": {\n        \"snr\": 14.94902515411377,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.10.mlp.gate_proj\": {\n        \"snr\": 20.92418670654297,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.11.mlp.gate_proj\": {\n        \"snr\": 15.954015731811523,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.12.mlp.gate_proj\": {\n        \"snr\": 8.980009078979492,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.13.mlp.gate_proj\": {\n        \"snr\": 17.59958267211914,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.14.mlp.gate_proj\": {\n        \"snr\": 17.23070526123047,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.15.mlp.gate_proj\": {\n        \"snr\": 23.725330352783203,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.16.mlp.gate_proj\": {\n        \"snr\": 17.000444412231445,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.17.mlp.gate_proj\": {\n        \"snr\": 18.293012619018555,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.18.mlp.gate_proj\": {\n        \"snr\": 12.644190788269043,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.19.mlp.gate_proj\": {\n        \"snr\": 16.278690338134766,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.20.mlp.gate_proj\": {\n        \"snr\": 7.407368183135986,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.21.mlp.gate_proj\": {\n        \"snr\": 6.109912395477295,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.22.mlp.gate_proj\": {\n        \"snr\": 5.3692426681518555,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.23.mlp.gate_proj\": {\n        \"snr\": 9.354235649108887,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.24.mlp.gate_proj\": {\n        \"snr\": 7.655010223388672,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.25.mlp.gate_proj\": {\n        \"snr\": 6.252986431121826,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.26.mlp.gate_proj\": {\n        \"snr\": 14.26718521118164,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.27.mlp.gate_proj\": {\n        \"snr\": 7.705836772918701,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.28.mlp.gate_proj\": {\n        \"snr\": 5.998677730560303,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.29.mlp.gate_proj\": {\n        \"snr\": 6.044872760772705,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.30.mlp.gate_proj\": {\n        \"snr\": 9.027137756347656,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.31.mlp.gate_proj\": {\n        \"snr\": 5.449969291687012,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.32.mlp.gate_proj\": {\n        \"snr\": 4.206825256347656,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.33.mlp.gate_proj\": {\n        \"snr\": 5.22825288772583,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.34.mlp.gate_proj\": {\n        \"snr\": 43.71927261352539,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.35.mlp.gate_proj\": {\n        \"snr\": 45.37385177612305,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.0.mlp.up_proj\": {\n        \"snr\": 0.7069714665412903,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.1.mlp.up_proj\": {\n        \"snr\": 0.17766596376895905,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.2.mlp.up_proj\": {\n        \"snr\": 0.28577035665512085,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.3.mlp.up_proj\": {\n        \"snr\": 0.6763099431991577,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.4.mlp.up_proj\": {\n        \"snr\": 0.8340913653373718,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.5.mlp.up_proj\": {\n        \"snr\": 3.946547031402588,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.6.mlp.up_proj\": {\n        \"snr\": 19.56715202331543,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.7.mlp.up_proj\": {\n        \"snr\": 36.21149826049805,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.8.mlp.up_proj\": {\n        \"snr\": 44.28759002685547,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.9.mlp.up_proj\": {\n        \"snr\": 45.47198486328125,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.10.mlp.up_proj\": {\n        \"snr\": 79.00128936767578,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.11.mlp.up_proj\": {\n        \"snr\": 52.28038787841797,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.12.mlp.up_proj\": {\n        \"snr\": 48.08102035522461,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.13.mlp.up_proj\": {\n        \"snr\": 56.071285247802734,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.14.mlp.up_proj\": {\n        \"snr\": 72.24358367919922,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.15.mlp.up_proj\": {\n        \"snr\": 54.818233489990234,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.16.mlp.up_proj\": {\n        \"snr\": 47.251495361328125,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.17.mlp.up_proj\": {\n        \"snr\": 51.585636138916016,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.18.mlp.up_proj\": {\n        \"snr\": 43.47938919067383,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.19.mlp.up_proj\": {\n        \"snr\": 38.132469177246094,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.20.mlp.up_proj\": {\n        \"snr\": 21.78435707092285,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.21.mlp.up_proj\": {\n        \"snr\": 22.261096954345703,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.22.mlp.up_proj\": {\n        \"snr\": 30.751861572265625,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.23.mlp.up_proj\": {\n        \"snr\": 28.61063575744629,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.24.mlp.up_proj\": {\n        \"snr\": 20.21415901184082,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.25.mlp.up_proj\": {\n        \"snr\": 20.759052276611328,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.26.mlp.up_proj\": {\n        \"snr\": 33.80818557739258,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.27.mlp.up_proj\": {\n        \"snr\": 17.274362564086914,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.28.mlp.up_proj\": {\n        \"snr\": 13.943653106689453,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.29.mlp.up_proj\": {\n        \"snr\": 16.202186584472656,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.30.mlp.up_proj\": {\n        \"snr\": 24.25114631652832,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.31.mlp.up_proj\": {\n        \"snr\": 10.68645191192627,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.32.mlp.up_proj\": {\n        \"snr\": 5.7449774742126465,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.33.mlp.up_proj\": {\n        \"snr\": 11.879876136779785,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.34.mlp.up_proj\": {\n        \"snr\": 25.948715209960938,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.35.mlp.up_proj\": {\n        \"snr\": 38.63526153564453,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.embed_tokens\": {\n        \"snr\": Infinity,\n        \"type\": \"model.embed_tokens\"\n    },\n    \"model.norm\": {\n        \"snr\": Infinity,\n        \"type\": \"model.norm\"\n    },\n    \"model.layers.0.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.1.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.2.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.3.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.4.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.5.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.6.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.7.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.8.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.9.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.10.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.11.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.12.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.13.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.14.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.15.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.16.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.17.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.18.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.19.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.20.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.21.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.22.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.23.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.24.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.25.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.26.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.27.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.28.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.29.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.30.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.31.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.32.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.33.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.34.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.35.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.0.self_attn.k_proj\": {\n        \"snr\": 12.243099212646484,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.1.self_attn.k_proj\": {\n        \"snr\": 0.6446183323860168,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.2.self_attn.k_proj\": {\n        \"snr\": 0.7159711718559265,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.3.self_attn.k_proj\": {\n        \"snr\": 5.5100932121276855,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.4.self_attn.k_proj\": {\n        \"snr\": 3.0802414417266846,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.5.self_attn.k_proj\": {\n        \"snr\": 1.0472767353057861,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.6.self_attn.k_proj\": {\n        \"snr\": 3.576918601989746,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.7.self_attn.k_proj\": {\n        \"snr\": 3.3793225288391113,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.8.self_attn.k_proj\": {\n        \"snr\": 2.9598212242126465,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.9.self_attn.k_proj\": {\n        \"snr\": 6.102792263031006,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.10.self_attn.k_proj\": {\n        \"snr\": 2.231630325317383,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.11.self_attn.k_proj\": {\n        \"snr\": 2.176372766494751,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.12.self_attn.k_proj\": {\n        \"snr\": 1.3229435682296753,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.13.self_attn.k_proj\": {\n        \"snr\": 2.6183862686157227,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.14.self_attn.k_proj\": {\n        \"snr\": 2.608288526535034,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.15.self_attn.k_proj\": {\n        \"snr\": 1.5090984106063843,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.16.self_attn.k_proj\": {\n        \"snr\": 1.284422516822815,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.17.self_attn.k_proj\": {\n        \"snr\": 0.8903945088386536,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.18.self_attn.k_proj\": {\n        \"snr\": 1.8880385160446167,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.19.self_attn.k_proj\": {\n        \"snr\": 0.8905735015869141,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.20.self_attn.k_proj\": {\n        \"snr\": 0.9060881733894348,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.21.self_attn.k_proj\": {\n        \"snr\": 0.7572551965713501,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.22.self_attn.k_proj\": {\n        \"snr\": 0.940827488899231,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.23.self_attn.k_proj\": {\n        \"snr\": 3.7776191234588623,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.24.self_attn.k_proj\": {\n        \"snr\": 1.328923225402832,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.25.self_attn.k_proj\": {\n        \"snr\": 1.3986345529556274,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.26.self_attn.k_proj\": {\n        \"snr\": 1.2436336278915405,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.27.self_attn.k_proj\": {\n        \"snr\": 0.7737217545509338,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.28.self_attn.k_proj\": {\n        \"snr\": 2.6027626991271973,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.29.self_attn.k_proj\": {\n        \"snr\": 2.2332751750946045,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.30.self_attn.k_proj\": {\n        \"snr\": 2.476585626602173,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.31.self_attn.k_proj\": {\n        \"snr\": 1.1115432977676392,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.32.self_attn.k_proj\": {\n        \"snr\": 0.8251476287841797,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.33.self_attn.k_proj\": {\n        \"snr\": 0.9331105947494507,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.34.self_attn.k_proj\": {\n        \"snr\": 6.602395534515381,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.35.self_attn.k_proj\": {\n        \"snr\": 10.151693344116211,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.0.self_attn.o_proj\": {\n        \"snr\": 0.3661542534828186,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.1.self_attn.o_proj\": {\n        \"snr\": 0.19571374356746674,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.2.self_attn.o_proj\": {\n        \"snr\": 0.2244851142168045,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.3.self_attn.o_proj\": {\n        \"snr\": 0.2593664526939392,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.4.self_attn.o_proj\": {\n        \"snr\": 0.2569783926010132,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.5.self_attn.o_proj\": {\n        \"snr\": 0.2564302980899811,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.6.self_attn.o_proj\": {\n        \"snr\": 0.18539844453334808,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.7.self_attn.o_proj\": {\n        \"snr\": 0.2328651398420334,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.8.self_attn.o_proj\": {\n        \"snr\": 0.22055882215499878,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.9.self_attn.o_proj\": {\n        \"snr\": 0.21800543367862701,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.10.self_attn.o_proj\": {\n        \"snr\": 0.22867777943611145,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.11.self_attn.o_proj\": {\n        \"snr\": 0.23986175656318665,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.12.self_attn.o_proj\": {\n        \"snr\": 0.17598563432693481,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.13.self_attn.o_proj\": {\n        \"snr\": 0.20469218492507935,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.14.self_attn.o_proj\": {\n        \"snr\": 0.21040217578411102,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.15.self_attn.o_proj\": {\n        \"snr\": 0.23787625133991241,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.16.self_attn.o_proj\": {\n        \"snr\": 0.16339677572250366,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.17.self_attn.o_proj\": {\n        \"snr\": 0.2070712298154831,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.18.self_attn.o_proj\": {\n        \"snr\": 0.1826934814453125,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.19.self_attn.o_proj\": {\n        \"snr\": 0.19459959864616394,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.20.self_attn.o_proj\": {\n        \"snr\": 0.2668156027793884,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.21.self_attn.o_proj\": {\n        \"snr\": 0.16906610131263733,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.22.self_attn.o_proj\": {\n        \"snr\": 0.18790249526500702,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.23.self_attn.o_proj\": {\n        \"snr\": 0.18883933126926422,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.24.self_attn.o_proj\": {\n        \"snr\": 0.1793188899755478,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.25.self_attn.o_proj\": {\n        \"snr\": 0.1800570785999298,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.26.self_attn.o_proj\": {\n        \"snr\": 0.17790433764457703,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.27.self_attn.o_proj\": {\n        \"snr\": 0.2029498964548111,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.28.self_attn.o_proj\": {\n        \"snr\": 0.17044201493263245,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.29.self_attn.o_proj\": {\n        \"snr\": 0.19938386976718903,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.30.self_attn.o_proj\": {\n        \"snr\": 0.23108959197998047,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.31.self_attn.o_proj\": {\n        \"snr\": 0.16427059471607208,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.32.self_attn.o_proj\": {\n        \"snr\": 0.10631092637777328,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.33.self_attn.o_proj\": {\n        \"snr\": 0.09417019784450531,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.34.self_attn.o_proj\": {\n        \"snr\": 0.1324978619813919,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.35.self_attn.o_proj\": {\n        \"snr\": 0.11784011125564575,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.0.self_attn.q_proj\": {\n        \"snr\": 0.05565479397773743,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.1.self_attn.q_proj\": {\n        \"snr\": 0.138458251953125,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.2.self_attn.q_proj\": {\n        \"snr\": 0.12992437183856964,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.3.self_attn.q_proj\": {\n        \"snr\": 0.15362468361854553,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.4.self_attn.q_proj\": {\n        \"snr\": 0.1563446819782257,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.5.self_attn.q_proj\": {\n        \"snr\": 0.15544593334197998,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.6.self_attn.q_proj\": {\n        \"snr\": 0.15956827998161316,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.7.self_attn.q_proj\": {\n        \"snr\": 0.17549948394298553,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.8.self_attn.q_proj\": {\n        \"snr\": 0.16668449342250824,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.9.self_attn.q_proj\": {\n        \"snr\": 0.15626586973667145,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.10.self_attn.q_proj\": {\n        \"snr\": 0.18318884074687958,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.11.self_attn.q_proj\": {\n        \"snr\": 0.171547532081604,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.12.self_attn.q_proj\": {\n        \"snr\": 0.18164905905723572,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.13.self_attn.q_proj\": {\n        \"snr\": 0.2091975212097168,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.14.self_attn.q_proj\": {\n        \"snr\": 0.17431670427322388,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.15.self_attn.q_proj\": {\n        \"snr\": 0.20902502536773682,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.16.self_attn.q_proj\": {\n        \"snr\": 0.15439842641353607,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.17.self_attn.q_proj\": {\n        \"snr\": 0.1945274919271469,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.18.self_attn.q_proj\": {\n        \"snr\": 0.18916545808315277,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.19.self_attn.q_proj\": {\n        \"snr\": 0.20778712630271912,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.20.self_attn.q_proj\": {\n        \"snr\": 0.20866931974887848,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.21.self_attn.q_proj\": {\n        \"snr\": 0.1900305300951004,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.22.self_attn.q_proj\": {\n        \"snr\": 0.18200653791427612,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.23.self_attn.q_proj\": {\n        \"snr\": 0.2070988416671753,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.24.self_attn.q_proj\": {\n        \"snr\": 0.1845332235097885,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.25.self_attn.q_proj\": {\n        \"snr\": 0.20868781208992004,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.26.self_attn.q_proj\": {\n        \"snr\": 0.19242744147777557,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.27.self_attn.q_proj\": {\n        \"snr\": 0.15225112438201904,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.28.self_attn.q_proj\": {\n        \"snr\": 0.20065009593963623,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.29.self_attn.q_proj\": {\n        \"snr\": 0.19390477240085602,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.30.self_attn.q_proj\": {\n        \"snr\": 0.18538697063922882,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.31.self_attn.q_proj\": {\n        \"snr\": 0.18954339623451233,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.32.self_attn.q_proj\": {\n        \"snr\": 0.20089596509933472,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.33.self_attn.q_proj\": {\n        \"snr\": 0.19814996421337128,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.34.self_attn.q_proj\": {\n        \"snr\": 0.17733213305473328,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.35.self_attn.q_proj\": {\n        \"snr\": 0.14075976610183716,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.0.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.1.self_attn.v_proj\": {\n        \"snr\": 845.8053588867188,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.2.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.3.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.4.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.5.self_attn.v_proj\": {\n        \"snr\": 83.97241973876953,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.6.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.7.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.8.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.9.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.10.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.11.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.12.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.13.self_attn.v_proj\": {\n        \"snr\": 213.70960998535156,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.14.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.15.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.16.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.17.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.18.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.19.self_attn.v_proj\": {\n        \"snr\": 18.950267791748047,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.20.self_attn.v_proj\": {\n        \"snr\": 435.8339538574219,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.21.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.22.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.23.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.24.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.25.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.26.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.27.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.28.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.29.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.30.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.31.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.32.self_attn.v_proj\": {\n        \"snr\": 1.2341279983520508,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.33.self_attn.v_proj\": {\n        \"snr\": 0.6158654689788818,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.34.self_attn.v_proj\": {\n        \"snr\": 509.3221130371094,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.35.self_attn.v_proj\": {\n        \"snr\": 538.6658325195312,\n        \"type\": \"self_attn.v_proj\"\n    }\n}\n"
  },
  {
    "path": "src/axolotl/integrations/spectrum/model_snr_results/snr_results_Qwen-Qwen2.5-3B.json",
    "content": "{\n    \"model.layers.0.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.1.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.2.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.3.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.4.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.5.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.6.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.7.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.8.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.9.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.10.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.11.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.12.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.13.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.14.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.15.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.16.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.17.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.18.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.19.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.20.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.21.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.22.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.23.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.24.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.25.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.26.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.27.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.28.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.29.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.30.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.31.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.32.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.33.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.34.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.35.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"lm_head\": {\n        \"snr\": Infinity,\n        \"type\": \"lm_head\"\n    },\n    \"model.layers.0.mlp.down_proj\": {\n        \"snr\": 20.942785263061523,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.1.mlp.down_proj\": {\n        \"snr\": 0.11550866067409515,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.2.mlp.down_proj\": {\n        \"snr\": 0.14981402456760406,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.3.mlp.down_proj\": {\n        \"snr\": 0.36719316244125366,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.4.mlp.down_proj\": {\n        \"snr\": 0.5072987079620361,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.5.mlp.down_proj\": {\n        \"snr\": 33.86688232421875,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.6.mlp.down_proj\": {\n        \"snr\": 45.066246032714844,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.7.mlp.down_proj\": {\n        \"snr\": 33.20981979370117,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.8.mlp.down_proj\": {\n        \"snr\": 29.310104370117188,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.9.mlp.down_proj\": {\n        \"snr\": 26.638381958007812,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.10.mlp.down_proj\": {\n        \"snr\": 22.302486419677734,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.11.mlp.down_proj\": {\n        \"snr\": 18.249290466308594,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.12.mlp.down_proj\": {\n        \"snr\": 14.057564735412598,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.13.mlp.down_proj\": {\n        \"snr\": 22.154281616210938,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.14.mlp.down_proj\": {\n        \"snr\": 27.348575592041016,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.15.mlp.down_proj\": {\n        \"snr\": 28.447378158569336,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.16.mlp.down_proj\": {\n        \"snr\": 10.405216217041016,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.17.mlp.down_proj\": {\n        \"snr\": 10.71042251586914,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.18.mlp.down_proj\": {\n        \"snr\": 8.642854690551758,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.19.mlp.down_proj\": {\n        \"snr\": 5.690433979034424,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.20.mlp.down_proj\": {\n        \"snr\": 23.869070053100586,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.21.mlp.down_proj\": {\n        \"snr\": 11.584356307983398,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.22.mlp.down_proj\": {\n        \"snr\": 5.992950916290283,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.23.mlp.down_proj\": {\n        \"snr\": 18.495361328125,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.24.mlp.down_proj\": {\n        \"snr\": 8.233827590942383,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.25.mlp.down_proj\": {\n        \"snr\": 12.626734733581543,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.26.mlp.down_proj\": {\n        \"snr\": 36.21802520751953,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.27.mlp.down_proj\": {\n        \"snr\": 19.932941436767578,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.28.mlp.down_proj\": {\n        \"snr\": 6.0616455078125,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.29.mlp.down_proj\": {\n        \"snr\": 5.363720417022705,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.30.mlp.down_proj\": {\n        \"snr\": 7.455615520477295,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.31.mlp.down_proj\": {\n        \"snr\": 7.8631815910339355,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.32.mlp.down_proj\": {\n        \"snr\": 8.706913948059082,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.33.mlp.down_proj\": {\n        \"snr\": 21.220134735107422,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.34.mlp.down_proj\": {\n        \"snr\": 32.33852005004883,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.35.mlp.down_proj\": {\n        \"snr\": 179.8906707763672,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.0.mlp.gate_proj\": {\n        \"snr\": 0.24970805644989014,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.1.mlp.gate_proj\": {\n        \"snr\": 0.11607512086629868,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.2.mlp.gate_proj\": {\n        \"snr\": 0.16310769319534302,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.3.mlp.gate_proj\": {\n        \"snr\": 0.3621424436569214,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.4.mlp.gate_proj\": {\n        \"snr\": 0.3482637107372284,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.5.mlp.gate_proj\": {\n        \"snr\": 2.6533455848693848,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.6.mlp.gate_proj\": {\n        \"snr\": 8.359040260314941,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.7.mlp.gate_proj\": {\n        \"snr\": 7.382037162780762,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.8.mlp.gate_proj\": {\n        \"snr\": 13.00683879852295,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.9.mlp.gate_proj\": {\n        \"snr\": 14.936161994934082,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.10.mlp.gate_proj\": {\n        \"snr\": 20.907283782958984,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.11.mlp.gate_proj\": {\n        \"snr\": 15.941497802734375,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.12.mlp.gate_proj\": {\n        \"snr\": 8.97419548034668,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.13.mlp.gate_proj\": {\n        \"snr\": 17.585100173950195,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.14.mlp.gate_proj\": {\n        \"snr\": 17.21462059020996,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.15.mlp.gate_proj\": {\n        \"snr\": 23.703285217285156,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.16.mlp.gate_proj\": {\n        \"snr\": 16.986576080322266,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.17.mlp.gate_proj\": {\n        \"snr\": 18.27729606628418,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.18.mlp.gate_proj\": {\n        \"snr\": 12.63351058959961,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.19.mlp.gate_proj\": {\n        \"snr\": 16.2633113861084,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.20.mlp.gate_proj\": {\n        \"snr\": 7.399787902832031,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.21.mlp.gate_proj\": {\n        \"snr\": 6.10424280166626,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.22.mlp.gate_proj\": {\n        \"snr\": 5.363350868225098,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.23.mlp.gate_proj\": {\n        \"snr\": 9.344535827636719,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.24.mlp.gate_proj\": {\n        \"snr\": 7.647364616394043,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.25.mlp.gate_proj\": {\n        \"snr\": 6.143579959869385,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.26.mlp.gate_proj\": {\n        \"snr\": 14.254817008972168,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.27.mlp.gate_proj\": {\n        \"snr\": 7.7000861167907715,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.28.mlp.gate_proj\": {\n        \"snr\": 5.994422435760498,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.29.mlp.gate_proj\": {\n        \"snr\": 6.041909694671631,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.30.mlp.gate_proj\": {\n        \"snr\": 9.027522087097168,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.31.mlp.gate_proj\": {\n        \"snr\": 5.450753211975098,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.32.mlp.gate_proj\": {\n        \"snr\": 4.149200439453125,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.33.mlp.gate_proj\": {\n        \"snr\": 5.223763942718506,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.34.mlp.gate_proj\": {\n        \"snr\": 43.65521240234375,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.35.mlp.gate_proj\": {\n        \"snr\": 45.312774658203125,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.0.mlp.up_proj\": {\n        \"snr\": 0.7065013647079468,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.1.mlp.up_proj\": {\n        \"snr\": 0.17752516269683838,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.2.mlp.up_proj\": {\n        \"snr\": 0.2847473919391632,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.3.mlp.up_proj\": {\n        \"snr\": 0.6757690906524658,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.4.mlp.up_proj\": {\n        \"snr\": 0.8353318572044373,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.5.mlp.up_proj\": {\n        \"snr\": 3.940711736679077,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.6.mlp.up_proj\": {\n        \"snr\": 19.556047439575195,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.7.mlp.up_proj\": {\n        \"snr\": 36.19340515136719,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.8.mlp.up_proj\": {\n        \"snr\": 44.2518424987793,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.9.mlp.up_proj\": {\n        \"snr\": 45.418025970458984,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.10.mlp.up_proj\": {\n        \"snr\": 78.90928649902344,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.11.mlp.up_proj\": {\n        \"snr\": 52.24648666381836,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.12.mlp.up_proj\": {\n        \"snr\": 48.02030563354492,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.13.mlp.up_proj\": {\n        \"snr\": 56.016239166259766,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.14.mlp.up_proj\": {\n        \"snr\": 72.16619873046875,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.15.mlp.up_proj\": {\n        \"snr\": 54.75283432006836,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.16.mlp.up_proj\": {\n        \"snr\": 47.204097747802734,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.17.mlp.up_proj\": {\n        \"snr\": 51.549312591552734,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.18.mlp.up_proj\": {\n        \"snr\": 43.43872833251953,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.19.mlp.up_proj\": {\n        \"snr\": 38.09785461425781,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.20.mlp.up_proj\": {\n        \"snr\": 21.767858505249023,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.21.mlp.up_proj\": {\n        \"snr\": 22.243661880493164,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.22.mlp.up_proj\": {\n        \"snr\": 30.71843147277832,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.23.mlp.up_proj\": {\n        \"snr\": 28.5756778717041,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.24.mlp.up_proj\": {\n        \"snr\": 20.186717987060547,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.25.mlp.up_proj\": {\n        \"snr\": 20.742860794067383,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.26.mlp.up_proj\": {\n        \"snr\": 33.777984619140625,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.27.mlp.up_proj\": {\n        \"snr\": 17.254213333129883,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.28.mlp.up_proj\": {\n        \"snr\": 13.930026054382324,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.29.mlp.up_proj\": {\n        \"snr\": 16.17984390258789,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.30.mlp.up_proj\": {\n        \"snr\": 24.236648559570312,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.31.mlp.up_proj\": {\n        \"snr\": 10.665648460388184,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.32.mlp.up_proj\": {\n        \"snr\": 5.735939025878906,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.33.mlp.up_proj\": {\n        \"snr\": 11.592061042785645,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.34.mlp.up_proj\": {\n        \"snr\": 25.923419952392578,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.35.mlp.up_proj\": {\n        \"snr\": 38.579349517822266,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.embed_tokens\": {\n        \"snr\": Infinity,\n        \"type\": \"model.embed_tokens\"\n    },\n    \"model.norm\": {\n        \"snr\": Infinity,\n        \"type\": \"model.norm\"\n    },\n    \"model.layers.0.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.1.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.2.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.3.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.4.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.5.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.6.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.7.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.8.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.9.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.10.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.11.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.12.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.13.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.14.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.15.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.16.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.17.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.18.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.19.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.20.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.21.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.22.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.23.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.24.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.25.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.26.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.27.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.28.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.29.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.30.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.31.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.32.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.33.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.34.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.35.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.0.self_attn.k_proj\": {\n        \"snr\": 12.24727725982666,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.1.self_attn.k_proj\": {\n        \"snr\": 0.6436238288879395,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.2.self_attn.k_proj\": {\n        \"snr\": 0.7156716585159302,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.3.self_attn.k_proj\": {\n        \"snr\": 5.505439758300781,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.4.self_attn.k_proj\": {\n        \"snr\": 3.0760715007781982,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.5.self_attn.k_proj\": {\n        \"snr\": 1.0453941822052002,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.6.self_attn.k_proj\": {\n        \"snr\": 3.57472562789917,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.7.self_attn.k_proj\": {\n        \"snr\": 3.3765170574188232,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.8.self_attn.k_proj\": {\n        \"snr\": 2.8859639167785645,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.9.self_attn.k_proj\": {\n        \"snr\": 6.09852409362793,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.10.self_attn.k_proj\": {\n        \"snr\": 2.229580879211426,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.11.self_attn.k_proj\": {\n        \"snr\": 2.173879623413086,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.12.self_attn.k_proj\": {\n        \"snr\": 1.3220131397247314,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.13.self_attn.k_proj\": {\n        \"snr\": 2.61668062210083,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.14.self_attn.k_proj\": {\n        \"snr\": 2.606799840927124,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.15.self_attn.k_proj\": {\n        \"snr\": 1.5080311298370361,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.16.self_attn.k_proj\": {\n        \"snr\": 1.2841484546661377,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.17.self_attn.k_proj\": {\n        \"snr\": 0.8896433115005493,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.18.self_attn.k_proj\": {\n        \"snr\": 1.8873414993286133,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.19.self_attn.k_proj\": {\n        \"snr\": 0.8897770643234253,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.20.self_attn.k_proj\": {\n        \"snr\": 0.9051405787467957,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.21.self_attn.k_proj\": {\n        \"snr\": 0.7568970322608948,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.22.self_attn.k_proj\": {\n        \"snr\": 0.9403582811355591,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.23.self_attn.k_proj\": {\n        \"snr\": 3.777062177658081,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.24.self_attn.k_proj\": {\n        \"snr\": 1.3280683755874634,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.25.self_attn.k_proj\": {\n        \"snr\": 1.3980307579040527,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.26.self_attn.k_proj\": {\n        \"snr\": 1.2435240745544434,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.27.self_attn.k_proj\": {\n        \"snr\": 0.7732619047164917,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.28.self_attn.k_proj\": {\n        \"snr\": 2.6010243892669678,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.29.self_attn.k_proj\": {\n        \"snr\": 2.232773780822754,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.30.self_attn.k_proj\": {\n        \"snr\": 2.4743099212646484,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.31.self_attn.k_proj\": {\n        \"snr\": 1.11082923412323,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.32.self_attn.k_proj\": {\n        \"snr\": 0.8243986368179321,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.33.self_attn.k_proj\": {\n        \"snr\": 0.932928204536438,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.34.self_attn.k_proj\": {\n        \"snr\": 6.608611583709717,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.35.self_attn.k_proj\": {\n        \"snr\": 10.160987854003906,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.0.self_attn.o_proj\": {\n        \"snr\": 0.36662933230400085,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.1.self_attn.o_proj\": {\n        \"snr\": 0.1955128312110901,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.2.self_attn.o_proj\": {\n        \"snr\": 0.22419843077659607,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.3.self_attn.o_proj\": {\n        \"snr\": 0.25902292132377625,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.4.self_attn.o_proj\": {\n        \"snr\": 0.2567676901817322,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.5.self_attn.o_proj\": {\n        \"snr\": 0.2560890316963196,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.6.self_attn.o_proj\": {\n        \"snr\": 0.18518221378326416,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.7.self_attn.o_proj\": {\n        \"snr\": 0.23254290223121643,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.8.self_attn.o_proj\": {\n        \"snr\": 0.2203962802886963,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.9.self_attn.o_proj\": {\n        \"snr\": 0.217017263174057,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.10.self_attn.o_proj\": {\n        \"snr\": 0.22843335568904877,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.11.self_attn.o_proj\": {\n        \"snr\": 0.23816843330860138,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.12.self_attn.o_proj\": {\n        \"snr\": 0.17585325241088867,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.13.self_attn.o_proj\": {\n        \"snr\": 0.20451271533966064,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.14.self_attn.o_proj\": {\n        \"snr\": 0.2095799297094345,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.15.self_attn.o_proj\": {\n        \"snr\": 0.23767071962356567,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.16.self_attn.o_proj\": {\n        \"snr\": 0.16328400373458862,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.17.self_attn.o_proj\": {\n        \"snr\": 0.20690056681632996,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.18.self_attn.o_proj\": {\n        \"snr\": 0.18191492557525635,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.19.self_attn.o_proj\": {\n        \"snr\": 0.1945018619298935,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.20.self_attn.o_proj\": {\n        \"snr\": 0.26658856868743896,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.21.self_attn.o_proj\": {\n        \"snr\": 0.16897724568843842,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.22.self_attn.o_proj\": {\n        \"snr\": 0.18773262202739716,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.23.self_attn.o_proj\": {\n        \"snr\": 0.18808405101299286,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.24.self_attn.o_proj\": {\n        \"snr\": 0.17919476330280304,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.25.self_attn.o_proj\": {\n        \"snr\": 0.1793426126241684,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.26.self_attn.o_proj\": {\n        \"snr\": 0.1777871698141098,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.27.self_attn.o_proj\": {\n        \"snr\": 0.20279864966869354,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.28.self_attn.o_proj\": {\n        \"snr\": 0.17030371725559235,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.29.self_attn.o_proj\": {\n        \"snr\": 0.1992504596710205,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.30.self_attn.o_proj\": {\n        \"snr\": 0.23085352778434753,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.31.self_attn.o_proj\": {\n        \"snr\": 0.1641533523797989,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.32.self_attn.o_proj\": {\n        \"snr\": 0.10621391236782074,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.33.self_attn.o_proj\": {\n        \"snr\": 0.09411631524562836,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.34.self_attn.o_proj\": {\n        \"snr\": 0.13239727914333344,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.35.self_attn.o_proj\": {\n        \"snr\": 0.11740171164274216,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.0.self_attn.q_proj\": {\n        \"snr\": 0.055595725774765015,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.1.self_attn.q_proj\": {\n        \"snr\": 0.13823610544204712,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.2.self_attn.q_proj\": {\n        \"snr\": 0.1297825127840042,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.3.self_attn.q_proj\": {\n        \"snr\": 0.15291297435760498,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.4.self_attn.q_proj\": {\n        \"snr\": 0.15615035593509674,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.5.self_attn.q_proj\": {\n        \"snr\": 0.15535500645637512,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.6.self_attn.q_proj\": {\n        \"snr\": 0.15993140637874603,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.7.self_attn.q_proj\": {\n        \"snr\": 0.1753682643175125,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.8.self_attn.q_proj\": {\n        \"snr\": 0.1664913445711136,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.9.self_attn.q_proj\": {\n        \"snr\": 0.15656901895999908,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.10.self_attn.q_proj\": {\n        \"snr\": 0.18300014734268188,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.11.self_attn.q_proj\": {\n        \"snr\": 0.1713649481534958,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.12.self_attn.q_proj\": {\n        \"snr\": 0.1809009313583374,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.13.self_attn.q_proj\": {\n        \"snr\": 0.20895132422447205,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.14.self_attn.q_proj\": {\n        \"snr\": 0.17413195967674255,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.15.self_attn.q_proj\": {\n        \"snr\": 0.20878490805625916,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.16.self_attn.q_proj\": {\n        \"snr\": 0.1547088772058487,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.17.self_attn.q_proj\": {\n        \"snr\": 0.1943129003047943,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.18.self_attn.q_proj\": {\n        \"snr\": 0.1889297217130661,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.19.self_attn.q_proj\": {\n        \"snr\": 0.207680344581604,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.20.self_attn.q_proj\": {\n        \"snr\": 0.20839959383010864,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.21.self_attn.q_proj\": {\n        \"snr\": 0.18989044427871704,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.22.self_attn.q_proj\": {\n        \"snr\": 0.18180623650550842,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.23.self_attn.q_proj\": {\n        \"snr\": 0.2069384753704071,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.24.self_attn.q_proj\": {\n        \"snr\": 0.1842993050813675,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.25.self_attn.q_proj\": {\n        \"snr\": 0.2078687846660614,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.26.self_attn.q_proj\": {\n        \"snr\": 0.19224946200847626,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.27.self_attn.q_proj\": {\n        \"snr\": 0.15170617401599884,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.28.self_attn.q_proj\": {\n        \"snr\": 0.20116600394248962,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.29.self_attn.q_proj\": {\n        \"snr\": 0.19373668730258942,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.30.self_attn.q_proj\": {\n        \"snr\": 0.18462225794792175,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.31.self_attn.q_proj\": {\n        \"snr\": 0.18939673900604248,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.32.self_attn.q_proj\": {\n        \"snr\": 0.20071947574615479,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.33.self_attn.q_proj\": {\n        \"snr\": 0.19740056991577148,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.34.self_attn.q_proj\": {\n        \"snr\": 0.17658494412899017,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.35.self_attn.q_proj\": {\n        \"snr\": 0.1407373696565628,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.0.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.1.self_attn.v_proj\": {\n        \"snr\": 846.30126953125,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.2.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.3.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.4.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.5.self_attn.v_proj\": {\n        \"snr\": 83.83415222167969,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.6.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.7.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.8.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.9.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.10.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.11.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.12.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.13.self_attn.v_proj\": {\n        \"snr\": 213.51316833496094,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.14.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.15.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.16.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.17.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.18.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.19.self_attn.v_proj\": {\n        \"snr\": 18.92746925354004,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.20.self_attn.v_proj\": {\n        \"snr\": 433.9771728515625,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.21.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.22.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.23.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.24.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.25.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.26.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.27.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.28.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.29.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.30.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.31.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.32.self_attn.v_proj\": {\n        \"snr\": 1.2332282066345215,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.33.self_attn.v_proj\": {\n        \"snr\": 0.6151890158653259,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.34.self_attn.v_proj\": {\n        \"snr\": 509.7169189453125,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.35.self_attn.v_proj\": {\n        \"snr\": 536.0748901367188,\n        \"type\": \"self_attn.v_proj\"\n    }\n}\n"
  },
  {
    "path": "src/axolotl/integrations/spectrum/model_snr_results/snr_results_Qwen-Qwen2.5-7B-Instruct.json",
    "content": "{\n    \"model.layers.0.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.1.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.2.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.3.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.4.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.5.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.6.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.7.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.8.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.9.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.10.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.11.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.12.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.13.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.14.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.15.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.16.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.17.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.18.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.19.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.20.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.21.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.22.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.23.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.24.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.25.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.26.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.27.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"lm_head\": {\n        \"snr\": Infinity,\n        \"type\": \"lm_head\"\n    },\n    \"model.layers.0.mlp.down_proj\": {\n        \"snr\": 10.283808708190918,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.1.mlp.down_proj\": {\n        \"snr\": 1.2089825868606567,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.2.mlp.down_proj\": {\n        \"snr\": 19.309062957763672,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.3.mlp.down_proj\": {\n        \"snr\": 50.174461364746094,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.4.mlp.down_proj\": {\n        \"snr\": 114.28582763671875,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.5.mlp.down_proj\": {\n        \"snr\": 215.5762176513672,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.6.mlp.down_proj\": {\n        \"snr\": 204.5117950439453,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.7.mlp.down_proj\": {\n        \"snr\": 182.5479278564453,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.8.mlp.down_proj\": {\n        \"snr\": 74.92950439453125,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.9.mlp.down_proj\": {\n        \"snr\": 16.482666015625,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.10.mlp.down_proj\": {\n        \"snr\": 55.33920669555664,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.11.mlp.down_proj\": {\n        \"snr\": 16.851062774658203,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.12.mlp.down_proj\": {\n        \"snr\": 58.65230178833008,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.13.mlp.down_proj\": {\n        \"snr\": 11.150161743164062,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.14.mlp.down_proj\": {\n        \"snr\": 65.32643127441406,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.15.mlp.down_proj\": {\n        \"snr\": 46.736305236816406,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.16.mlp.down_proj\": {\n        \"snr\": 14.288785934448242,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.17.mlp.down_proj\": {\n        \"snr\": 23.40110206604004,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.18.mlp.down_proj\": {\n        \"snr\": 86.34363555908203,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.19.mlp.down_proj\": {\n        \"snr\": 49.14613342285156,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.20.mlp.down_proj\": {\n        \"snr\": 1276.84814453125,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.21.mlp.down_proj\": {\n        \"snr\": 51.803409576416016,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.22.mlp.down_proj\": {\n        \"snr\": 143.0666046142578,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.23.mlp.down_proj\": {\n        \"snr\": 35.14984893798828,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.24.mlp.down_proj\": {\n        \"snr\": 21.41700553894043,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.25.mlp.down_proj\": {\n        \"snr\": 10.651569366455078,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.26.mlp.down_proj\": {\n        \"snr\": 21.635149002075195,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.27.mlp.down_proj\": {\n        \"snr\": 1446.2774658203125,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.0.mlp.gate_proj\": {\n        \"snr\": 0.04497330263257027,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.1.mlp.gate_proj\": {\n        \"snr\": 0.16888172924518585,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.2.mlp.gate_proj\": {\n        \"snr\": 0.33653727173805237,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.3.mlp.gate_proj\": {\n        \"snr\": 3.1445391178131104,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.4.mlp.gate_proj\": {\n        \"snr\": 9.107144355773926,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.5.mlp.gate_proj\": {\n        \"snr\": 15.909018516540527,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.6.mlp.gate_proj\": {\n        \"snr\": 60.9138069152832,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.7.mlp.gate_proj\": {\n        \"snr\": 57.570281982421875,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.8.mlp.gate_proj\": {\n        \"snr\": 65.82791137695312,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.9.mlp.gate_proj\": {\n        \"snr\": 10.455283164978027,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.10.mlp.gate_proj\": {\n        \"snr\": 26.970706939697266,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.11.mlp.gate_proj\": {\n        \"snr\": 31.139820098876953,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.12.mlp.gate_proj\": {\n        \"snr\": 43.987159729003906,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.13.mlp.gate_proj\": {\n        \"snr\": 20.704849243164062,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.14.mlp.gate_proj\": {\n        \"snr\": 21.191452026367188,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.15.mlp.gate_proj\": {\n        \"snr\": 42.66447830200195,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.16.mlp.gate_proj\": {\n        \"snr\": 22.136825561523438,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.17.mlp.gate_proj\": {\n        \"snr\": 22.60980987548828,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.18.mlp.gate_proj\": {\n        \"snr\": 81.80574035644531,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.19.mlp.gate_proj\": {\n        \"snr\": 20.88619613647461,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.20.mlp.gate_proj\": {\n        \"snr\": 58.3524055480957,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.21.mlp.gate_proj\": {\n        \"snr\": 22.786706924438477,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.22.mlp.gate_proj\": {\n        \"snr\": 16.932226181030273,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.23.mlp.gate_proj\": {\n        \"snr\": 16.819862365722656,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.24.mlp.gate_proj\": {\n        \"snr\": 19.76348304748535,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.25.mlp.gate_proj\": {\n        \"snr\": 28.98714256286621,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.26.mlp.gate_proj\": {\n        \"snr\": 36.7071533203125,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.27.mlp.gate_proj\": {\n        \"snr\": 51.81539535522461,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.0.mlp.up_proj\": {\n        \"snr\": 0.2243107706308365,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.1.mlp.up_proj\": {\n        \"snr\": 0.4464716613292694,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.2.mlp.up_proj\": {\n        \"snr\": 1.7838181257247925,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.3.mlp.up_proj\": {\n        \"snr\": 17.912736892700195,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.4.mlp.up_proj\": {\n        \"snr\": 47.45841979980469,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.5.mlp.up_proj\": {\n        \"snr\": 56.3084602355957,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.6.mlp.up_proj\": {\n        \"snr\": 173.33717346191406,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.7.mlp.up_proj\": {\n        \"snr\": 148.22750854492188,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.8.mlp.up_proj\": {\n        \"snr\": 133.63565063476562,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.9.mlp.up_proj\": {\n        \"snr\": 83.65129852294922,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.10.mlp.up_proj\": {\n        \"snr\": 117.94369506835938,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.11.mlp.up_proj\": {\n        \"snr\": 94.52413940429688,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.12.mlp.up_proj\": {\n        \"snr\": 130.43333435058594,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.13.mlp.up_proj\": {\n        \"snr\": 76.11975860595703,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.14.mlp.up_proj\": {\n        \"snr\": 158.75192260742188,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.15.mlp.up_proj\": {\n        \"snr\": 143.72706604003906,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.16.mlp.up_proj\": {\n        \"snr\": 84.28279876708984,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.17.mlp.up_proj\": {\n        \"snr\": 116.65055084228516,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.18.mlp.up_proj\": {\n        \"snr\": 177.1201934814453,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.19.mlp.up_proj\": {\n        \"snr\": 82.4564437866211,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.20.mlp.up_proj\": {\n        \"snr\": 137.73019409179688,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.21.mlp.up_proj\": {\n        \"snr\": 89.97538757324219,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.22.mlp.up_proj\": {\n        \"snr\": 86.30876159667969,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.23.mlp.up_proj\": {\n        \"snr\": 61.53449249267578,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.24.mlp.up_proj\": {\n        \"snr\": 45.22392654418945,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.25.mlp.up_proj\": {\n        \"snr\": 60.3155517578125,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.26.mlp.up_proj\": {\n        \"snr\": 40.06092071533203,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.27.mlp.up_proj\": {\n        \"snr\": 48.12322998046875,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.embed_tokens\": {\n        \"snr\": 0.08805440366268158,\n        \"type\": \"model.embed_tokens\"\n    },\n    \"model.norm\": {\n        \"snr\": Infinity,\n        \"type\": \"model.norm\"\n    },\n    \"model.layers.0.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.1.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.2.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.3.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.4.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.5.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.6.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.7.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.8.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.9.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.10.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.11.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.12.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.13.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.14.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.15.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.16.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.17.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.18.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.19.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.20.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.21.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.22.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.23.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.24.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.25.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.26.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.27.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.0.self_attn.k_proj\": {\n        \"snr\": 4.771554470062256,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.1.self_attn.k_proj\": {\n        \"snr\": 0.46674421429634094,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.2.self_attn.k_proj\": {\n        \"snr\": 1.6167784929275513,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.3.self_attn.k_proj\": {\n        \"snr\": 2.0980119705200195,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.4.self_attn.k_proj\": {\n        \"snr\": 1.4339035749435425,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.5.self_attn.k_proj\": {\n        \"snr\": 1.7446703910827637,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.6.self_attn.k_proj\": {\n        \"snr\": 1.2829725742340088,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.7.self_attn.k_proj\": {\n        \"snr\": 2.2314982414245605,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.8.self_attn.k_proj\": {\n        \"snr\": 1.5125916004180908,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.9.self_attn.k_proj\": {\n        \"snr\": 1.2817912101745605,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.10.self_attn.k_proj\": {\n        \"snr\": 3.3553454875946045,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.11.self_attn.k_proj\": {\n        \"snr\": 1.591347336769104,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.12.self_attn.k_proj\": {\n        \"snr\": 1.1114169359207153,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.13.self_attn.k_proj\": {\n        \"snr\": 1.1536189317703247,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.14.self_attn.k_proj\": {\n        \"snr\": 0.994098424911499,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.15.self_attn.k_proj\": {\n        \"snr\": 1.484580636024475,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.16.self_attn.k_proj\": {\n        \"snr\": 1.2999093532562256,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.17.self_attn.k_proj\": {\n        \"snr\": 2.1628623008728027,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.18.self_attn.k_proj\": {\n        \"snr\": 1.3842225074768066,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.19.self_attn.k_proj\": {\n        \"snr\": 1.440075159072876,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.20.self_attn.k_proj\": {\n        \"snr\": 1.7816450595855713,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.21.self_attn.k_proj\": {\n        \"snr\": 1.746536135673523,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.22.self_attn.k_proj\": {\n        \"snr\": 1.318993091583252,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.23.self_attn.k_proj\": {\n        \"snr\": 1.7234206199645996,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.24.self_attn.k_proj\": {\n        \"snr\": 2.586996555328369,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.25.self_attn.k_proj\": {\n        \"snr\": 1.6486897468566895,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.26.self_attn.k_proj\": {\n        \"snr\": 1.3349357843399048,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.27.self_attn.k_proj\": {\n        \"snr\": 0.9039687514305115,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.0.self_attn.o_proj\": {\n        \"snr\": 0.10605750232934952,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.1.self_attn.o_proj\": {\n        \"snr\": 0.2503393292427063,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.2.self_attn.o_proj\": {\n        \"snr\": 0.21453581750392914,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.3.self_attn.o_proj\": {\n        \"snr\": 0.20600366592407227,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.4.self_attn.o_proj\": {\n        \"snr\": 0.22004099190235138,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.5.self_attn.o_proj\": {\n        \"snr\": 0.2267625778913498,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.6.self_attn.o_proj\": {\n        \"snr\": 0.1736888736486435,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.7.self_attn.o_proj\": {\n        \"snr\": 0.2314220815896988,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.8.self_attn.o_proj\": {\n        \"snr\": 0.24031606316566467,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.9.self_attn.o_proj\": {\n        \"snr\": 0.13458871841430664,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.10.self_attn.o_proj\": {\n        \"snr\": 0.20170633494853973,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.11.self_attn.o_proj\": {\n        \"snr\": 0.19507651031017303,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.12.self_attn.o_proj\": {\n        \"snr\": 0.1862162947654724,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.13.self_attn.o_proj\": {\n        \"snr\": 0.15117767453193665,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.14.self_attn.o_proj\": {\n        \"snr\": 0.1857745349407196,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.15.self_attn.o_proj\": {\n        \"snr\": 0.2064860314130783,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.16.self_attn.o_proj\": {\n        \"snr\": 0.15419450402259827,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.17.self_attn.o_proj\": {\n        \"snr\": 0.17895667254924774,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.18.self_attn.o_proj\": {\n        \"snr\": 0.18284623324871063,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.19.self_attn.o_proj\": {\n        \"snr\": 0.17497135698795319,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.20.self_attn.o_proj\": {\n        \"snr\": 0.178844153881073,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.21.self_attn.o_proj\": {\n        \"snr\": 0.16190896928310394,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.22.self_attn.o_proj\": {\n        \"snr\": 0.19371949136257172,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.23.self_attn.o_proj\": {\n        \"snr\": 0.14116843044757843,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.24.self_attn.o_proj\": {\n        \"snr\": 0.14100700616836548,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.25.self_attn.o_proj\": {\n        \"snr\": 0.14792074263095856,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.26.self_attn.o_proj\": {\n        \"snr\": 0.11953117698431015,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.27.self_attn.o_proj\": {\n        \"snr\": 0.06241385638713837,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.0.self_attn.q_proj\": {\n        \"snr\": 0.02127065323293209,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.1.self_attn.q_proj\": {\n        \"snr\": 0.14693336188793182,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.2.self_attn.q_proj\": {\n        \"snr\": 0.16316214203834534,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.3.self_attn.q_proj\": {\n        \"snr\": 0.1218630000948906,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.4.self_attn.q_proj\": {\n        \"snr\": 0.13916714489459991,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.5.self_attn.q_proj\": {\n        \"snr\": 0.155359148979187,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.6.self_attn.q_proj\": {\n        \"snr\": 0.1590007096529007,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.7.self_attn.q_proj\": {\n        \"snr\": 0.1958903819322586,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.8.self_attn.q_proj\": {\n        \"snr\": 0.22448301315307617,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.9.self_attn.q_proj\": {\n        \"snr\": 0.20126597583293915,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.10.self_attn.q_proj\": {\n        \"snr\": 0.1980895698070526,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.11.self_attn.q_proj\": {\n        \"snr\": 0.2289486974477768,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.12.self_attn.q_proj\": {\n        \"snr\": 0.22922305762767792,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.13.self_attn.q_proj\": {\n        \"snr\": 0.21452386677265167,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.14.self_attn.q_proj\": {\n        \"snr\": 0.24151542782783508,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.15.self_attn.q_proj\": {\n        \"snr\": 0.21893717348575592,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.16.self_attn.q_proj\": {\n        \"snr\": 0.2321016639471054,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.17.self_attn.q_proj\": {\n        \"snr\": 0.24078059196472168,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.18.self_attn.q_proj\": {\n        \"snr\": 0.22774985432624817,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.19.self_attn.q_proj\": {\n        \"snr\": 0.20914016664028168,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.20.self_attn.q_proj\": {\n        \"snr\": 0.22847522795200348,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.21.self_attn.q_proj\": {\n        \"snr\": 0.2500442862510681,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.22.self_attn.q_proj\": {\n        \"snr\": 0.2353251725435257,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.23.self_attn.q_proj\": {\n        \"snr\": 0.20365388691425323,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.24.self_attn.q_proj\": {\n        \"snr\": 0.21967172622680664,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.25.self_attn.q_proj\": {\n        \"snr\": 0.2122868150472641,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.26.self_attn.q_proj\": {\n        \"snr\": 0.2415798157453537,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.27.self_attn.q_proj\": {\n        \"snr\": 0.12347634881734848,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.0.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.1.self_attn.v_proj\": {\n        \"snr\": 230.88636779785156,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.2.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.3.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.4.self_attn.v_proj\": {\n        \"snr\": 22.38136100769043,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.5.self_attn.v_proj\": {\n        \"snr\": 246.59597778320312,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.6.self_attn.v_proj\": {\n        \"snr\": 499.61761474609375,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.7.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.8.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.9.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.10.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.11.self_attn.v_proj\": {\n        \"snr\": 69.18345642089844,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.12.self_attn.v_proj\": {\n        \"snr\": 984.9320068359375,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.13.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.14.self_attn.v_proj\": {\n        \"snr\": 64.06214141845703,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.15.self_attn.v_proj\": {\n        \"snr\": 28.43911361694336,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.16.self_attn.v_proj\": {\n        \"snr\": 725.1439819335938,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.17.self_attn.v_proj\": {\n        \"snr\": 63.43681716918945,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.18.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.19.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.20.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.21.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.22.self_attn.v_proj\": {\n        \"snr\": 238.4695587158203,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.23.self_attn.v_proj\": {\n        \"snr\": 111.88697814941406,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.24.self_attn.v_proj\": {\n        \"snr\": 686.2830200195312,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.25.self_attn.v_proj\": {\n        \"snr\": 566.2647705078125,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.26.self_attn.v_proj\": {\n        \"snr\": 4.070064544677734,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.27.self_attn.v_proj\": {\n        \"snr\": 4.3411664962768555,\n        \"type\": \"self_attn.v_proj\"\n    }\n}\n"
  },
  {
    "path": "src/axolotl/integrations/spectrum/model_snr_results/snr_results_Qwen-Qwen2.5-7B.json",
    "content": "{\n    \"model.layers.0.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.1.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.2.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.3.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.4.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.5.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.6.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.7.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.8.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.9.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.10.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.11.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.12.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.13.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.14.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.15.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.16.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.17.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.18.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.19.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.20.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.21.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.22.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.23.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.24.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.25.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.26.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.27.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"lm_head\": {\n        \"snr\": Infinity,\n        \"type\": \"lm_head\"\n    },\n    \"model.layers.0.mlp.down_proj\": {\n        \"snr\": 10.277782440185547,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.1.mlp.down_proj\": {\n        \"snr\": 1.2050706148147583,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.2.mlp.down_proj\": {\n        \"snr\": 19.284534454345703,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.3.mlp.down_proj\": {\n        \"snr\": 50.16513442993164,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.4.mlp.down_proj\": {\n        \"snr\": 114.24882507324219,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.5.mlp.down_proj\": {\n        \"snr\": 215.48194885253906,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.6.mlp.down_proj\": {\n        \"snr\": 204.39431762695312,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.7.mlp.down_proj\": {\n        \"snr\": 182.5116729736328,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.8.mlp.down_proj\": {\n        \"snr\": 74.9266128540039,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.9.mlp.down_proj\": {\n        \"snr\": 16.474102020263672,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.10.mlp.down_proj\": {\n        \"snr\": 55.30583572387695,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.11.mlp.down_proj\": {\n        \"snr\": 16.84047508239746,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.12.mlp.down_proj\": {\n        \"snr\": 58.62131118774414,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.13.mlp.down_proj\": {\n        \"snr\": 11.144298553466797,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.14.mlp.down_proj\": {\n        \"snr\": 65.28057098388672,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.15.mlp.down_proj\": {\n        \"snr\": 46.701290130615234,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.16.mlp.down_proj\": {\n        \"snr\": 14.278325080871582,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.17.mlp.down_proj\": {\n        \"snr\": 23.382247924804688,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.18.mlp.down_proj\": {\n        \"snr\": 93.8782958984375,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.19.mlp.down_proj\": {\n        \"snr\": 49.10498809814453,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.20.mlp.down_proj\": {\n        \"snr\": 1277.5101318359375,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.21.mlp.down_proj\": {\n        \"snr\": 51.7880859375,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.22.mlp.down_proj\": {\n        \"snr\": 143.03504943847656,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.23.mlp.down_proj\": {\n        \"snr\": 35.123931884765625,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.24.mlp.down_proj\": {\n        \"snr\": 21.403743743896484,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.25.mlp.down_proj\": {\n        \"snr\": 10.551352500915527,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.26.mlp.down_proj\": {\n        \"snr\": 21.62333869934082,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.27.mlp.down_proj\": {\n        \"snr\": 1541.98681640625,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.0.mlp.gate_proj\": {\n        \"snr\": 0.04497644677758217,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.1.mlp.gate_proj\": {\n        \"snr\": 0.16878646612167358,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.2.mlp.gate_proj\": {\n        \"snr\": 0.336302250623703,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.3.mlp.gate_proj\": {\n        \"snr\": 3.141293525695801,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.4.mlp.gate_proj\": {\n        \"snr\": 9.098686218261719,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.5.mlp.gate_proj\": {\n        \"snr\": 15.89354419708252,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.6.mlp.gate_proj\": {\n        \"snr\": 60.85503387451172,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.7.mlp.gate_proj\": {\n        \"snr\": 57.53098678588867,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.8.mlp.gate_proj\": {\n        \"snr\": 65.77096557617188,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.9.mlp.gate_proj\": {\n        \"snr\": 10.453179359436035,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.10.mlp.gate_proj\": {\n        \"snr\": 26.94801139831543,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.11.mlp.gate_proj\": {\n        \"snr\": 31.111093521118164,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.12.mlp.gate_proj\": {\n        \"snr\": 43.963191986083984,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.13.mlp.gate_proj\": {\n        \"snr\": 20.690765380859375,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.14.mlp.gate_proj\": {\n        \"snr\": 20.47557258605957,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.15.mlp.gate_proj\": {\n        \"snr\": 42.63906478881836,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.16.mlp.gate_proj\": {\n        \"snr\": 22.11542320251465,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.17.mlp.gate_proj\": {\n        \"snr\": 22.590566635131836,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.18.mlp.gate_proj\": {\n        \"snr\": 81.74773406982422,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.19.mlp.gate_proj\": {\n        \"snr\": 20.872997283935547,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.20.mlp.gate_proj\": {\n        \"snr\": 58.32197952270508,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.21.mlp.gate_proj\": {\n        \"snr\": 22.784095764160156,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.22.mlp.gate_proj\": {\n        \"snr\": 16.935768127441406,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.23.mlp.gate_proj\": {\n        \"snr\": 16.830224990844727,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.24.mlp.gate_proj\": {\n        \"snr\": 19.774564743041992,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.25.mlp.gate_proj\": {\n        \"snr\": 27.770675659179688,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.26.mlp.gate_proj\": {\n        \"snr\": 36.714595794677734,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.27.mlp.gate_proj\": {\n        \"snr\": 51.81637191772461,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.0.mlp.up_proj\": {\n        \"snr\": 0.22425401210784912,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.1.mlp.up_proj\": {\n        \"snr\": 0.4456978142261505,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.2.mlp.up_proj\": {\n        \"snr\": 1.7769725322723389,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.3.mlp.up_proj\": {\n        \"snr\": 17.8966121673584,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.4.mlp.up_proj\": {\n        \"snr\": 47.43608856201172,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.5.mlp.up_proj\": {\n        \"snr\": 56.2298698425293,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.6.mlp.up_proj\": {\n        \"snr\": 173.1498260498047,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.7.mlp.up_proj\": {\n        \"snr\": 148.02874755859375,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.8.mlp.up_proj\": {\n        \"snr\": 133.5174560546875,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.9.mlp.up_proj\": {\n        \"snr\": 83.45183563232422,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.10.mlp.up_proj\": {\n        \"snr\": 117.88772583007812,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.11.mlp.up_proj\": {\n        \"snr\": 94.41156768798828,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.12.mlp.up_proj\": {\n        \"snr\": 130.3107452392578,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.13.mlp.up_proj\": {\n        \"snr\": 76.04458618164062,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.14.mlp.up_proj\": {\n        \"snr\": 158.59634399414062,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.15.mlp.up_proj\": {\n        \"snr\": 143.59596252441406,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.16.mlp.up_proj\": {\n        \"snr\": 84.2161636352539,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.17.mlp.up_proj\": {\n        \"snr\": 116.55204010009766,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.18.mlp.up_proj\": {\n        \"snr\": 176.95449829101562,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.19.mlp.up_proj\": {\n        \"snr\": 82.37284088134766,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.20.mlp.up_proj\": {\n        \"snr\": 137.5695343017578,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.21.mlp.up_proj\": {\n        \"snr\": 89.87335205078125,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.22.mlp.up_proj\": {\n        \"snr\": 86.1510238647461,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.23.mlp.up_proj\": {\n        \"snr\": 61.37428665161133,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.24.mlp.up_proj\": {\n        \"snr\": 45.10757064819336,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.25.mlp.up_proj\": {\n        \"snr\": 60.16519546508789,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.26.mlp.up_proj\": {\n        \"snr\": 39.96969223022461,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.27.mlp.up_proj\": {\n        \"snr\": 48.04258346557617,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.embed_tokens\": {\n        \"snr\": 0.08800078183412552,\n        \"type\": \"model.embed_tokens\"\n    },\n    \"model.norm\": {\n        \"snr\": Infinity,\n        \"type\": \"model.norm\"\n    },\n    \"model.layers.0.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.1.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.2.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.3.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.4.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.5.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.6.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.7.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.8.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.9.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.10.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.11.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.12.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.13.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.14.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.15.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.16.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.17.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.18.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.19.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.20.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.21.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.22.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.23.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.24.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.25.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.26.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.27.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.0.self_attn.k_proj\": {\n        \"snr\": 4.764852046966553,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.1.self_attn.k_proj\": {\n        \"snr\": 0.46627077460289,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.2.self_attn.k_proj\": {\n        \"snr\": 1.6155915260314941,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.3.self_attn.k_proj\": {\n        \"snr\": 2.096365451812744,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.4.self_attn.k_proj\": {\n        \"snr\": 1.431254267692566,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.5.self_attn.k_proj\": {\n        \"snr\": 1.7440669536590576,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.6.self_attn.k_proj\": {\n        \"snr\": 1.2815033197402954,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.7.self_attn.k_proj\": {\n        \"snr\": 2.2301025390625,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.8.self_attn.k_proj\": {\n        \"snr\": 1.5116536617279053,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.9.self_attn.k_proj\": {\n        \"snr\": 1.2699830532073975,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.10.self_attn.k_proj\": {\n        \"snr\": 3.3086464405059814,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.11.self_attn.k_proj\": {\n        \"snr\": 1.59111487865448,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.12.self_attn.k_proj\": {\n        \"snr\": 1.1007944345474243,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.13.self_attn.k_proj\": {\n        \"snr\": 1.163416862487793,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.14.self_attn.k_proj\": {\n        \"snr\": 0.9935113787651062,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.15.self_attn.k_proj\": {\n        \"snr\": 1.483581304550171,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.16.self_attn.k_proj\": {\n        \"snr\": 1.2992271184921265,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.17.self_attn.k_proj\": {\n        \"snr\": 2.162485122680664,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.18.self_attn.k_proj\": {\n        \"snr\": 1.3841017484664917,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.19.self_attn.k_proj\": {\n        \"snr\": 1.453418493270874,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.20.self_attn.k_proj\": {\n        \"snr\": 1.781678557395935,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.21.self_attn.k_proj\": {\n        \"snr\": 1.7460925579071045,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.22.self_attn.k_proj\": {\n        \"snr\": 1.3188031911849976,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.23.self_attn.k_proj\": {\n        \"snr\": 1.723441243171692,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.24.self_attn.k_proj\": {\n        \"snr\": 2.585094928741455,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.25.self_attn.k_proj\": {\n        \"snr\": 1.6478856801986694,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.26.self_attn.k_proj\": {\n        \"snr\": 1.3221096992492676,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.27.self_attn.k_proj\": {\n        \"snr\": 0.9034463167190552,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.0.self_attn.o_proj\": {\n        \"snr\": 0.10636883229017258,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.1.self_attn.o_proj\": {\n        \"snr\": 0.24971255660057068,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.2.self_attn.o_proj\": {\n        \"snr\": 0.21437697112560272,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.3.self_attn.o_proj\": {\n        \"snr\": 0.2058248072862625,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.4.self_attn.o_proj\": {\n        \"snr\": 0.21978946030139923,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.5.self_attn.o_proj\": {\n        \"snr\": 0.2269466072320938,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.6.self_attn.o_proj\": {\n        \"snr\": 0.17318543791770935,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.7.self_attn.o_proj\": {\n        \"snr\": 0.23159846663475037,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.8.self_attn.o_proj\": {\n        \"snr\": 0.2400084286928177,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.9.self_attn.o_proj\": {\n        \"snr\": 0.134766086935997,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.10.self_attn.o_proj\": {\n        \"snr\": 0.20152011513710022,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.11.self_attn.o_proj\": {\n        \"snr\": 0.19492347538471222,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.12.self_attn.o_proj\": {\n        \"snr\": 0.18607021868228912,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.13.self_attn.o_proj\": {\n        \"snr\": 0.15107683837413788,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.14.self_attn.o_proj\": {\n        \"snr\": 0.18565276265144348,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.15.self_attn.o_proj\": {\n        \"snr\": 0.20626339316368103,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.16.self_attn.o_proj\": {\n        \"snr\": 0.1541011780500412,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.17.self_attn.o_proj\": {\n        \"snr\": 0.1784645915031433,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.18.self_attn.o_proj\": {\n        \"snr\": 0.18307389318943024,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.19.self_attn.o_proj\": {\n        \"snr\": 0.17449897527694702,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.20.self_attn.o_proj\": {\n        \"snr\": 0.1787375956773758,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.21.self_attn.o_proj\": {\n        \"snr\": 0.161802276968956,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.22.self_attn.o_proj\": {\n        \"snr\": 0.1931520402431488,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.23.self_attn.o_proj\": {\n        \"snr\": 0.14108893275260925,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.24.self_attn.o_proj\": {\n        \"snr\": 0.14064815640449524,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.25.self_attn.o_proj\": {\n        \"snr\": 0.14790543913841248,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.26.self_attn.o_proj\": {\n        \"snr\": 0.11950570344924927,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.27.self_attn.o_proj\": {\n        \"snr\": 0.062389008700847626,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.0.self_attn.q_proj\": {\n        \"snr\": 0.02138795144855976,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.1.self_attn.q_proj\": {\n        \"snr\": 0.14676862955093384,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.2.self_attn.q_proj\": {\n        \"snr\": 0.16297142207622528,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.3.self_attn.q_proj\": {\n        \"snr\": 0.12198334187269211,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.4.self_attn.q_proj\": {\n        \"snr\": 0.13921146094799042,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.5.self_attn.q_proj\": {\n        \"snr\": 0.15567339956760406,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.6.self_attn.q_proj\": {\n        \"snr\": 0.1589033454656601,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.7.self_attn.q_proj\": {\n        \"snr\": 0.195299431681633,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.8.self_attn.q_proj\": {\n        \"snr\": 0.22430908679962158,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.9.self_attn.q_proj\": {\n        \"snr\": 0.2011336237192154,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.10.self_attn.q_proj\": {\n        \"snr\": 0.1982448250055313,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.11.self_attn.q_proj\": {\n        \"snr\": 0.22880099713802338,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.12.self_attn.q_proj\": {\n        \"snr\": 0.22898294031620026,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.13.self_attn.q_proj\": {\n        \"snr\": 0.21394900977611542,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.14.self_attn.q_proj\": {\n        \"snr\": 0.24130398035049438,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.15.self_attn.q_proj\": {\n        \"snr\": 0.21905161440372467,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.16.self_attn.q_proj\": {\n        \"snr\": 0.2319282442331314,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.17.self_attn.q_proj\": {\n        \"snr\": 0.24004821479320526,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.18.self_attn.q_proj\": {\n        \"snr\": 0.22754515707492828,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.19.self_attn.q_proj\": {\n        \"snr\": 0.2086794078350067,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.20.self_attn.q_proj\": {\n        \"snr\": 0.2290779948234558,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.21.self_attn.q_proj\": {\n        \"snr\": 0.250373899936676,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.22.self_attn.q_proj\": {\n        \"snr\": 0.23474709689617157,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.23.self_attn.q_proj\": {\n        \"snr\": 0.20302507281303406,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.24.self_attn.q_proj\": {\n        \"snr\": 0.21992310881614685,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.25.self_attn.q_proj\": {\n        \"snr\": 0.2120121270418167,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.26.self_attn.q_proj\": {\n        \"snr\": 0.24161922931671143,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.27.self_attn.q_proj\": {\n        \"snr\": 0.12337693572044373,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.0.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.1.self_attn.v_proj\": {\n        \"snr\": 231.07347106933594,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.2.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.3.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.4.self_attn.v_proj\": {\n        \"snr\": 22.34870719909668,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.5.self_attn.v_proj\": {\n        \"snr\": 246.30386352539062,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.6.self_attn.v_proj\": {\n        \"snr\": 499.5611572265625,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.7.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.8.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.9.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.10.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.11.self_attn.v_proj\": {\n        \"snr\": 69.09609985351562,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.12.self_attn.v_proj\": {\n        \"snr\": 983.3341674804688,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.13.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.14.self_attn.v_proj\": {\n        \"snr\": 64.04925537109375,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.15.self_attn.v_proj\": {\n        \"snr\": 28.41021728515625,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.16.self_attn.v_proj\": {\n        \"snr\": 724.2736206054688,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.17.self_attn.v_proj\": {\n        \"snr\": 63.35670852661133,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.18.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.19.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.20.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.21.self_attn.v_proj\": {\n        \"snr\": Infinity,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.22.self_attn.v_proj\": {\n        \"snr\": 238.2569122314453,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.23.self_attn.v_proj\": {\n        \"snr\": 111.78319549560547,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.24.self_attn.v_proj\": {\n        \"snr\": 687.0054931640625,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.25.self_attn.v_proj\": {\n        \"snr\": 565.3272705078125,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.26.self_attn.v_proj\": {\n        \"snr\": 4.064513683319092,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.27.self_attn.v_proj\": {\n        \"snr\": 4.335177421569824,\n        \"type\": \"self_attn.v_proj\"\n    }\n}\n"
  },
  {
    "path": "src/axolotl/integrations/spectrum/model_snr_results/snr_results_google-gemma-2-2b.json",
    "content": "{\n    \"model.layers.0.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.1.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.2.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.3.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.4.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.5.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.6.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.7.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.8.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.9.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.10.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.11.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.12.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.13.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.14.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.15.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.16.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.17.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.18.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.19.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.20.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.21.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.22.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.23.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.24.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.25.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"lm_head\": {\n        \"snr\": 4.538210391998291,\n        \"type\": \"lm_head\"\n    },\n    \"model.layers.0.mlp.down_proj\": {\n        \"snr\": 7.746472358703613,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.1.mlp.down_proj\": {\n        \"snr\": 4.3358893394470215,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.2.mlp.down_proj\": {\n        \"snr\": 26.88057518005371,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.3.mlp.down_proj\": {\n        \"snr\": 8.699942588806152,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.4.mlp.down_proj\": {\n        \"snr\": 32.808380126953125,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.5.mlp.down_proj\": {\n        \"snr\": 10.831522941589355,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.6.mlp.down_proj\": {\n        \"snr\": 18.843679428100586,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.7.mlp.down_proj\": {\n        \"snr\": 9.348078727722168,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.8.mlp.down_proj\": {\n        \"snr\": 7.061270236968994,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.9.mlp.down_proj\": {\n        \"snr\": 5.454320907592773,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.10.mlp.down_proj\": {\n        \"snr\": 7.386133193969727,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.11.mlp.down_proj\": {\n        \"snr\": 6.648562908172607,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.12.mlp.down_proj\": {\n        \"snr\": 5.853652477264404,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.13.mlp.down_proj\": {\n        \"snr\": 8.570493698120117,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.14.mlp.down_proj\": {\n        \"snr\": 13.120837211608887,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.15.mlp.down_proj\": {\n        \"snr\": 14.780969619750977,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.16.mlp.down_proj\": {\n        \"snr\": 6.953134059906006,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.17.mlp.down_proj\": {\n        \"snr\": 12.589436531066895,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.18.mlp.down_proj\": {\n        \"snr\": 8.844094276428223,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.19.mlp.down_proj\": {\n        \"snr\": 7.598869800567627,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.20.mlp.down_proj\": {\n        \"snr\": 11.293925285339355,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.21.mlp.down_proj\": {\n        \"snr\": 9.384604454040527,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.22.mlp.down_proj\": {\n        \"snr\": 12.12533187866211,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.23.mlp.down_proj\": {\n        \"snr\": 11.217570304870605,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.24.mlp.down_proj\": {\n        \"snr\": 14.197714805603027,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.25.mlp.down_proj\": {\n        \"snr\": 12.449926376342773,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.0.mlp.gate_proj\": {\n        \"snr\": 16.885862350463867,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.1.mlp.gate_proj\": {\n        \"snr\": 23.410266876220703,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.2.mlp.gate_proj\": {\n        \"snr\": 22.57662582397461,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.3.mlp.gate_proj\": {\n        \"snr\": 17.29996681213379,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.4.mlp.gate_proj\": {\n        \"snr\": 11.718637466430664,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.5.mlp.gate_proj\": {\n        \"snr\": 6.376136779785156,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.6.mlp.gate_proj\": {\n        \"snr\": 6.794021129608154,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.7.mlp.gate_proj\": {\n        \"snr\": 3.2425343990325928,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.8.mlp.gate_proj\": {\n        \"snr\": 2.368421792984009,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.9.mlp.gate_proj\": {\n        \"snr\": 3.3193087577819824,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.10.mlp.gate_proj\": {\n        \"snr\": 3.9515960216522217,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.11.mlp.gate_proj\": {\n        \"snr\": 3.2761318683624268,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.12.mlp.gate_proj\": {\n        \"snr\": 4.026322841644287,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.13.mlp.gate_proj\": {\n        \"snr\": 3.415473699569702,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.14.mlp.gate_proj\": {\n        \"snr\": 3.3418092727661133,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.15.mlp.gate_proj\": {\n        \"snr\": 3.6233012676239014,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.16.mlp.gate_proj\": {\n        \"snr\": 3.2199010848999023,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.17.mlp.gate_proj\": {\n        \"snr\": 3.6848936080932617,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.18.mlp.gate_proj\": {\n        \"snr\": 3.4439642429351807,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.19.mlp.gate_proj\": {\n        \"snr\": 3.7366604804992676,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.20.mlp.gate_proj\": {\n        \"snr\": 4.262336254119873,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.21.mlp.gate_proj\": {\n        \"snr\": 4.333253860473633,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.22.mlp.gate_proj\": {\n        \"snr\": 3.640247344970703,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.23.mlp.gate_proj\": {\n        \"snr\": 4.2978034019470215,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.24.mlp.gate_proj\": {\n        \"snr\": 4.339972496032715,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.25.mlp.gate_proj\": {\n        \"snr\": 3.8502564430236816,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.0.mlp.up_proj\": {\n        \"snr\": 28.129924774169922,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.1.mlp.up_proj\": {\n        \"snr\": 41.49960708618164,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.2.mlp.up_proj\": {\n        \"snr\": 125.47801971435547,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.3.mlp.up_proj\": {\n        \"snr\": 119.93355560302734,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.4.mlp.up_proj\": {\n        \"snr\": 162.62631225585938,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.5.mlp.up_proj\": {\n        \"snr\": 32.36909484863281,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.6.mlp.up_proj\": {\n        \"snr\": 49.10078430175781,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.7.mlp.up_proj\": {\n        \"snr\": 28.541580200195312,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.8.mlp.up_proj\": {\n        \"snr\": 14.764090538024902,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.9.mlp.up_proj\": {\n        \"snr\": 16.5697078704834,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.10.mlp.up_proj\": {\n        \"snr\": 19.26059913635254,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.11.mlp.up_proj\": {\n        \"snr\": 15.082040786743164,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.12.mlp.up_proj\": {\n        \"snr\": 15.5792875289917,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.13.mlp.up_proj\": {\n        \"snr\": 9.84595012664795,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.14.mlp.up_proj\": {\n        \"snr\": 11.506875991821289,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.15.mlp.up_proj\": {\n        \"snr\": 21.507600784301758,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.16.mlp.up_proj\": {\n        \"snr\": 15.110466957092285,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.17.mlp.up_proj\": {\n        \"snr\": 27.062183380126953,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.18.mlp.up_proj\": {\n        \"snr\": 16.40383529663086,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.19.mlp.up_proj\": {\n        \"snr\": 13.117464065551758,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.20.mlp.up_proj\": {\n        \"snr\": 11.393353462219238,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.21.mlp.up_proj\": {\n        \"snr\": 10.791608810424805,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.22.mlp.up_proj\": {\n        \"snr\": 7.512388706207275,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.23.mlp.up_proj\": {\n        \"snr\": 9.889434814453125,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.24.mlp.up_proj\": {\n        \"snr\": 7.587779521942139,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.25.mlp.up_proj\": {\n        \"snr\": 4.561068058013916,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.embed_tokens\": {\n        \"snr\": 4.538210391998291,\n        \"type\": \"model.embed_tokens\"\n    },\n    \"model.norm\": {\n        \"snr\": Infinity,\n        \"type\": \"model.norm\"\n    },\n    \"model.layers.0.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.1.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.2.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.3.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.4.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.5.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.6.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.7.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.8.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.9.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.10.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.11.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.12.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.13.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.14.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.15.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.16.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.17.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.18.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.19.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.20.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.21.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.22.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.23.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.24.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.25.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.0.post_feedforward_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_feedforward_layernorm\"\n    },\n    \"model.layers.1.post_feedforward_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_feedforward_layernorm\"\n    },\n    \"model.layers.2.post_feedforward_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_feedforward_layernorm\"\n    },\n    \"model.layers.3.post_feedforward_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_feedforward_layernorm\"\n    },\n    \"model.layers.4.post_feedforward_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_feedforward_layernorm\"\n    },\n    \"model.layers.5.post_feedforward_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_feedforward_layernorm\"\n    },\n    \"model.layers.6.post_feedforward_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_feedforward_layernorm\"\n    },\n    \"model.layers.7.post_feedforward_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_feedforward_layernorm\"\n    },\n    \"model.layers.8.post_feedforward_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_feedforward_layernorm\"\n    },\n    \"model.layers.9.post_feedforward_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_feedforward_layernorm\"\n    },\n    \"model.layers.10.post_feedforward_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_feedforward_layernorm\"\n    },\n    \"model.layers.11.post_feedforward_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_feedforward_layernorm\"\n    },\n    \"model.layers.12.post_feedforward_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_feedforward_layernorm\"\n    },\n    \"model.layers.13.post_feedforward_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_feedforward_layernorm\"\n    },\n    \"model.layers.14.post_feedforward_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_feedforward_layernorm\"\n    },\n    \"model.layers.15.post_feedforward_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_feedforward_layernorm\"\n    },\n    \"model.layers.16.post_feedforward_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_feedforward_layernorm\"\n    },\n    \"model.layers.17.post_feedforward_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_feedforward_layernorm\"\n    },\n    \"model.layers.18.post_feedforward_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_feedforward_layernorm\"\n    },\n    \"model.layers.19.post_feedforward_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_feedforward_layernorm\"\n    },\n    \"model.layers.20.post_feedforward_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_feedforward_layernorm\"\n    },\n    \"model.layers.21.post_feedforward_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_feedforward_layernorm\"\n    },\n    \"model.layers.22.post_feedforward_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_feedforward_layernorm\"\n    },\n    \"model.layers.23.post_feedforward_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_feedforward_layernorm\"\n    },\n    \"model.layers.24.post_feedforward_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_feedforward_layernorm\"\n    },\n    \"model.layers.25.post_feedforward_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_feedforward_layernorm\"\n    },\n    \"model.layers.0.pre_feedforward_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"pre_feedforward_layernorm\"\n    },\n    \"model.layers.1.pre_feedforward_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"pre_feedforward_layernorm\"\n    },\n    \"model.layers.2.pre_feedforward_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"pre_feedforward_layernorm\"\n    },\n    \"model.layers.3.pre_feedforward_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"pre_feedforward_layernorm\"\n    },\n    \"model.layers.4.pre_feedforward_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"pre_feedforward_layernorm\"\n    },\n    \"model.layers.5.pre_feedforward_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"pre_feedforward_layernorm\"\n    },\n    \"model.layers.6.pre_feedforward_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"pre_feedforward_layernorm\"\n    },\n    \"model.layers.7.pre_feedforward_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"pre_feedforward_layernorm\"\n    },\n    \"model.layers.8.pre_feedforward_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"pre_feedforward_layernorm\"\n    },\n    \"model.layers.9.pre_feedforward_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"pre_feedforward_layernorm\"\n    },\n    \"model.layers.10.pre_feedforward_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"pre_feedforward_layernorm\"\n    },\n    \"model.layers.11.pre_feedforward_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"pre_feedforward_layernorm\"\n    },\n    \"model.layers.12.pre_feedforward_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"pre_feedforward_layernorm\"\n    },\n    \"model.layers.13.pre_feedforward_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"pre_feedforward_layernorm\"\n    },\n    \"model.layers.14.pre_feedforward_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"pre_feedforward_layernorm\"\n    },\n    \"model.layers.15.pre_feedforward_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"pre_feedforward_layernorm\"\n    },\n    \"model.layers.16.pre_feedforward_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"pre_feedforward_layernorm\"\n    },\n    \"model.layers.17.pre_feedforward_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"pre_feedforward_layernorm\"\n    },\n    \"model.layers.18.pre_feedforward_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"pre_feedforward_layernorm\"\n    },\n    \"model.layers.19.pre_feedforward_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"pre_feedforward_layernorm\"\n    },\n    \"model.layers.20.pre_feedforward_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"pre_feedforward_layernorm\"\n    },\n    \"model.layers.21.pre_feedforward_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"pre_feedforward_layernorm\"\n    },\n    \"model.layers.22.pre_feedforward_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"pre_feedforward_layernorm\"\n    },\n    \"model.layers.23.pre_feedforward_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"pre_feedforward_layernorm\"\n    },\n    \"model.layers.24.pre_feedforward_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"pre_feedforward_layernorm\"\n    },\n    \"model.layers.25.pre_feedforward_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"pre_feedforward_layernorm\"\n    },\n    \"model.layers.0.self_attn.k_proj\": {\n        \"snr\": 0.5685535073280334,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.1.self_attn.k_proj\": {\n        \"snr\": 1.060130000114441,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.2.self_attn.k_proj\": {\n        \"snr\": 1.0735561847686768,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.3.self_attn.k_proj\": {\n        \"snr\": 1.0217311382293701,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.4.self_attn.k_proj\": {\n        \"snr\": 0.9687430262565613,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.5.self_attn.k_proj\": {\n        \"snr\": 0.8411160111427307,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.6.self_attn.k_proj\": {\n        \"snr\": 0.936741054058075,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.7.self_attn.k_proj\": {\n        \"snr\": 0.7236003279685974,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.8.self_attn.k_proj\": {\n        \"snr\": 0.9032857418060303,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.9.self_attn.k_proj\": {\n        \"snr\": 0.7513307929039001,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.10.self_attn.k_proj\": {\n        \"snr\": 0.6875415444374084,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.11.self_attn.k_proj\": {\n        \"snr\": 0.6611058712005615,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.12.self_attn.k_proj\": {\n        \"snr\": 0.8023670315742493,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.13.self_attn.k_proj\": {\n        \"snr\": 0.7188767194747925,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.14.self_attn.k_proj\": {\n        \"snr\": 0.7930117249488831,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.15.self_attn.k_proj\": {\n        \"snr\": 0.9076258540153503,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.16.self_attn.k_proj\": {\n        \"snr\": 0.7295113801956177,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.17.self_attn.k_proj\": {\n        \"snr\": 0.898467481136322,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.18.self_attn.k_proj\": {\n        \"snr\": 0.9652048945426941,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.19.self_attn.k_proj\": {\n        \"snr\": 0.9855819344520569,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.20.self_attn.k_proj\": {\n        \"snr\": 1.2863355875015259,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.21.self_attn.k_proj\": {\n        \"snr\": 1.116607904434204,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.22.self_attn.k_proj\": {\n        \"snr\": 0.7438228130340576,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.23.self_attn.k_proj\": {\n        \"snr\": 0.8499895334243774,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.24.self_attn.k_proj\": {\n        \"snr\": 0.7764042019844055,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.25.self_attn.k_proj\": {\n        \"snr\": 0.7127887606620789,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.0.self_attn.o_proj\": {\n        \"snr\": 0.2556447386741638,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.1.self_attn.o_proj\": {\n        \"snr\": 0.2930974066257477,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.2.self_attn.o_proj\": {\n        \"snr\": 0.27571651339530945,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.3.self_attn.o_proj\": {\n        \"snr\": 0.280631959438324,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.4.self_attn.o_proj\": {\n        \"snr\": 0.2958097755908966,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.5.self_attn.o_proj\": {\n        \"snr\": 0.3072899580001831,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.6.self_attn.o_proj\": {\n        \"snr\": 0.31374114751815796,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.7.self_attn.o_proj\": {\n        \"snr\": 0.2903076410293579,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.8.self_attn.o_proj\": {\n        \"snr\": 0.2625811696052551,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.9.self_attn.o_proj\": {\n        \"snr\": 0.2306082546710968,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.10.self_attn.o_proj\": {\n        \"snr\": 0.24869701266288757,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.11.self_attn.o_proj\": {\n        \"snr\": 0.2556127905845642,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.12.self_attn.o_proj\": {\n        \"snr\": 0.28926730155944824,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.13.self_attn.o_proj\": {\n        \"snr\": 0.25355643033981323,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.14.self_attn.o_proj\": {\n        \"snr\": 0.23122912645339966,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.15.self_attn.o_proj\": {\n        \"snr\": 0.28772857785224915,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.16.self_attn.o_proj\": {\n        \"snr\": 0.22682352364063263,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.17.self_attn.o_proj\": {\n        \"snr\": 0.2558597922325134,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.18.self_attn.o_proj\": {\n        \"snr\": 0.1773315966129303,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.19.self_attn.o_proj\": {\n        \"snr\": 0.2106105089187622,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.20.self_attn.o_proj\": {\n        \"snr\": 0.2008877396583557,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.21.self_attn.o_proj\": {\n        \"snr\": 0.1973956972360611,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.22.self_attn.o_proj\": {\n        \"snr\": 0.25533634424209595,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.23.self_attn.o_proj\": {\n        \"snr\": 0.20066529512405396,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.24.self_attn.o_proj\": {\n        \"snr\": 0.18342143297195435,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.25.self_attn.o_proj\": {\n        \"snr\": 0.3224162459373474,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.0.self_attn.q_proj\": {\n        \"snr\": 0.2074502408504486,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.1.self_attn.q_proj\": {\n        \"snr\": 0.33233126997947693,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.2.self_attn.q_proj\": {\n        \"snr\": 0.3586291968822479,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.3.self_attn.q_proj\": {\n        \"snr\": 0.2850974202156067,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.4.self_attn.q_proj\": {\n        \"snr\": 0.37816473841667175,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.5.self_attn.q_proj\": {\n        \"snr\": 0.31616899371147156,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.6.self_attn.q_proj\": {\n        \"snr\": 0.4988365173339844,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.7.self_attn.q_proj\": {\n        \"snr\": 0.4238639175891876,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.8.self_attn.q_proj\": {\n        \"snr\": 0.2674674689769745,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.9.self_attn.q_proj\": {\n        \"snr\": 0.34524214267730713,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.10.self_attn.q_proj\": {\n        \"snr\": 0.4472109377384186,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.11.self_attn.q_proj\": {\n        \"snr\": 0.41363632678985596,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.12.self_attn.q_proj\": {\n        \"snr\": 0.44623735547065735,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.13.self_attn.q_proj\": {\n        \"snr\": 0.4404333531856537,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.14.self_attn.q_proj\": {\n        \"snr\": 0.5200268626213074,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.15.self_attn.q_proj\": {\n        \"snr\": 0.4320363700389862,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.16.self_attn.q_proj\": {\n        \"snr\": 0.46235284209251404,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.17.self_attn.q_proj\": {\n        \"snr\": 0.47477203607559204,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.18.self_attn.q_proj\": {\n        \"snr\": 0.4001321494579315,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.19.self_attn.q_proj\": {\n        \"snr\": 0.42365774512290955,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.20.self_attn.q_proj\": {\n        \"snr\": 0.37057873606681824,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.21.self_attn.q_proj\": {\n        \"snr\": 0.3990235924720764,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.22.self_attn.q_proj\": {\n        \"snr\": 0.35094162821769714,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.23.self_attn.q_proj\": {\n        \"snr\": 0.35721710324287415,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.24.self_attn.q_proj\": {\n        \"snr\": 0.2812618315219879,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.25.self_attn.q_proj\": {\n        \"snr\": 0.19463211297988892,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.0.self_attn.v_proj\": {\n        \"snr\": 1.3365743160247803,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.1.self_attn.v_proj\": {\n        \"snr\": 2.402009963989258,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.2.self_attn.v_proj\": {\n        \"snr\": 3.8695859909057617,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.3.self_attn.v_proj\": {\n        \"snr\": 4.117948055267334,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.4.self_attn.v_proj\": {\n        \"snr\": 5.651231288909912,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.5.self_attn.v_proj\": {\n        \"snr\": 2.720799446105957,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.6.self_attn.v_proj\": {\n        \"snr\": 1.4446897506713867,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.7.self_attn.v_proj\": {\n        \"snr\": 4.497112274169922,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.8.self_attn.v_proj\": {\n        \"snr\": 1.7241870164871216,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.9.self_attn.v_proj\": {\n        \"snr\": 1.7104988098144531,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.10.self_attn.v_proj\": {\n        \"snr\": 1.4231206178665161,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.11.self_attn.v_proj\": {\n        \"snr\": 2.1643989086151123,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.12.self_attn.v_proj\": {\n        \"snr\": 1.5254249572753906,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.13.self_attn.v_proj\": {\n        \"snr\": 2.3788745403289795,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.14.self_attn.v_proj\": {\n        \"snr\": 3.4155967235565186,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.15.self_attn.v_proj\": {\n        \"snr\": 4.623549938201904,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.16.self_attn.v_proj\": {\n        \"snr\": 1.5291141271591187,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.17.self_attn.v_proj\": {\n        \"snr\": 3.9934189319610596,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.18.self_attn.v_proj\": {\n        \"snr\": 9.035382270812988,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.19.self_attn.v_proj\": {\n        \"snr\": 5.8578925132751465,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.20.self_attn.v_proj\": {\n        \"snr\": 3.759958505630493,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.21.self_attn.v_proj\": {\n        \"snr\": 4.558528900146484,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.22.self_attn.v_proj\": {\n        \"snr\": 0.9163281917572021,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.23.self_attn.v_proj\": {\n        \"snr\": 2.564377546310425,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.24.self_attn.v_proj\": {\n        \"snr\": 3.689103841781616,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.25.self_attn.v_proj\": {\n        \"snr\": 5.6444854736328125,\n        \"type\": \"self_attn.v_proj\"\n    }\n}\n"
  },
  {
    "path": "src/axolotl/integrations/spectrum/model_snr_results/snr_results_meta-llama-Llama-3.2-1B-Instruct.json",
    "content": "{\n    \"model.layers.0.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.1.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.2.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.3.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.4.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.5.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.6.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.7.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.8.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.9.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.10.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.11.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.12.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.13.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.14.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.15.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"lm_head\": {\n        \"snr\": Infinity,\n        \"type\": \"lm_head\"\n    },\n    \"model.layers.0.mlp.down_proj\": {\n        \"snr\": 70.0594253540039,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.1.mlp.down_proj\": {\n        \"snr\": 11.135851860046387,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.2.mlp.down_proj\": {\n        \"snr\": 7.035482883453369,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.3.mlp.down_proj\": {\n        \"snr\": 6.422532081604004,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.4.mlp.down_proj\": {\n        \"snr\": 5.748020172119141,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.5.mlp.down_proj\": {\n        \"snr\": 3.885556697845459,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.6.mlp.down_proj\": {\n        \"snr\": 3.4336745738983154,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.7.mlp.down_proj\": {\n        \"snr\": 2.791595935821533,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.8.mlp.down_proj\": {\n        \"snr\": 5.36277961730957,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.9.mlp.down_proj\": {\n        \"snr\": 4.459208011627197,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.10.mlp.down_proj\": {\n        \"snr\": 6.272170066833496,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.11.mlp.down_proj\": {\n        \"snr\": 5.264761447906494,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.12.mlp.down_proj\": {\n        \"snr\": 4.324735641479492,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.13.mlp.down_proj\": {\n        \"snr\": 3.878648042678833,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.14.mlp.down_proj\": {\n        \"snr\": 2.9773054122924805,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.15.mlp.down_proj\": {\n        \"snr\": 4.471445560455322,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.0.mlp.gate_proj\": {\n        \"snr\": 25.227100372314453,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.1.mlp.gate_proj\": {\n        \"snr\": 6.58299446105957,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.2.mlp.gate_proj\": {\n        \"snr\": 3.4688243865966797,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.3.mlp.gate_proj\": {\n        \"snr\": 1.555246114730835,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.4.mlp.gate_proj\": {\n        \"snr\": 0.7770601511001587,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.5.mlp.gate_proj\": {\n        \"snr\": 0.6239906549453735,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.6.mlp.gate_proj\": {\n        \"snr\": 0.6440379023551941,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.7.mlp.gate_proj\": {\n        \"snr\": 0.5120116472244263,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.8.mlp.gate_proj\": {\n        \"snr\": 0.6544050574302673,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.9.mlp.gate_proj\": {\n        \"snr\": 0.5381016731262207,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.10.mlp.gate_proj\": {\n        \"snr\": 0.622873842716217,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.11.mlp.gate_proj\": {\n        \"snr\": 0.9361700415611267,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.12.mlp.gate_proj\": {\n        \"snr\": 1.475605845451355,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.13.mlp.gate_proj\": {\n        \"snr\": 1.608325719833374,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.14.mlp.gate_proj\": {\n        \"snr\": 1.0720024108886719,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.15.mlp.gate_proj\": {\n        \"snr\": 0.7111338973045349,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.0.mlp.up_proj\": {\n        \"snr\": 28.431896209716797,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.1.mlp.up_proj\": {\n        \"snr\": 15.546019554138184,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.2.mlp.up_proj\": {\n        \"snr\": 23.048023223876953,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.3.mlp.up_proj\": {\n        \"snr\": 25.790977478027344,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.4.mlp.up_proj\": {\n        \"snr\": 18.552549362182617,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.5.mlp.up_proj\": {\n        \"snr\": 8.85106372833252,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.6.mlp.up_proj\": {\n        \"snr\": 10.653799057006836,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.7.mlp.up_proj\": {\n        \"snr\": 7.365357875823975,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.8.mlp.up_proj\": {\n        \"snr\": 11.98373794555664,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.9.mlp.up_proj\": {\n        \"snr\": 8.04493236541748,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.10.mlp.up_proj\": {\n        \"snr\": 8.523039817810059,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.11.mlp.up_proj\": {\n        \"snr\": 5.381742477416992,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.12.mlp.up_proj\": {\n        \"snr\": 3.9845118522644043,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.13.mlp.up_proj\": {\n        \"snr\": 3.4893221855163574,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.14.mlp.up_proj\": {\n        \"snr\": 1.764201045036316,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.15.mlp.up_proj\": {\n        \"snr\": 0.9730708599090576,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.embed_tokens\": {\n        \"snr\": Infinity,\n        \"type\": \"model.embed_tokens\"\n    },\n    \"model.norm\": {\n        \"snr\": Infinity,\n        \"type\": \"model.norm\"\n    },\n    \"model.layers.0.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.1.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.2.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.3.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.4.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.5.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.6.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.7.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.8.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.9.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.10.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.11.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.12.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.13.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.14.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.15.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.0.self_attn.k_proj\": {\n        \"snr\": 0.11727584153413773,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.1.self_attn.k_proj\": {\n        \"snr\": 0.24786807596683502,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.2.self_attn.k_proj\": {\n        \"snr\": 0.36378130316734314,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.3.self_attn.k_proj\": {\n        \"snr\": 0.2983120381832123,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.4.self_attn.k_proj\": {\n        \"snr\": 0.33789733052253723,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.5.self_attn.k_proj\": {\n        \"snr\": 0.29155924916267395,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.6.self_attn.k_proj\": {\n        \"snr\": 0.2537297010421753,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.7.self_attn.k_proj\": {\n        \"snr\": 0.28204113245010376,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.8.self_attn.k_proj\": {\n        \"snr\": 0.2776711583137512,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.9.self_attn.k_proj\": {\n        \"snr\": 0.2927376627922058,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.10.self_attn.k_proj\": {\n        \"snr\": 0.31486213207244873,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.11.self_attn.k_proj\": {\n        \"snr\": 0.32363659143447876,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.12.self_attn.k_proj\": {\n        \"snr\": 0.31382912397384644,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.13.self_attn.k_proj\": {\n        \"snr\": 0.4635234773159027,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.14.self_attn.k_proj\": {\n        \"snr\": 0.25379249453544617,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.15.self_attn.k_proj\": {\n        \"snr\": 0.2628238797187805,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.0.self_attn.o_proj\": {\n        \"snr\": 0.27602291107177734,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.1.self_attn.o_proj\": {\n        \"snr\": 0.2149604707956314,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.2.self_attn.o_proj\": {\n        \"snr\": 0.2540294826030731,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.3.self_attn.o_proj\": {\n        \"snr\": 0.27978822588920593,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.4.self_attn.o_proj\": {\n        \"snr\": 0.3121289908885956,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.5.self_attn.o_proj\": {\n        \"snr\": 0.35037684440612793,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.6.self_attn.o_proj\": {\n        \"snr\": 0.366205096244812,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.7.self_attn.o_proj\": {\n        \"snr\": 0.3692712187767029,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.8.self_attn.o_proj\": {\n        \"snr\": 0.3301038146018982,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.9.self_attn.o_proj\": {\n        \"snr\": 0.3003396987915039,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.10.self_attn.o_proj\": {\n        \"snr\": 0.30804169178009033,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.11.self_attn.o_proj\": {\n        \"snr\": 0.28501132130622864,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.12.self_attn.o_proj\": {\n        \"snr\": 0.2171541005373001,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.13.self_attn.o_proj\": {\n        \"snr\": 0.19183959066867828,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.14.self_attn.o_proj\": {\n        \"snr\": 0.19215913116931915,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.15.self_attn.o_proj\": {\n        \"snr\": 0.25486502051353455,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.0.self_attn.q_proj\": {\n        \"snr\": 0.03850084915757179,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.1.self_attn.q_proj\": {\n        \"snr\": 0.0713055431842804,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.2.self_attn.q_proj\": {\n        \"snr\": 0.07948919385671616,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.3.self_attn.q_proj\": {\n        \"snr\": 0.08047746121883392,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.4.self_attn.q_proj\": {\n        \"snr\": 0.0852593332529068,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.5.self_attn.q_proj\": {\n        \"snr\": 0.09794823825359344,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.6.self_attn.q_proj\": {\n        \"snr\": 0.09627152234315872,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.7.self_attn.q_proj\": {\n        \"snr\": 0.11065381020307541,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.8.self_attn.q_proj\": {\n        \"snr\": 0.12031875550746918,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.9.self_attn.q_proj\": {\n        \"snr\": 0.09804573655128479,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.10.self_attn.q_proj\": {\n        \"snr\": 0.10897502303123474,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.11.self_attn.q_proj\": {\n        \"snr\": 0.09267337620258331,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.12.self_attn.q_proj\": {\n        \"snr\": 0.08803492039442062,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.13.self_attn.q_proj\": {\n        \"snr\": 0.0902542844414711,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.14.self_attn.q_proj\": {\n        \"snr\": 0.10154066979885101,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.15.self_attn.q_proj\": {\n        \"snr\": 0.09083802253007889,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.0.self_attn.v_proj\": {\n        \"snr\": 2.842210054397583,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.1.self_attn.v_proj\": {\n        \"snr\": 10.59461498260498,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.2.self_attn.v_proj\": {\n        \"snr\": 8.993025779724121,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.3.self_attn.v_proj\": {\n        \"snr\": 62.567787170410156,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.4.self_attn.v_proj\": {\n        \"snr\": 23.80082893371582,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.5.self_attn.v_proj\": {\n        \"snr\": 7.957369804382324,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.6.self_attn.v_proj\": {\n        \"snr\": 12.01815414428711,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.7.self_attn.v_proj\": {\n        \"snr\": 5.095500469207764,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.8.self_attn.v_proj\": {\n        \"snr\": 11.719332695007324,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.9.self_attn.v_proj\": {\n        \"snr\": 555.0869750976562,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.10.self_attn.v_proj\": {\n        \"snr\": 22.95538330078125,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.11.self_attn.v_proj\": {\n        \"snr\": 30.042158126831055,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.12.self_attn.v_proj\": {\n        \"snr\": 9.577271461486816,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.13.self_attn.v_proj\": {\n        \"snr\": 18.176361083984375,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.14.self_attn.v_proj\": {\n        \"snr\": 1.5695856809616089,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.15.self_attn.v_proj\": {\n        \"snr\": 2.7235565185546875,\n        \"type\": \"self_attn.v_proj\"\n    }\n}\n"
  },
  {
    "path": "src/axolotl/integrations/spectrum/model_snr_results/snr_results_meta-llama-Llama-3.2-1B.json",
    "content": "{\n    \"model.layers.0.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.1.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.2.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.3.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.4.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.5.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.6.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.7.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.8.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.9.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.10.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.11.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.12.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.13.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.14.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.15.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"lm_head\": {\n        \"snr\": Infinity,\n        \"type\": \"lm_head\"\n    },\n    \"model.layers.0.mlp.down_proj\": {\n        \"snr\": 57.09797286987305,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.1.mlp.down_proj\": {\n        \"snr\": 9.538983345031738,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.2.mlp.down_proj\": {\n        \"snr\": 6.227016925811768,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.3.mlp.down_proj\": {\n        \"snr\": 5.660686492919922,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.4.mlp.down_proj\": {\n        \"snr\": 5.178432464599609,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.5.mlp.down_proj\": {\n        \"snr\": 3.5638349056243896,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.6.mlp.down_proj\": {\n        \"snr\": 3.0918056964874268,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.7.mlp.down_proj\": {\n        \"snr\": 2.456392288208008,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.8.mlp.down_proj\": {\n        \"snr\": 4.525328636169434,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.9.mlp.down_proj\": {\n        \"snr\": 3.9409055709838867,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.10.mlp.down_proj\": {\n        \"snr\": 5.447249412536621,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.11.mlp.down_proj\": {\n        \"snr\": 4.807600975036621,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.12.mlp.down_proj\": {\n        \"snr\": 3.915374517440796,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.13.mlp.down_proj\": {\n        \"snr\": 3.4820363521575928,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.14.mlp.down_proj\": {\n        \"snr\": 2.6045074462890625,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.15.mlp.down_proj\": {\n        \"snr\": 3.7237701416015625,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.0.mlp.gate_proj\": {\n        \"snr\": 22.160131454467773,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.1.mlp.gate_proj\": {\n        \"snr\": 6.072206020355225,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.2.mlp.gate_proj\": {\n        \"snr\": 3.2467362880706787,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.3.mlp.gate_proj\": {\n        \"snr\": 1.4111896753311157,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.4.mlp.gate_proj\": {\n        \"snr\": 0.7405938506126404,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.5.mlp.gate_proj\": {\n        \"snr\": 0.5916463136672974,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.6.mlp.gate_proj\": {\n        \"snr\": 0.6149423718452454,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.7.mlp.gate_proj\": {\n        \"snr\": 0.48369669914245605,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.8.mlp.gate_proj\": {\n        \"snr\": 0.6047574877738953,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.9.mlp.gate_proj\": {\n        \"snr\": 0.5092479586601257,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.10.mlp.gate_proj\": {\n        \"snr\": 0.5999670624732971,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.11.mlp.gate_proj\": {\n        \"snr\": 0.8980127573013306,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.12.mlp.gate_proj\": {\n        \"snr\": 1.4252448081970215,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.13.mlp.gate_proj\": {\n        \"snr\": 1.509937047958374,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.14.mlp.gate_proj\": {\n        \"snr\": 1.0066585540771484,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.15.mlp.gate_proj\": {\n        \"snr\": 0.6413647532463074,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.0.mlp.up_proj\": {\n        \"snr\": 26.08852195739746,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.1.mlp.up_proj\": {\n        \"snr\": 13.382951736450195,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.2.mlp.up_proj\": {\n        \"snr\": 20.088768005371094,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.3.mlp.up_proj\": {\n        \"snr\": 23.0632381439209,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.4.mlp.up_proj\": {\n        \"snr\": 16.07433319091797,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.5.mlp.up_proj\": {\n        \"snr\": 8.00507640838623,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.6.mlp.up_proj\": {\n        \"snr\": 9.538354873657227,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.7.mlp.up_proj\": {\n        \"snr\": 6.286602973937988,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.8.mlp.up_proj\": {\n        \"snr\": 10.092820167541504,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.9.mlp.up_proj\": {\n        \"snr\": 7.193963527679443,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.10.mlp.up_proj\": {\n        \"snr\": 7.320116996765137,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.11.mlp.up_proj\": {\n        \"snr\": 4.8728532791137695,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.12.mlp.up_proj\": {\n        \"snr\": 3.596583366394043,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.13.mlp.up_proj\": {\n        \"snr\": 3.166161298751831,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.14.mlp.up_proj\": {\n        \"snr\": 1.5600818395614624,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.15.mlp.up_proj\": {\n        \"snr\": 0.8726214170455933,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.embed_tokens\": {\n        \"snr\": Infinity,\n        \"type\": \"model.embed_tokens\"\n    },\n    \"model.norm\": {\n        \"snr\": Infinity,\n        \"type\": \"model.norm\"\n    },\n    \"model.layers.0.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.1.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.2.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.3.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.4.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.5.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.6.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.7.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.8.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.9.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.10.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.11.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.12.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.13.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.14.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.15.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.0.self_attn.k_proj\": {\n        \"snr\": 0.1154392883181572,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.1.self_attn.k_proj\": {\n        \"snr\": 0.24299409985542297,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.2.self_attn.k_proj\": {\n        \"snr\": 0.3624322712421417,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.3.self_attn.k_proj\": {\n        \"snr\": 0.29509487748146057,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.4.self_attn.k_proj\": {\n        \"snr\": 0.32953736186027527,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.5.self_attn.k_proj\": {\n        \"snr\": 0.2908833622932434,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.6.self_attn.k_proj\": {\n        \"snr\": 0.2488437294960022,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.7.self_attn.k_proj\": {\n        \"snr\": 0.27847856283187866,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.8.self_attn.k_proj\": {\n        \"snr\": 0.27143892645835876,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.9.self_attn.k_proj\": {\n        \"snr\": 0.28804272413253784,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.10.self_attn.k_proj\": {\n        \"snr\": 0.31197959184646606,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.11.self_attn.k_proj\": {\n        \"snr\": 0.3203586935997009,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.12.self_attn.k_proj\": {\n        \"snr\": 0.30905747413635254,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.13.self_attn.k_proj\": {\n        \"snr\": 0.46828722953796387,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.14.self_attn.k_proj\": {\n        \"snr\": 0.24205778539180756,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.15.self_attn.k_proj\": {\n        \"snr\": 0.2559327781200409,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.0.self_attn.o_proj\": {\n        \"snr\": 0.2638678550720215,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.1.self_attn.o_proj\": {\n        \"snr\": 0.21109595894813538,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.2.self_attn.o_proj\": {\n        \"snr\": 0.24751724302768707,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.3.self_attn.o_proj\": {\n        \"snr\": 0.2728094160556793,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.4.self_attn.o_proj\": {\n        \"snr\": 0.3001374304294586,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.5.self_attn.o_proj\": {\n        \"snr\": 0.33903488516807556,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.6.self_attn.o_proj\": {\n        \"snr\": 0.3530929982662201,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.7.self_attn.o_proj\": {\n        \"snr\": 0.36753255128860474,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.8.self_attn.o_proj\": {\n        \"snr\": 0.3373180329799652,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.9.self_attn.o_proj\": {\n        \"snr\": 0.2970578670501709,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.10.self_attn.o_proj\": {\n        \"snr\": 0.3076324760913849,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.11.self_attn.o_proj\": {\n        \"snr\": 0.2766900658607483,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.12.self_attn.o_proj\": {\n        \"snr\": 0.20973259210586548,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.13.self_attn.o_proj\": {\n        \"snr\": 0.18185566365718842,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.14.self_attn.o_proj\": {\n        \"snr\": 0.18329747021198273,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.15.self_attn.o_proj\": {\n        \"snr\": 0.2437991499900818,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.0.self_attn.q_proj\": {\n        \"snr\": 0.038040731102228165,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.1.self_attn.q_proj\": {\n        \"snr\": 0.0707998052239418,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.2.self_attn.q_proj\": {\n        \"snr\": 0.0787411704659462,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.3.self_attn.q_proj\": {\n        \"snr\": 0.08089710026979446,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.4.self_attn.q_proj\": {\n        \"snr\": 0.08591937273740768,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.5.self_attn.q_proj\": {\n        \"snr\": 0.09852176159620285,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.6.self_attn.q_proj\": {\n        \"snr\": 0.09690654277801514,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.7.self_attn.q_proj\": {\n        \"snr\": 0.11181341856718063,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.8.self_attn.q_proj\": {\n        \"snr\": 0.12042108923196793,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.9.self_attn.q_proj\": {\n        \"snr\": 0.09799323976039886,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.10.self_attn.q_proj\": {\n        \"snr\": 0.10901063680648804,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.11.self_attn.q_proj\": {\n        \"snr\": 0.09307146072387695,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.12.self_attn.q_proj\": {\n        \"snr\": 0.0880950540304184,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.13.self_attn.q_proj\": {\n        \"snr\": 0.08886399120092392,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.14.self_attn.q_proj\": {\n        \"snr\": 0.09955056011676788,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.15.self_attn.q_proj\": {\n        \"snr\": 0.08929339051246643,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.0.self_attn.v_proj\": {\n        \"snr\": 2.5501928329467773,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.1.self_attn.v_proj\": {\n        \"snr\": 9.449499130249023,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.2.self_attn.v_proj\": {\n        \"snr\": 7.9920830726623535,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.3.self_attn.v_proj\": {\n        \"snr\": 50.69462585449219,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.4.self_attn.v_proj\": {\n        \"snr\": 19.083511352539062,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.5.self_attn.v_proj\": {\n        \"snr\": 7.21597146987915,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.6.self_attn.v_proj\": {\n        \"snr\": 11.27744197845459,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.7.self_attn.v_proj\": {\n        \"snr\": 4.579711437225342,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.8.self_attn.v_proj\": {\n        \"snr\": 10.940719604492188,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.9.self_attn.v_proj\": {\n        \"snr\": 553.4417724609375,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.10.self_attn.v_proj\": {\n        \"snr\": 20.59434700012207,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.11.self_attn.v_proj\": {\n        \"snr\": 26.636865615844727,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.12.self_attn.v_proj\": {\n        \"snr\": 8.614749908447266,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.13.self_attn.v_proj\": {\n        \"snr\": 17.722007751464844,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.14.self_attn.v_proj\": {\n        \"snr\": 1.48500657081604,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.15.self_attn.v_proj\": {\n        \"snr\": 2.5776851177215576,\n        \"type\": \"self_attn.v_proj\"\n    }\n}\n"
  },
  {
    "path": "src/axolotl/integrations/spectrum/model_snr_results/snr_results_meta-llama-Llama-3.2-3B-Instruct.json",
    "content": "{\n    \"model.layers.0.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.1.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.2.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.3.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.4.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.5.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.6.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.7.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.8.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.9.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.10.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.11.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.12.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.13.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.14.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.15.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.16.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.17.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.18.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.19.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.20.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.21.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.22.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.23.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.24.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.25.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.26.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.27.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"lm_head\": {\n        \"snr\": Infinity,\n        \"type\": \"lm_head\"\n    },\n    \"model.layers.0.mlp.down_proj\": {\n        \"snr\": 2.306217670440674,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.1.mlp.down_proj\": {\n        \"snr\": 2.2327167987823486,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.2.mlp.down_proj\": {\n        \"snr\": 1.4501516819000244,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.3.mlp.down_proj\": {\n        \"snr\": 1.363667607307434,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.4.mlp.down_proj\": {\n        \"snr\": 1.4520279169082642,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.5.mlp.down_proj\": {\n        \"snr\": 1.4664665460586548,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.6.mlp.down_proj\": {\n        \"snr\": 1.4122329950332642,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.7.mlp.down_proj\": {\n        \"snr\": 1.0504299402236938,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.8.mlp.down_proj\": {\n        \"snr\": 0.9837537407875061,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.9.mlp.down_proj\": {\n        \"snr\": 0.8659006357192993,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.10.mlp.down_proj\": {\n        \"snr\": 0.7936406135559082,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.11.mlp.down_proj\": {\n        \"snr\": 0.9000886678695679,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.12.mlp.down_proj\": {\n        \"snr\": 1.1559213399887085,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.13.mlp.down_proj\": {\n        \"snr\": 1.3054672479629517,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.14.mlp.down_proj\": {\n        \"snr\": 1.196791410446167,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.15.mlp.down_proj\": {\n        \"snr\": 1.3163655996322632,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.16.mlp.down_proj\": {\n        \"snr\": 1.3388997316360474,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.17.mlp.down_proj\": {\n        \"snr\": 1.592497706413269,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.18.mlp.down_proj\": {\n        \"snr\": 1.5399079322814941,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.19.mlp.down_proj\": {\n        \"snr\": 1.5683293342590332,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.20.mlp.down_proj\": {\n        \"snr\": 1.4739630222320557,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.21.mlp.down_proj\": {\n        \"snr\": 1.2608393430709839,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.22.mlp.down_proj\": {\n        \"snr\": 1.2087301015853882,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.23.mlp.down_proj\": {\n        \"snr\": 1.1851829290390015,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.24.mlp.down_proj\": {\n        \"snr\": 1.0537594556808472,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.25.mlp.down_proj\": {\n        \"snr\": 1.1649317741394043,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.26.mlp.down_proj\": {\n        \"snr\": 1.2376821041107178,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.27.mlp.down_proj\": {\n        \"snr\": 1.147771954536438,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.0.mlp.gate_proj\": {\n        \"snr\": 0.9385462999343872,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.1.mlp.gate_proj\": {\n        \"snr\": 0.8528683185577393,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.2.mlp.gate_proj\": {\n        \"snr\": 0.761657178401947,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.3.mlp.gate_proj\": {\n        \"snr\": 0.6598325371742249,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.4.mlp.gate_proj\": {\n        \"snr\": 0.44578588008880615,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.5.mlp.gate_proj\": {\n        \"snr\": 0.4053060710430145,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.6.mlp.gate_proj\": {\n        \"snr\": 0.3588462769985199,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.7.mlp.gate_proj\": {\n        \"snr\": 0.35667839646339417,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.8.mlp.gate_proj\": {\n        \"snr\": 0.3106202781200409,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.9.mlp.gate_proj\": {\n        \"snr\": 0.2821919322013855,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.10.mlp.gate_proj\": {\n        \"snr\": 0.29143741726875305,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.11.mlp.gate_proj\": {\n        \"snr\": 0.29830989241600037,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.12.mlp.gate_proj\": {\n        \"snr\": 0.2862427532672882,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.13.mlp.gate_proj\": {\n        \"snr\": 0.2797018587589264,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.14.mlp.gate_proj\": {\n        \"snr\": 0.2679217755794525,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.15.mlp.gate_proj\": {\n        \"snr\": 0.2782425880432129,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.16.mlp.gate_proj\": {\n        \"snr\": 0.3503592610359192,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.17.mlp.gate_proj\": {\n        \"snr\": 0.3968559205532074,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.18.mlp.gate_proj\": {\n        \"snr\": 0.4318574070930481,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.19.mlp.gate_proj\": {\n        \"snr\": 0.4693693220615387,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.20.mlp.gate_proj\": {\n        \"snr\": 0.5051979422569275,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.21.mlp.gate_proj\": {\n        \"snr\": 0.5675955414772034,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.22.mlp.gate_proj\": {\n        \"snr\": 0.5861843824386597,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.23.mlp.gate_proj\": {\n        \"snr\": 0.4759417772293091,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.24.mlp.gate_proj\": {\n        \"snr\": 0.38529056310653687,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.25.mlp.gate_proj\": {\n        \"snr\": 0.3180919587612152,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.26.mlp.gate_proj\": {\n        \"snr\": 0.2695689797401428,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.27.mlp.gate_proj\": {\n        \"snr\": 0.21765239536762238,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.0.mlp.up_proj\": {\n        \"snr\": 1.4919718503952026,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.1.mlp.up_proj\": {\n        \"snr\": 1.7983858585357666,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.2.mlp.up_proj\": {\n        \"snr\": 2.1709094047546387,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.3.mlp.up_proj\": {\n        \"snr\": 2.751326560974121,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.4.mlp.up_proj\": {\n        \"snr\": 3.063521385192871,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.5.mlp.up_proj\": {\n        \"snr\": 2.4026951789855957,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.6.mlp.up_proj\": {\n        \"snr\": 2.3890223503112793,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.7.mlp.up_proj\": {\n        \"snr\": 2.3861353397369385,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.8.mlp.up_proj\": {\n        \"snr\": 2.0745043754577637,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.9.mlp.up_proj\": {\n        \"snr\": 1.8550645112991333,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.10.mlp.up_proj\": {\n        \"snr\": 1.6184496879577637,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.11.mlp.up_proj\": {\n        \"snr\": 1.9287559986114502,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.12.mlp.up_proj\": {\n        \"snr\": 1.7427546977996826,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.13.mlp.up_proj\": {\n        \"snr\": 1.9872609376907349,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.14.mlp.up_proj\": {\n        \"snr\": 2.0224087238311768,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.15.mlp.up_proj\": {\n        \"snr\": 1.7851638793945312,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.16.mlp.up_proj\": {\n        \"snr\": 1.7160604000091553,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.17.mlp.up_proj\": {\n        \"snr\": 1.6870195865631104,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.18.mlp.up_proj\": {\n        \"snr\": 1.6585396528244019,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.19.mlp.up_proj\": {\n        \"snr\": 1.5509096384048462,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.20.mlp.up_proj\": {\n        \"snr\": 1.4310423135757446,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.21.mlp.up_proj\": {\n        \"snr\": 1.5009464025497437,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.22.mlp.up_proj\": {\n        \"snr\": 1.4866929054260254,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.23.mlp.up_proj\": {\n        \"snr\": 1.332513689994812,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.24.mlp.up_proj\": {\n        \"snr\": 1.073512077331543,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.25.mlp.up_proj\": {\n        \"snr\": 0.7472100257873535,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.26.mlp.up_proj\": {\n        \"snr\": 0.4880162179470062,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.27.mlp.up_proj\": {\n        \"snr\": 0.2527681589126587,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.embed_tokens\": {\n        \"snr\": Infinity,\n        \"type\": \"model.embed_tokens\"\n    },\n    \"model.norm\": {\n        \"snr\": Infinity,\n        \"type\": \"model.norm\"\n    },\n    \"model.layers.0.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.1.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.2.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.3.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.4.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.5.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.6.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.7.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.8.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.9.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.10.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.11.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.12.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.13.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.14.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.15.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.16.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.17.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.18.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.19.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.20.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.21.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.22.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.23.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.24.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.25.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.26.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.27.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.0.self_attn.k_proj\": {\n        \"snr\": 0.08262510597705841,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.1.self_attn.k_proj\": {\n        \"snr\": 0.1441459059715271,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.2.self_attn.k_proj\": {\n        \"snr\": 0.21418076753616333,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.3.self_attn.k_proj\": {\n        \"snr\": 0.22496014833450317,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.4.self_attn.k_proj\": {\n        \"snr\": 0.23101305961608887,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.5.self_attn.k_proj\": {\n        \"snr\": 0.23644132912158966,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.6.self_attn.k_proj\": {\n        \"snr\": 0.23666173219680786,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.7.self_attn.k_proj\": {\n        \"snr\": 0.19791515171527863,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.8.self_attn.k_proj\": {\n        \"snr\": 0.22062039375305176,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.9.self_attn.k_proj\": {\n        \"snr\": 0.21218444406986237,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.10.self_attn.k_proj\": {\n        \"snr\": 0.24218571186065674,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.11.self_attn.k_proj\": {\n        \"snr\": 0.21870514750480652,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.12.self_attn.k_proj\": {\n        \"snr\": 0.22160987555980682,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.13.self_attn.k_proj\": {\n        \"snr\": 0.22726823389530182,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.14.self_attn.k_proj\": {\n        \"snr\": 0.20256873965263367,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.15.self_attn.k_proj\": {\n        \"snr\": 0.24100735783576965,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.16.self_attn.k_proj\": {\n        \"snr\": 0.23794010281562805,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.17.self_attn.k_proj\": {\n        \"snr\": 0.2913324534893036,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.18.self_attn.k_proj\": {\n        \"snr\": 0.28093472123146057,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.19.self_attn.k_proj\": {\n        \"snr\": 0.31062793731689453,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.20.self_attn.k_proj\": {\n        \"snr\": 0.2942160367965698,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.21.self_attn.k_proj\": {\n        \"snr\": 0.28014805912971497,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.22.self_attn.k_proj\": {\n        \"snr\": 0.3512437045574188,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.23.self_attn.k_proj\": {\n        \"snr\": 0.2837671637535095,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.24.self_attn.k_proj\": {\n        \"snr\": 0.2960015535354614,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.25.self_attn.k_proj\": {\n        \"snr\": 0.5086414813995361,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.26.self_attn.k_proj\": {\n        \"snr\": 0.24054698646068573,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.27.self_attn.k_proj\": {\n        \"snr\": 0.247616246342659,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.0.self_attn.o_proj\": {\n        \"snr\": 0.18390265107154846,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.1.self_attn.o_proj\": {\n        \"snr\": 0.14759540557861328,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.2.self_attn.o_proj\": {\n        \"snr\": 0.15726515650749207,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.3.self_attn.o_proj\": {\n        \"snr\": 0.16903570294380188,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.4.self_attn.o_proj\": {\n        \"snr\": 0.17953157424926758,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.5.self_attn.o_proj\": {\n        \"snr\": 0.2351229190826416,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.6.self_attn.o_proj\": {\n        \"snr\": 0.22804339230060577,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.7.self_attn.o_proj\": {\n        \"snr\": 0.24786025285720825,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.8.self_attn.o_proj\": {\n        \"snr\": 0.21847976744174957,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.9.self_attn.o_proj\": {\n        \"snr\": 0.2092437595129013,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.10.self_attn.o_proj\": {\n        \"snr\": 0.23278094828128815,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.11.self_attn.o_proj\": {\n        \"snr\": 0.20468176901340485,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.12.self_attn.o_proj\": {\n        \"snr\": 0.2353818416595459,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.13.self_attn.o_proj\": {\n        \"snr\": 0.2702614367008209,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.14.self_attn.o_proj\": {\n        \"snr\": 0.19177420437335968,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.15.self_attn.o_proj\": {\n        \"snr\": 0.18293911218643188,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.16.self_attn.o_proj\": {\n        \"snr\": 0.20286045968532562,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.17.self_attn.o_proj\": {\n        \"snr\": 0.20763878524303436,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.18.self_attn.o_proj\": {\n        \"snr\": 0.190629780292511,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.19.self_attn.o_proj\": {\n        \"snr\": 0.22044304013252258,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.20.self_attn.o_proj\": {\n        \"snr\": 0.21491236984729767,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.21.self_attn.o_proj\": {\n        \"snr\": 0.23289704322814941,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.22.self_attn.o_proj\": {\n        \"snr\": 0.21457163989543915,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.23.self_attn.o_proj\": {\n        \"snr\": 0.1949365884065628,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.24.self_attn.o_proj\": {\n        \"snr\": 0.1606779545545578,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.25.self_attn.o_proj\": {\n        \"snr\": 0.13892440497875214,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.26.self_attn.o_proj\": {\n        \"snr\": 0.1407029926776886,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.27.self_attn.o_proj\": {\n        \"snr\": 0.16027599573135376,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.0.self_attn.q_proj\": {\n        \"snr\": 0.0534212663769722,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.1.self_attn.q_proj\": {\n        \"snr\": 0.06873775273561478,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.2.self_attn.q_proj\": {\n        \"snr\": 0.07522258907556534,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.3.self_attn.q_proj\": {\n        \"snr\": 0.06616844981908798,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.4.self_attn.q_proj\": {\n        \"snr\": 0.06809444725513458,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.5.self_attn.q_proj\": {\n        \"snr\": 0.0758095383644104,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.6.self_attn.q_proj\": {\n        \"snr\": 0.07800278812646866,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.7.self_attn.q_proj\": {\n        \"snr\": 0.07535763084888458,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.8.self_attn.q_proj\": {\n        \"snr\": 0.09488166123628616,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.9.self_attn.q_proj\": {\n        \"snr\": 0.09709945321083069,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.10.self_attn.q_proj\": {\n        \"snr\": 0.09381720423698425,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.11.self_attn.q_proj\": {\n        \"snr\": 0.08205580711364746,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.12.self_attn.q_proj\": {\n        \"snr\": 0.10723169893026352,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.13.self_attn.q_proj\": {\n        \"snr\": 0.10166660696268082,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.14.self_attn.q_proj\": {\n        \"snr\": 0.08822792023420334,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.15.self_attn.q_proj\": {\n        \"snr\": 0.0814041867852211,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.16.self_attn.q_proj\": {\n        \"snr\": 0.07586681097745895,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.17.self_attn.q_proj\": {\n        \"snr\": 0.07040166854858398,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.18.self_attn.q_proj\": {\n        \"snr\": 0.0728282704949379,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.19.self_attn.q_proj\": {\n        \"snr\": 0.06912193447351456,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.20.self_attn.q_proj\": {\n        \"snr\": 0.06646180897951126,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.21.self_attn.q_proj\": {\n        \"snr\": 0.06960278004407883,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.22.self_attn.q_proj\": {\n        \"snr\": 0.06566876918077469,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.23.self_attn.q_proj\": {\n        \"snr\": 0.07412787526845932,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.24.self_attn.q_proj\": {\n        \"snr\": 0.07131384313106537,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.25.self_attn.q_proj\": {\n        \"snr\": 0.07768437266349792,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.26.self_attn.q_proj\": {\n        \"snr\": 0.0809575766324997,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.27.self_attn.q_proj\": {\n        \"snr\": 0.06796683371067047,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.0.self_attn.v_proj\": {\n        \"snr\": 1.4029983282089233,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.1.self_attn.v_proj\": {\n        \"snr\": 3.123720169067383,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.2.self_attn.v_proj\": {\n        \"snr\": 2.4177253246307373,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.3.self_attn.v_proj\": {\n        \"snr\": 5.588768005371094,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.4.self_attn.v_proj\": {\n        \"snr\": 4.395562648773193,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.5.self_attn.v_proj\": {\n        \"snr\": 3.2982685565948486,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.6.self_attn.v_proj\": {\n        \"snr\": 3.2798449993133545,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.7.self_attn.v_proj\": {\n        \"snr\": 2.109200954437256,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.8.self_attn.v_proj\": {\n        \"snr\": 3.229325532913208,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.9.self_attn.v_proj\": {\n        \"snr\": 1.7349927425384521,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.10.self_attn.v_proj\": {\n        \"snr\": 1.5926740169525146,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.11.self_attn.v_proj\": {\n        \"snr\": 1.9097802639007568,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.12.self_attn.v_proj\": {\n        \"snr\": 2.5654332637786865,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.13.self_attn.v_proj\": {\n        \"snr\": 3.536489963531494,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.14.self_attn.v_proj\": {\n        \"snr\": 8.366667747497559,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.15.self_attn.v_proj\": {\n        \"snr\": 7.348303318023682,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.16.self_attn.v_proj\": {\n        \"snr\": 2.815748691558838,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.17.self_attn.v_proj\": {\n        \"snr\": 4.048776149749756,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.18.self_attn.v_proj\": {\n        \"snr\": 4.426101207733154,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.19.self_attn.v_proj\": {\n        \"snr\": 7.098501682281494,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.20.self_attn.v_proj\": {\n        \"snr\": 3.700288772583008,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.21.self_attn.v_proj\": {\n        \"snr\": 2.1859049797058105,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.22.self_attn.v_proj\": {\n        \"snr\": 3.6953284740448,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.23.self_attn.v_proj\": {\n        \"snr\": 11.148802757263184,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.24.self_attn.v_proj\": {\n        \"snr\": 2.4171905517578125,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.25.self_attn.v_proj\": {\n        \"snr\": 4.404144287109375,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.26.self_attn.v_proj\": {\n        \"snr\": 2.340604782104492,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.27.self_attn.v_proj\": {\n        \"snr\": 3.284160614013672,\n        \"type\": \"self_attn.v_proj\"\n    }\n}\n"
  },
  {
    "path": "src/axolotl/integrations/spectrum/model_snr_results/snr_results_meta-llama-Llama-3.2-3B.json",
    "content": "{\n    \"model.layers.0.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.1.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.2.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.3.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.4.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.5.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.6.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.7.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.8.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.9.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.10.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.11.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.12.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.13.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.14.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.15.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.16.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.17.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.18.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.19.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.20.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.21.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.22.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.23.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.24.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.25.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.26.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"model.layers.27.input_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"input_layernorm\"\n    },\n    \"lm_head\": {\n        \"snr\": Infinity,\n        \"type\": \"lm_head\"\n    },\n    \"model.layers.0.mlp.down_proj\": {\n        \"snr\": 2.364603281021118,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.1.mlp.down_proj\": {\n        \"snr\": 2.229910373687744,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.2.mlp.down_proj\": {\n        \"snr\": 1.4312117099761963,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.3.mlp.down_proj\": {\n        \"snr\": 1.3216407299041748,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.4.mlp.down_proj\": {\n        \"snr\": 1.4183496236801147,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.5.mlp.down_proj\": {\n        \"snr\": 1.4453660249710083,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.6.mlp.down_proj\": {\n        \"snr\": 1.4030662775039673,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.7.mlp.down_proj\": {\n        \"snr\": 1.042332649230957,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.8.mlp.down_proj\": {\n        \"snr\": 0.9530982375144958,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.9.mlp.down_proj\": {\n        \"snr\": 0.849862277507782,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.10.mlp.down_proj\": {\n        \"snr\": 0.7704945206642151,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.11.mlp.down_proj\": {\n        \"snr\": 0.8871145844459534,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.12.mlp.down_proj\": {\n        \"snr\": 1.1408143043518066,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.13.mlp.down_proj\": {\n        \"snr\": 1.2769343852996826,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.14.mlp.down_proj\": {\n        \"snr\": 1.1703068017959595,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.15.mlp.down_proj\": {\n        \"snr\": 1.2794467210769653,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.16.mlp.down_proj\": {\n        \"snr\": 1.3154453039169312,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.17.mlp.down_proj\": {\n        \"snr\": 1.5596749782562256,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.18.mlp.down_proj\": {\n        \"snr\": 1.4949405193328857,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.19.mlp.down_proj\": {\n        \"snr\": 1.5329173803329468,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.20.mlp.down_proj\": {\n        \"snr\": 1.4396660327911377,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.21.mlp.down_proj\": {\n        \"snr\": 1.217085838317871,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.22.mlp.down_proj\": {\n        \"snr\": 1.150472640991211,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.23.mlp.down_proj\": {\n        \"snr\": 1.1166225671768188,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.24.mlp.down_proj\": {\n        \"snr\": 0.9966591000556946,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.25.mlp.down_proj\": {\n        \"snr\": 1.0938347578048706,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.26.mlp.down_proj\": {\n        \"snr\": 1.1505423784255981,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.27.mlp.down_proj\": {\n        \"snr\": 1.1156749725341797,\n        \"type\": \"mlp.down_proj\"\n    },\n    \"model.layers.0.mlp.gate_proj\": {\n        \"snr\": 0.9329171776771545,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.1.mlp.gate_proj\": {\n        \"snr\": 0.8513413667678833,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.2.mlp.gate_proj\": {\n        \"snr\": 0.7584061026573181,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.3.mlp.gate_proj\": {\n        \"snr\": 0.65835040807724,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.4.mlp.gate_proj\": {\n        \"snr\": 0.436420738697052,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.5.mlp.gate_proj\": {\n        \"snr\": 0.39712461829185486,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.6.mlp.gate_proj\": {\n        \"snr\": 0.3530206084251404,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.7.mlp.gate_proj\": {\n        \"snr\": 0.34982794523239136,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.8.mlp.gate_proj\": {\n        \"snr\": 0.30338960886001587,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.9.mlp.gate_proj\": {\n        \"snr\": 0.27569833397865295,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.10.mlp.gate_proj\": {\n        \"snr\": 0.28934162855148315,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.11.mlp.gate_proj\": {\n        \"snr\": 0.2929173707962036,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.12.mlp.gate_proj\": {\n        \"snr\": 0.28263387084007263,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.13.mlp.gate_proj\": {\n        \"snr\": 0.27778616547584534,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.14.mlp.gate_proj\": {\n        \"snr\": 0.26527827978134155,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.15.mlp.gate_proj\": {\n        \"snr\": 0.27635642886161804,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.16.mlp.gate_proj\": {\n        \"snr\": 0.35072311758995056,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.17.mlp.gate_proj\": {\n        \"snr\": 0.4002636671066284,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.18.mlp.gate_proj\": {\n        \"snr\": 0.4319891333580017,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.19.mlp.gate_proj\": {\n        \"snr\": 0.47527065873146057,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.20.mlp.gate_proj\": {\n        \"snr\": 0.5112077593803406,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.21.mlp.gate_proj\": {\n        \"snr\": 0.5749644637107849,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.22.mlp.gate_proj\": {\n        \"snr\": 0.5967603921890259,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.23.mlp.gate_proj\": {\n        \"snr\": 0.48045310378074646,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.24.mlp.gate_proj\": {\n        \"snr\": 0.3838970363140106,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.25.mlp.gate_proj\": {\n        \"snr\": 0.3108249604701996,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.26.mlp.gate_proj\": {\n        \"snr\": 0.26704445481300354,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.27.mlp.gate_proj\": {\n        \"snr\": 0.20953254401683807,\n        \"type\": \"mlp.gate_proj\"\n    },\n    \"model.layers.0.mlp.up_proj\": {\n        \"snr\": 1.5084924697875977,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.1.mlp.up_proj\": {\n        \"snr\": 1.7789595127105713,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.2.mlp.up_proj\": {\n        \"snr\": 2.1431775093078613,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.3.mlp.up_proj\": {\n        \"snr\": 2.762744903564453,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.4.mlp.up_proj\": {\n        \"snr\": 3.0324745178222656,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.5.mlp.up_proj\": {\n        \"snr\": 2.3884809017181396,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.6.mlp.up_proj\": {\n        \"snr\": 2.388005256652832,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.7.mlp.up_proj\": {\n        \"snr\": 2.339340925216675,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.8.mlp.up_proj\": {\n        \"snr\": 2.0497021675109863,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.9.mlp.up_proj\": {\n        \"snr\": 1.822119116783142,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.10.mlp.up_proj\": {\n        \"snr\": 1.600373387336731,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.11.mlp.up_proj\": {\n        \"snr\": 1.9298171997070312,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.12.mlp.up_proj\": {\n        \"snr\": 1.728783369064331,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.13.mlp.up_proj\": {\n        \"snr\": 1.965298056602478,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.14.mlp.up_proj\": {\n        \"snr\": 2.023681640625,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.15.mlp.up_proj\": {\n        \"snr\": 1.7721818685531616,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.16.mlp.up_proj\": {\n        \"snr\": 1.7068361043930054,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.17.mlp.up_proj\": {\n        \"snr\": 1.6673219203948975,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.18.mlp.up_proj\": {\n        \"snr\": 1.6240718364715576,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.19.mlp.up_proj\": {\n        \"snr\": 1.5169662237167358,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.20.mlp.up_proj\": {\n        \"snr\": 1.4018198251724243,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.21.mlp.up_proj\": {\n        \"snr\": 1.4556466341018677,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.22.mlp.up_proj\": {\n        \"snr\": 1.4304454326629639,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.23.mlp.up_proj\": {\n        \"snr\": 1.2785290479660034,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.24.mlp.up_proj\": {\n        \"snr\": 1.023495078086853,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.25.mlp.up_proj\": {\n        \"snr\": 0.6992124915122986,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.26.mlp.up_proj\": {\n        \"snr\": 0.4549211859703064,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.layers.27.mlp.up_proj\": {\n        \"snr\": 0.23889905214309692,\n        \"type\": \"mlp.up_proj\"\n    },\n    \"model.embed_tokens\": {\n        \"snr\": Infinity,\n        \"type\": \"model.embed_tokens\"\n    },\n    \"model.norm\": {\n        \"snr\": Infinity,\n        \"type\": \"model.norm\"\n    },\n    \"model.layers.0.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.1.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.2.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.3.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.4.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.5.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.6.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.7.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.8.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.9.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.10.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.11.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.12.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.13.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.14.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.15.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.16.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.17.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.18.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.19.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.20.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.21.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.22.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.23.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.24.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.25.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.26.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.27.post_attention_layernorm\": {\n        \"snr\": Infinity,\n        \"type\": \"post_attention_layernorm\"\n    },\n    \"model.layers.0.self_attn.k_proj\": {\n        \"snr\": 0.08150045573711395,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.1.self_attn.k_proj\": {\n        \"snr\": 0.1428358554840088,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.2.self_attn.k_proj\": {\n        \"snr\": 0.2096949815750122,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.3.self_attn.k_proj\": {\n        \"snr\": 0.22633400559425354,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.4.self_attn.k_proj\": {\n        \"snr\": 0.2293967455625534,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.5.self_attn.k_proj\": {\n        \"snr\": 0.23336802423000336,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.6.self_attn.k_proj\": {\n        \"snr\": 0.23429904878139496,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.7.self_attn.k_proj\": {\n        \"snr\": 0.19610290229320526,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.8.self_attn.k_proj\": {\n        \"snr\": 0.2163258045911789,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.9.self_attn.k_proj\": {\n        \"snr\": 0.21039333939552307,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.10.self_attn.k_proj\": {\n        \"snr\": 0.23533931374549866,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.11.self_attn.k_proj\": {\n        \"snr\": 0.21457058191299438,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.12.self_attn.k_proj\": {\n        \"snr\": 0.21686571836471558,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.13.self_attn.k_proj\": {\n        \"snr\": 0.22398065030574799,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.14.self_attn.k_proj\": {\n        \"snr\": 0.20160657167434692,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.15.self_attn.k_proj\": {\n        \"snr\": 0.23705022037029266,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.16.self_attn.k_proj\": {\n        \"snr\": 0.23254962265491486,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.17.self_attn.k_proj\": {\n        \"snr\": 0.2892642617225647,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.18.self_attn.k_proj\": {\n        \"snr\": 0.27587130665779114,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.19.self_attn.k_proj\": {\n        \"snr\": 0.30891212821006775,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.20.self_attn.k_proj\": {\n        \"snr\": 0.28997519612312317,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.21.self_attn.k_proj\": {\n        \"snr\": 0.27534863352775574,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.22.self_attn.k_proj\": {\n        \"snr\": 0.35139667987823486,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.23.self_attn.k_proj\": {\n        \"snr\": 0.2773109972476959,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.24.self_attn.k_proj\": {\n        \"snr\": 0.2853511571884155,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.25.self_attn.k_proj\": {\n        \"snr\": 0.5030262470245361,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.26.self_attn.k_proj\": {\n        \"snr\": 0.2317112237215042,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.27.self_attn.k_proj\": {\n        \"snr\": 0.24419328570365906,\n        \"type\": \"self_attn.k_proj\"\n    },\n    \"model.layers.0.self_attn.o_proj\": {\n        \"snr\": 0.17767645418643951,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.1.self_attn.o_proj\": {\n        \"snr\": 0.14102177321910858,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.2.self_attn.o_proj\": {\n        \"snr\": 0.1523692011833191,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.3.self_attn.o_proj\": {\n        \"snr\": 0.16522075235843658,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.4.self_attn.o_proj\": {\n        \"snr\": 0.17483487725257874,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.5.self_attn.o_proj\": {\n        \"snr\": 0.227921262383461,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.6.self_attn.o_proj\": {\n        \"snr\": 0.2196175903081894,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.7.self_attn.o_proj\": {\n        \"snr\": 0.24270132184028625,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.8.self_attn.o_proj\": {\n        \"snr\": 0.2118290364742279,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.9.self_attn.o_proj\": {\n        \"snr\": 0.20525991916656494,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.10.self_attn.o_proj\": {\n        \"snr\": 0.22847208380699158,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.11.self_attn.o_proj\": {\n        \"snr\": 0.19665324687957764,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.12.self_attn.o_proj\": {\n        \"snr\": 0.23233532905578613,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.13.self_attn.o_proj\": {\n        \"snr\": 0.2624332308769226,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.14.self_attn.o_proj\": {\n        \"snr\": 0.1868327558040619,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.15.self_attn.o_proj\": {\n        \"snr\": 0.17706255614757538,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.16.self_attn.o_proj\": {\n        \"snr\": 0.19422705471515656,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.17.self_attn.o_proj\": {\n        \"snr\": 0.2000615894794464,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.18.self_attn.o_proj\": {\n        \"snr\": 0.1874573826789856,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.19.self_attn.o_proj\": {\n        \"snr\": 0.21297843754291534,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.20.self_attn.o_proj\": {\n        \"snr\": 0.2100859135389328,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.21.self_attn.o_proj\": {\n        \"snr\": 0.22561520338058472,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.22.self_attn.o_proj\": {\n        \"snr\": 0.20994484424591064,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.23.self_attn.o_proj\": {\n        \"snr\": 0.18978221714496613,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.24.self_attn.o_proj\": {\n        \"snr\": 0.1571759581565857,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.25.self_attn.o_proj\": {\n        \"snr\": 0.1349896937608719,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.26.self_attn.o_proj\": {\n        \"snr\": 0.1368866115808487,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.27.self_attn.o_proj\": {\n        \"snr\": 0.1571887582540512,\n        \"type\": \"self_attn.o_proj\"\n    },\n    \"model.layers.0.self_attn.q_proj\": {\n        \"snr\": 0.05295897275209427,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.1.self_attn.q_proj\": {\n        \"snr\": 0.06835605204105377,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.2.self_attn.q_proj\": {\n        \"snr\": 0.0746372863650322,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.3.self_attn.q_proj\": {\n        \"snr\": 0.06615085154771805,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.4.self_attn.q_proj\": {\n        \"snr\": 0.06788161396980286,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.5.self_attn.q_proj\": {\n        \"snr\": 0.07514483481645584,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.6.self_attn.q_proj\": {\n        \"snr\": 0.07777862250804901,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.7.self_attn.q_proj\": {\n        \"snr\": 0.07534090429544449,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.8.self_attn.q_proj\": {\n        \"snr\": 0.09494179487228394,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.9.self_attn.q_proj\": {\n        \"snr\": 0.09699037671089172,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.10.self_attn.q_proj\": {\n        \"snr\": 0.09426294267177582,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.11.self_attn.q_proj\": {\n        \"snr\": 0.08260341733694077,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.12.self_attn.q_proj\": {\n        \"snr\": 0.10650420933961868,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.13.self_attn.q_proj\": {\n        \"snr\": 0.10250870138406754,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.14.self_attn.q_proj\": {\n        \"snr\": 0.08775162696838379,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.15.self_attn.q_proj\": {\n        \"snr\": 0.08071447163820267,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.16.self_attn.q_proj\": {\n        \"snr\": 0.07530857622623444,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.17.self_attn.q_proj\": {\n        \"snr\": 0.06964966654777527,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.18.self_attn.q_proj\": {\n        \"snr\": 0.07150755077600479,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.19.self_attn.q_proj\": {\n        \"snr\": 0.0676807165145874,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.20.self_attn.q_proj\": {\n        \"snr\": 0.06511317938566208,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.21.self_attn.q_proj\": {\n        \"snr\": 0.06773187220096588,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.22.self_attn.q_proj\": {\n        \"snr\": 0.06400436162948608,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.23.self_attn.q_proj\": {\n        \"snr\": 0.0726117342710495,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.24.self_attn.q_proj\": {\n        \"snr\": 0.06882446259260178,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.25.self_attn.q_proj\": {\n        \"snr\": 0.07506493479013443,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.26.self_attn.q_proj\": {\n        \"snr\": 0.07797915488481522,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.27.self_attn.q_proj\": {\n        \"snr\": 0.06680692732334137,\n        \"type\": \"self_attn.q_proj\"\n    },\n    \"model.layers.0.self_attn.v_proj\": {\n        \"snr\": 1.326789379119873,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.1.self_attn.v_proj\": {\n        \"snr\": 3.043806791305542,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.2.self_attn.v_proj\": {\n        \"snr\": 2.295107841491699,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.3.self_attn.v_proj\": {\n        \"snr\": 5.2584614753723145,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.4.self_attn.v_proj\": {\n        \"snr\": 4.038785934448242,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.5.self_attn.v_proj\": {\n        \"snr\": 3.0907773971557617,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.6.self_attn.v_proj\": {\n        \"snr\": 3.114994525909424,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.7.self_attn.v_proj\": {\n        \"snr\": 1.9747973680496216,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.8.self_attn.v_proj\": {\n        \"snr\": 3.0469374656677246,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.9.self_attn.v_proj\": {\n        \"snr\": 1.602966547012329,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.10.self_attn.v_proj\": {\n        \"snr\": 1.489019513130188,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.11.self_attn.v_proj\": {\n        \"snr\": 1.7490826845169067,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.12.self_attn.v_proj\": {\n        \"snr\": 2.451310396194458,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.13.self_attn.v_proj\": {\n        \"snr\": 3.250821590423584,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.14.self_attn.v_proj\": {\n        \"snr\": 7.944663047790527,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.15.self_attn.v_proj\": {\n        \"snr\": 7.013208389282227,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.16.self_attn.v_proj\": {\n        \"snr\": 2.68644118309021,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.17.self_attn.v_proj\": {\n        \"snr\": 3.9063122272491455,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.18.self_attn.v_proj\": {\n        \"snr\": 4.1816816329956055,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.19.self_attn.v_proj\": {\n        \"snr\": 6.794488906860352,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.20.self_attn.v_proj\": {\n        \"snr\": 3.401334285736084,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.21.self_attn.v_proj\": {\n        \"snr\": 2.051994562149048,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.22.self_attn.v_proj\": {\n        \"snr\": 3.614379405975342,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.23.self_attn.v_proj\": {\n        \"snr\": 11.180968284606934,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.24.self_attn.v_proj\": {\n        \"snr\": 2.3629775047302246,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.25.self_attn.v_proj\": {\n        \"snr\": 4.137593746185303,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.26.self_attn.v_proj\": {\n        \"snr\": 2.3465518951416016,\n        \"type\": \"self_attn.v_proj\"\n    },\n    \"model.layers.27.self_attn.v_proj\": {\n        \"snr\": 3.10064697265625,\n        \"type\": \"self_attn.v_proj\"\n    }\n}\n"
  },
  {
    "path": "src/axolotl/integrations/swanlab/README.md",
    "content": "# SwanLab Integration for Axolotl\n\nSwanLab is an open-source, lightweight AI experiment tracking and visualization tool that provides a platform for tracking, recording, comparing, and collaborating on experiments.\n\nThis integration enables seamless experiment tracking and visualization of Axolotl training runs using SwanLab.\n\n## Features\n\n- 📊 **Automatic Metrics Logging**: Training loss, learning rate, and other metrics are automatically logged\n- 🎯 **Hyperparameter Tracking**: Model configuration and training parameters are tracked\n- 📈 **Real-time Visualization**: Monitor training progress in real-time through SwanLab dashboard\n- ☁️ **Cloud & Local Support**: Works in both cloud-synced and offline modes\n- 🔄 **Experiment Comparison**: Compare multiple training runs easily\n- 🤝 **Team Collaboration**: Share experiments with team members\n- 🎭 **RLHF Completion Logging**: Automatically log model outputs during DPO/KTO/ORPO/GRPO training for qualitative analysis\n- ⚡ **Performance Profiling**: Built-in profiling decorators to measure and optimize training performance\n- 🔔 **Lark Notifications**: Send real-time training updates to team chat (Feishu/Lark integration)\n\n## Installation\n\n```bash\npip install swanlab\n```\n\n## Quick Start\n\n### 1. Register for SwanLab (Optional for cloud mode)\n\nIf you want to use cloud sync features, register at [https://swanlab.cn](https://swanlab.cn) to get your API key.\n\n### 2. Configure Axolotl Config File\n\nAdd SwanLab configuration to your Axolotl YAML config:\n\n```yaml\n# Enable SwanLab plugin\nplugins:\n  - axolotl.integrations.swanlab.SwanLabPlugin\n\n# SwanLab configuration\nuse_swanlab: true\nswanlab_project: my-llm-project\nswanlab_experiment_name: qwen-finetune-v1\nswanlab_mode: cloud  # Options: cloud, local, offline, disabled\nswanlab_workspace: my-team  # Optional: organization name\nswanlab_api_key: YOUR_API_KEY  # Optional: can also use env var SWANLAB_API_KEY\n```\n\n### 3. Run Training\n\n```bash\n# Set API key via environment variable (recommended)\nexport SWANLAB_API_KEY=your-api-key-here\n\n# Or login once\nswanlab login\n\n# Run training as usual\naccelerate launch -m axolotl.cli.train your-config.yaml\n```\n\n## Configuration Options\n\n### Basic Configuration\n\n| Parameter | Type | Default | Description |\n|-----------|------|---------|-------------|\n| `use_swanlab` | bool | `false` | Enable SwanLab tracking |\n| `swanlab_project` | str | `None` | Project name (required) |\n| `swanlab_experiment_name` | str | `None` | Experiment name |\n| `swanlab_description` | str | `None` | Experiment description |\n| `swanlab_mode` | str | `cloud` | Sync mode: `cloud`, `local`, `offline`, `disabled` |\n\n### Advanced Configuration\n\n| Parameter | Type | Default | Description |\n|-----------|------|---------|-------------|\n| `swanlab_workspace` | str | `None` | Workspace/organization name |\n| `swanlab_api_key` | str | `None` | API key (prefer env var) |\n| `swanlab_web_host` | str | `None` | Private deployment web host |\n| `swanlab_api_host` | str | `None` | Private deployment API host |\n| `swanlab_log_model` | bool | `false` | Log model checkpoints (coming soon) |\n| `swanlab_lark_webhook_url` | str | `None` | Lark (Feishu) webhook URL for team notifications |\n| `swanlab_lark_secret` | str | `None` | Lark webhook HMAC secret for authentication |\n| `swanlab_log_completions` | bool | `true` | Enable RLHF completion table logging (DPO/KTO/ORPO/GRPO) |\n| `swanlab_completion_log_interval` | int | `100` | Steps between completion logging |\n| `swanlab_completion_max_buffer` | int | `128` | Max completions to buffer (memory bound) |\n\n## Configuration Examples\n\n### Example 1: Basic Cloud Sync\n\n```yaml\nplugins:\n  - axolotl.integrations.swanlab.SwanLabPlugin\n\nuse_swanlab: true\nswanlab_project: llama-finetune\nswanlab_experiment_name: llama-3-8b-instruct-v1\nswanlab_mode: cloud\n```\n\n### Example 2: Offline/Local Mode\n\n```yaml\nplugins:\n  - axolotl.integrations.swanlab.SwanLabPlugin\n\nuse_swanlab: true\nswanlab_project: local-experiments\nswanlab_experiment_name: test-run-1\nswanlab_mode: local  # or 'offline'\n```\n\n### Example 3: Team Workspace\n\n```yaml\nplugins:\n  - axolotl.integrations.swanlab.SwanLabPlugin\n\nuse_swanlab: true\nswanlab_project: research-project\nswanlab_experiment_name: experiment-42\nswanlab_workspace: my-research-team\nswanlab_mode: cloud\n```\n\n### Example 4: Private Deployment\n\n```yaml\nplugins:\n  - axolotl.integrations.swanlab.SwanLabPlugin\n\nuse_swanlab: true\nswanlab_project: internal-project\nswanlab_experiment_name: secure-training\nswanlab_mode: cloud\nswanlab_web_host: https://swanlab.yourcompany.com\nswanlab_api_host: https://api.swanlab.yourcompany.com\n```\n\n## Team Notifications with Lark (Feishu)\n\nSwanLab supports sending real-time training notifications to your team chat via Lark (Feishu), ByteDance's enterprise collaboration platform. This is especially useful for:\n- **Production training monitoring**: Get alerts when training starts, completes, or encounters errors\n- **Team collaboration**: Keep your ML team informed about long-running experiments\n- **Multi-timezone teams**: Team members can check training progress without being online\n\n### Prerequisites\n\n1. **Lark Bot Setup**: Create a custom bot in your Lark group chat\n2. **Webhook URL**: Get the webhook URL from your Lark bot settings\n3. **HMAC Secret** (recommended): Enable signature verification in your Lark bot for security\n\nFor detailed Lark bot setup instructions, see [Lark Custom Bot Documentation](https://open.feishu.cn/document/ukTMukTMukTM/ucTM5YjL3ETO24yNxkjN).\n\n### Example 5: Basic Lark Notifications\n\nSend training notifications to a Lark group chat:\n\n```yaml\nplugins:\n  - axolotl.integrations.swanlab.SwanLabPlugin\n\nuse_swanlab: true\nswanlab_project: production-training\nswanlab_experiment_name: llama-3-finetune-v2\nswanlab_mode: cloud\n\n# Lark notification (basic, no HMAC verification)\nswanlab_lark_webhook_url: https://open.feishu.cn/open-apis/bot/v2/hook/xxxxxxxxxx\n```\n\n**Note**: This configuration will work, but you'll see a security warning recommending HMAC secret configuration.\n\n### Example 6: Lark Notifications with HMAC Security (Recommended)\n\nFor production use, enable HMAC signature verification:\n\n```yaml\nplugins:\n  - axolotl.integrations.swanlab.SwanLabPlugin\n\nuse_swanlab: true\nswanlab_project: production-training\nswanlab_experiment_name: llama-3-finetune-v2\nswanlab_mode: cloud\n\n# Lark notification with HMAC authentication\nswanlab_lark_webhook_url: https://open.feishu.cn/open-apis/bot/v2/hook/xxxxxxxxxx\nswanlab_lark_secret: your-webhook-secret-key\n```\n\n**Why HMAC secret matters**:\n- Prevents unauthorized parties from sending fake notifications to your Lark group\n- Ensures notifications genuinely come from your training jobs\n- Required for production deployments with sensitive training data\n\n### Example 7: Team Workspace + Lark Notifications\n\nCombine team workspace collaboration with Lark notifications:\n\n```yaml\nplugins:\n  - axolotl.integrations.swanlab.SwanLabPlugin\n\nuse_swanlab: true\nswanlab_project: research-project\nswanlab_experiment_name: multimodal-experiment-42\nswanlab_workspace: ml-research-team\nswanlab_mode: cloud\n\n# Notify team via Lark when training starts/completes\nswanlab_lark_webhook_url: https://open.feishu.cn/open-apis/bot/v2/hook/xxxxxxxxxx\nswanlab_lark_secret: your-webhook-secret-key\n```\n\n### What Notifications Are Sent?\n\nSwanLab's Lark integration sends notifications for key training events:\n- **Training Start**: When your experiment begins\n- **Training Complete**: When training finishes successfully\n- **Training Errors**: If training crashes or encounters critical errors\n- **Metric Milestones**: Configurable alerts for metric thresholds (if configured in SwanLab)\n\nEach notification includes:\n- Experiment name and project\n- Training status\n- Key metrics (loss, learning rate)\n- Direct link to SwanLab dashboard\n\n### Lark Configuration Validation\n\nThe plugin validates your Lark configuration at startup:\n\n#### ✅ Valid Configurations\n\n```yaml\n# Option 1: No Lark (default)\nuse_swanlab: true\nswanlab_project: my-project\n# No swanlab_lark_webhook_url → Lark disabled, no warnings\n\n# Option 2: Lark with HMAC secret (recommended)\nuse_swanlab: true\nswanlab_project: my-project\nswanlab_lark_webhook_url: https://open.feishu.cn/open-apis/bot/v2/hook/xxx\nswanlab_lark_secret: your-secret\n# ✅ Logs: \"Registered Lark notification callback with HMAC authentication\"\n\n# Option 3: Lark without secret (works but not recommended)\nuse_swanlab: true\nswanlab_project: my-project\nswanlab_lark_webhook_url: https://open.feishu.cn/open-apis/bot/v2/hook/xxx\n# ⚠️ Logs: \"Registered Lark notification callback (no HMAC secret)\"\n# ⚠️ Warning: \"Lark webhook has no secret configured. For production use, set 'swanlab_lark_secret'...\"\n```\n\n### Security Best Practices\n\n1. **Always use HMAC secret in production**:\n   ```yaml\n   swanlab_lark_webhook_url: https://open.feishu.cn/...\n   swanlab_lark_secret: your-secret-key  # ✅ Add this!\n   ```\n\n2. **Store secrets in environment variables** (even better):\n   ```yaml\n   # In your training script/environment\n   export SWANLAB_LARK_WEBHOOK_URL=\"https://open.feishu.cn/...\"\n   export SWANLAB_LARK_SECRET=\"your-secret-key\"\n   ```\n\n   Then in config:\n   ```yaml\n   # SwanLab plugin will auto-detect environment variables\n   use_swanlab: true\n   swanlab_project: my-project\n   # Lark URL and secret read from env vars\n   ```\n\n3. **Rotate webhook secrets periodically**: Update your Lark bot's secret every 90 days\n\n4. **Use separate webhooks for dev/prod**: Don't mix development and production notifications\n\n### Distributed Training\n\nLark notifications are automatically deduplicated in distributed training:\n- Only **rank 0** sends notifications\n- Other GPU ranks skip Lark registration\n- Prevents duplicate messages in multi-GPU training\n\n```bash\n# Running on 4 GPUs\ntorchrun --nproc_per_node=4 -m axolotl.cli.train config.yml\n\n# Expected logs:\n# [Rank 0] Registered Lark notification callback with HMAC authentication\n# [Rank 1-3] (no Lark registration messages)\n```\n\n## RLHF Completion Table Logging\n\nFor RLHF (Reinforcement Learning from Human Feedback) training methods like DPO, KTO, ORPO, and GRPO, SwanLab can log model completions (prompts, chosen/rejected responses, rewards) to a visual table for qualitative analysis. This helps you:\n\n- **Inspect model behavior**: See actual model outputs during training\n- **Debug preference learning**: Compare chosen vs rejected responses\n- **Track reward patterns**: Monitor how rewards evolve over training\n- **Share examples with team**: Visual tables in SwanLab dashboard\n\n### Features\n\n- ✅ **Automatic detection**: Works with DPO, KTO, ORPO, GRPO trainers\n- ✅ **Memory-safe buffering**: Bounded buffer prevents memory leaks in long training runs\n- ✅ **Periodic logging**: Configurable logging interval to reduce overhead\n- ✅ **Rich visualization**: SwanLab tables show prompts, responses, and metrics side-by-side\n\n### Configuration\n\n| Parameter | Type | Default | Description |\n|-----------|------|---------|-------------|\n| `swanlab_log_completions` | bool | `true` | Enable completion logging for RLHF trainers |\n| `swanlab_completion_log_interval` | int | `100` | Log completions to SwanLab every N training steps |\n| `swanlab_completion_max_buffer` | int | `128` | Maximum completions to buffer (memory bound) |\n\n### Example: DPO Training with Completion Logging\n\n```yaml\nplugins:\n  - axolotl.integrations.swanlab.SwanLabPlugin\n\nuse_swanlab: true\nswanlab_project: dpo-training\nswanlab_experiment_name: llama-3-dpo-v1\nswanlab_mode: cloud\n\n# RLHF completion logging (enabled by default)\nswanlab_log_completions: true\nswanlab_completion_log_interval: 100  # Log every 100 steps\nswanlab_completion_max_buffer: 128    # Keep last 128 completions\n\n# DPO-specific config\nrl: dpo\ndatasets:\n  - path: /path/to/preference_dataset\n    type: chatml.intel\n```\n\n### Example: Disable Completion Logging\n\nIf you're doing a quick test run or don't need completion tables:\n\n```yaml\nplugins:\n  - axolotl.integrations.swanlab.SwanLabPlugin\n\nuse_swanlab: true\nswanlab_project: dpo-training\n\n# Disable completion logging\nswanlab_log_completions: false\n```\n\n### Supported RLHF Trainers\n\nThe completion logging callback automatically activates for these trainer types:\n\n- **DPO (Direct Preference Optimization)**: Logs prompts, chosen, rejected, reward_diff\n- **KTO (Kahneman-Tversky Optimization)**: Logs prompts, completions, labels, rewards\n- **ORPO (Odds Ratio Preference Optimization)**: Logs prompts, chosen, rejected, log_odds_ratio\n- **GRPO (Group Relative Policy Optimization)**: Logs prompts, completions, rewards, advantages\n- **CPO (Constrained Policy Optimization)**: Logs prompts, chosen, rejected\n\nFor non-RLHF trainers (standard supervised fine-tuning), the completion callback is automatically skipped.\n\n### How It Works\n\n1. **Auto-detection**: Plugin detects trainer type at initialization\n2. **Buffering**: Completions are buffered in memory (up to `swanlab_completion_max_buffer`)\n3. **Periodic logging**: Every `swanlab_completion_log_interval` steps, buffer is logged to SwanLab\n4. **Memory safety**: Old completions are automatically dropped when buffer is full (uses `collections.deque`)\n5. **Final flush**: Remaining completions are logged when training completes\n\n### Viewing Completion Tables\n\nAfter training starts, you can view completion tables in your SwanLab dashboard:\n\n1. Navigate to your experiment in SwanLab\n2. Look for the \"rlhf_completions\" table in the metrics panel\n3. The table shows:\n   - **step**: Training step when completion was generated\n   - **prompt**: Input prompt\n   - **chosen**: Preferred response (DPO/ORPO)\n   - **rejected**: Non-preferred response (DPO/ORPO)\n   - **completion**: Model output (KTO/GRPO)\n   - **reward_diff/reward**: Reward metrics\n   - Trainer-specific metrics (e.g., log_odds_ratio for ORPO)\n\n### Memory Management\n\nThe completion buffer is **memory-bounded** to prevent memory leaks:\n\n```python\n# Internal implementation uses deque with maxlen\nfrom collections import deque\n\nbuffer = deque(maxlen=128)  # Old completions automatically dropped\n```\n\n**Memory usage estimate**:\n- Average completion: ~500 characters (prompt + responses)\n- Buffer size 128: ~64 KB (negligible)\n- Buffer size 1024: ~512 KB (still small)\n\n**Recommendation**: Default buffer size (128) works well for most cases. Increase to 512-1024 only if you need to review more historical completions.\n\n### Performance Impact\n\nCompletion logging has minimal overhead:\n\n- **Buffering**: O(1) append operation, negligible CPU/memory\n- **Logging**: Only happens every N steps (default: 100)\n- **Network**: SwanLab batches table uploads efficiently\n\n**Expected overhead**: < 0.5% per training step\n\n### Troubleshooting\n\n#### Completions not appearing in SwanLab\n\n**Cause**: Trainer may not be logging completion data in the expected format.\n\n**Diagnostic steps**:\n1. Check trainer type detection in logs:\n   ```text\n   INFO: SwanLab RLHF completion logging enabled for DPOTrainer (type: dpo)\n   ```\n2. Verify your trainer is an RLHF trainer (DPO/KTO/ORPO/GRPO)\n3. Check if trainer logs completion data (this depends on TRL version)\n\n**Note**: The current implementation expects trainers to log completion data in the `logs` dict during `on_log()` callback. Some TRL trainers may not expose this data by default. You may need to patch the trainer to expose completions.\n\n#### Buffer fills up too quickly\n\n**Cause**: High logging frequency with small buffer size.\n\n**Solution**: Increase buffer size or logging interval:\n```yaml\nswanlab_completion_log_interval: 200  # Log less frequently\nswanlab_completion_max_buffer: 512    # Larger buffer\n```\n\n#### Memory usage growing over time\n\n**Cause**: Buffer should be bounded, so this indicates a bug.\n\n**Solution**:\n1. Verify `swanlab_completion_max_buffer` is set\n2. Check SwanLab version is up to date\n3. Report issue with memory profiling data\n\n## Performance Profiling\n\nSwanLab integration includes profiling utilities to measure and log execution time of trainer methods. This helps you:\n\n- **Identify bottlenecks**: Find slow operations in your training loop\n- **Optimize performance**: Track improvements after optimization changes\n- **Monitor distributed training**: See per-rank timing differences\n- **Debug hangs**: Detect methods that take unexpectedly long\n\n### Features\n\n- ✅ **Zero-config profiling**: Automatic timing of key trainer methods\n- ✅ **Decorator-based**: Easy to add profiling to custom methods with `@swanlab_profile`\n- ✅ **Context manager**: Fine-grained profiling with `swanlab_profiling_context()`\n- ✅ **Advanced filtering**: `ProfilingConfig` for throttling and minimum duration thresholds\n- ✅ **Exception-safe**: Logs duration even if function raises an exception\n\n### Basic Usage: Decorator\n\nAdd profiling to any trainer method with the `@swanlab_profile` decorator:\n\n```python\nfrom axolotl.integrations.swanlab.profiling import swanlab_profile\n\nclass MyCustomTrainer(AxolotlTrainer):\n    @swanlab_profile\n    def training_step(self, model, inputs):\n        # Your training step logic\n        return super().training_step(model, inputs)\n\n    @swanlab_profile\n    def prediction_step(self, model, inputs, prediction_loss_only):\n        # Your prediction logic\n        return super().prediction_step(model, inputs, prediction_loss_only)\n```\n\nThe decorator automatically:\n1. Measures execution time with high-precision timer\n2. Logs to SwanLab as `profiling/Time taken: ClassName.method_name`\n3. Only logs if SwanLab is enabled (`use_swanlab: true`)\n4. Gracefully handles exceptions (logs duration, then re-raises)\n\n### Advanced Usage: Context Manager\n\nFor fine-grained profiling within a method:\n\n```python\nfrom axolotl.integrations.swanlab.profiling import swanlab_profiling_context\n\nclass MyTrainer(AxolotlTrainer):\n    def complex_training_step(self, model, inputs):\n        # Profile just the forward pass\n        with swanlab_profiling_context(self, \"forward_pass\"):\n            outputs = model(**inputs)\n\n        # Profile just the backward pass\n        with swanlab_profiling_context(self, \"backward_pass\"):\n            loss = outputs.loss\n            loss.backward()\n\n        return outputs\n```\n\n### Advanced Usage: ProfilingConfig\n\nFilter and throttle profiling logs with `ProfilingConfig`:\n\n```python\nfrom axolotl.integrations.swanlab.profiling import (\n    swanlab_profiling_context_advanced,\n    ProfilingConfig,\n)\n\n# Create custom profiling config\nprofiling_config = ProfilingConfig(\n    enabled=True,\n    min_duration_ms=1.0,    # Only log if duration > 1ms\n    log_interval=10,        # Log every 10th call\n)\n\nclass MyTrainer(AxolotlTrainer):\n    def frequently_called_method(self, data):\n        with swanlab_profiling_context_advanced(\n            self,\n            \"frequent_op\",\n            config=profiling_config\n        ):\n            # This only logs every 10th call, and only if it takes > 1ms\n            result = expensive_computation(data)\n        return result\n```\n\n**ProfilingConfig Parameters**:\n- `enabled`: Enable/disable profiling globally (default: `True`)\n- `min_duration_ms`: Minimum duration to log in milliseconds (default: `0.1`)\n- `log_interval`: Log every Nth function call (default: `1` = log all)\n\n**Use cases**:\n- **High-frequency methods**: Use `log_interval=100` to reduce logging overhead\n- **Filter noise**: Use `min_duration_ms=1.0` to skip very fast operations\n- **Debugging**: Use `log_interval=1, min_duration_ms=0.0` to log everything\n\n### Viewing Profiling Metrics\n\nIn your SwanLab dashboard, profiling metrics appear under the \"profiling\" namespace:\n\n```text\nprofiling/Time taken: AxolotlTrainer.training_step\nprofiling/Time taken: AxolotlTrainer.prediction_step\nprofiling/Time taken: MyTrainer.forward_pass\nprofiling/Time taken: MyTrainer.backward_pass\n```\n\nYou can:\n- **Track over time**: See if methods get faster/slower during training\n- **Compare runs**: Compare profiling metrics across experiments\n- **Identify regressions**: Detect if a code change slowed down training\n\n### Configuration in Axolotl Config\n\nProfiling is automatically enabled when SwanLab is enabled. No additional config needed:\n\n```yaml\nplugins:\n  - axolotl.integrations.swanlab.SwanLabPlugin\n\nuse_swanlab: true\nswanlab_project: my-project\n\n# Profiling is automatically enabled\n# Add @swanlab_profile decorators to your custom trainer methods\n```\n\nTo disable profiling while keeping SwanLab enabled:\n\n```python\n# In your custom trainer code\nfrom axolotl.integrations.swanlab.profiling import DEFAULT_PROFILING_CONFIG\n\n# Disable profiling globally\nDEFAULT_PROFILING_CONFIG.enabled = False\n```\n\n### Performance Impact\n\n- **Decorator overhead**: ~2-5 microseconds per call (negligible)\n- **Context manager overhead**: ~1-3 microseconds (negligible)\n- **Logging overhead**: Only when SwanLab is enabled and method duration exceeds threshold\n- **Network overhead**: SwanLab batches metrics efficiently\n\n**Expected overhead**: < 0.1% per training step (effectively zero)\n\n### Best Practices\n\n1. **Profile bottlenecks first**: Start by profiling suspected slow operations\n2. **Use min_duration_ms**: Filter out fast operations (< 1ms) to reduce noise\n3. **Throttle high-frequency calls**: Use `log_interval` for methods called > 100 times/step\n4. **Profile across runs**: Compare profiling metrics before/after optimization\n5. **Monitor distributed training**: Check for rank-specific slowdowns\n\n### Example: Complete Profiling Setup\n\n```python\nfrom axolotl.integrations.swanlab.profiling import (\n    swanlab_profile,\n    swanlab_profiling_context,\n    ProfilingConfig,\n)\n\nclass OptimizedTrainer(AxolotlTrainer):\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n\n        # Custom profiling config for high-frequency operations\n        self.fast_op_config = ProfilingConfig(\n            enabled=True,\n            min_duration_ms=0.5,\n            log_interval=50,\n        )\n\n    @swanlab_profile\n    def training_step(self, model, inputs):\n        \"\"\"Main training step - always profile.\"\"\"\n        return super().training_step(model, inputs)\n\n    @swanlab_profile\n    def compute_loss(self, model, inputs, return_outputs=False):\n        \"\"\"Loss computation - always profile.\"\"\"\n        return super().compute_loss(model, inputs, return_outputs)\n\n    def _prepare_inputs(self, inputs):\n        \"\"\"High-frequency operation - throttled profiling.\"\"\"\n        with swanlab_profiling_context_advanced(\n            self,\n            \"prepare_inputs\",\n            config=self.fast_op_config,\n        ):\n            return super()._prepare_inputs(inputs)\n```\n\n### Troubleshooting\n\n#### Profiling metrics not appearing in SwanLab\n\n**Cause**: SwanLab is not enabled or not initialized.\n\n**Solution**:\n```yaml\n# Ensure SwanLab is enabled\nuse_swanlab: true\nswanlab_project: my-project\n```\n\nCheck logs for:\n```text\nINFO: SwanLab initialized for project: my-project\n```\n\n#### Too many profiling metrics cluttering dashboard\n\n**Cause**: Profiling every function call for high-frequency operations.\n\n**Solution**: Use `ProfilingConfig` with throttling:\n```python\nconfig = ProfilingConfig(\n    min_duration_ms=1.0,    # Skip fast ops\n    log_interval=100,       # Log every 100th call\n)\n```\n\n#### Profiling overhead impacting training speed\n\n**Cause**: Profiling itself should have negligible overhead (< 0.1%). If you see > 1% slowdown, this indicates a bug.\n\n**Solution**:\n1. Disable profiling temporarily to confirm:\n   ```python\n   DEFAULT_PROFILING_CONFIG.enabled = False\n   ```\n2. Report issue with profiling data and trainer details\n\n#### Profiling shows inconsistent timing\n\n**Cause**: Normal variation due to GPU warmup, data loading, or system load.\n\n**Solution**:\n- Ignore first few steps (warmup period)\n- Look at average/median timing over many steps\n- Use `log_interval` to reduce noise from individual outliers\n\n## Complete Config Example\n\nHere's a complete example integrating SwanLab with your RVQ-Alpha training:\n\n```yaml\nbase_model: /path/to/your/model\nmodel_type: Qwen2ForCausalLM\n\n# SwanLab Integration\nplugins:\n  - axolotl.integrations.swanlab.SwanLabPlugin\n  - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin\n\nuse_swanlab: true\nswanlab_project: RVQ-Alpha-Training\nswanlab_experiment_name: Qwen2.5-7B-MetaQA-Perturb-P020\nswanlab_description: \"Training on MetaQA and Perturbation datasets with NEW-RVQ encoding\"\nswanlab_mode: cloud\nswanlab_workspace: single-cell-genomics\n\n# Training configuration\nsequence_len: 32768\nmicro_batch_size: 1\ngradient_accumulation_steps: 1\nnum_epochs: 2\nlearning_rate: 2e-5\noptimizer: adamw_torch_fused\n\n# Datasets\ndatasets:\n  - path: /path/to/dataset\n    type: chat_template\n\n# Output\noutput_dir: ./outputs\n```\n\n## Modes Explained\n\n### `cloud` Mode (Default)\n- Syncs experiments to SwanLab cloud in real-time\n- Requires API key and internet connection\n- Best for: Team collaboration, remote monitoring\n\n### `local` Mode\n- Saves experiments locally only\n- No cloud sync\n- Best for: Local development, air-gapped environments\n\n### `offline` Mode\n- Saves metadata locally\n- Can sync to cloud later using `swanlab sync`\n- Best for: Unstable internet, sync later\n\n### `disabled` Mode\n- Turns off SwanLab completely\n- No logging or tracking\n- Best for: Debugging, testing\n\n## Configuration Validation & Conflict Detection\n\nSwanLab integration includes comprehensive validation and conflict detection to help you catch configuration errors early and avoid performance issues.\n\n### Required Fields Validation\n\nThe plugin validates your configuration at startup and provides clear error messages with solutions:\n\n#### Missing Project Name\n\n```yaml\n# ❌ INVALID: use_swanlab enabled but no project\nuse_swanlab: true\n# Error: SwanLab enabled but 'swanlab_project' is not set.\n```\n\n**Solution**:\n```yaml\n# ✅ VALID: Provide project name\nuse_swanlab: true\nswanlab_project: my-project\n```\n\n#### Invalid Mode\n\n```yaml\n# ❌ INVALID: Unknown mode\nuse_swanlab: true\nswanlab_project: my-project\nswanlab_mode: invalid-mode\n# Error: Invalid swanlab_mode: 'invalid-mode'. Valid options: cloud, local, offline, disabled\n```\n\n**Solution**:\n```yaml\n# ✅ VALID: Use one of the valid modes\nuse_swanlab: true\nswanlab_project: my-project\nswanlab_mode: cloud  # or: local, offline, disabled\n```\n\n#### Empty Project Name\n\n```yaml\n# ❌ INVALID: Empty string project name\nuse_swanlab: true\nswanlab_project: \"\"\n# Error: swanlab_project cannot be an empty string.\n```\n\n**Solution**:\n```yaml\n# ✅ VALID: Provide non-empty project name\nuse_swanlab: true\nswanlab_project: my-project\n```\n\n### Cloud Mode API Key Warning\n\nWhen using `cloud` mode without an API key, you'll receive a warning with multiple solutions:\n\n```yaml\nuse_swanlab: true\nswanlab_project: my-project\nswanlab_mode: cloud\n# No API key set\n# Warning: SwanLab cloud mode enabled but no API key found.\n```\n\n**Solutions**:\n1. Set environment variable: `export SWANLAB_API_KEY=your-api-key`\n2. Add to config (less secure): `swanlab_api_key: your-api-key`\n3. Run `swanlab login` before training\n4. Use `swanlab_mode: local` for offline tracking\n\n### Multi-Logger Performance Warnings\n\nUsing multiple logging tools simultaneously (SwanLab + WandB + MLflow + Comet) can impact training performance:\n\n#### Two Loggers - Warning\n\n```yaml\nuse_swanlab: true\nswanlab_project: my-project\n\nuse_wandb: true\nwandb_project: my-project\n\n# Warning: Multiple logging tools enabled: SwanLab, WandB\n# Expected overhead: ~3.0% per training step.\n```\n\n**Impact**:\n- Performance overhead: ~1-2% per logger (cumulative)\n- Increased memory usage\n- Longer training time per step\n- Potential config/callback conflicts\n\n**Recommendations**:\n- Choose ONE primary logging tool for production training\n- Use multiple loggers only for:\n  - Migration period (transitioning between tools)\n  - Short comparison runs\n  - Debugging specific tool issues\n- Monitor system resources (CPU, memory) during training\n\n#### Three+ Loggers - Error-Level Warning\n\n```yaml\nuse_swanlab: true\nswanlab_project: my-project\n\nuse_wandb: true\nwandb_project: my-project\n\nuse_mlflow: true\nmlflow_tracking_uri: http://localhost:5000\n\n# ERROR: 3 logging tools enabled simultaneously!\n# Expected overhead: ~4.5% per training step.\n# STRONGLY RECOMMEND: Disable all but ONE logging tool\n```\n\n**Why This Matters**:\n- With 3 loggers: ~4-5% overhead per step → significant slowdown over long training\n- Example: 10,000 steps at 2s/step → ~400-500 seconds extra (6-8 minutes)\n- Memory overhead scales with number of loggers\n- Rare edge cases with callback ordering conflicts\n\n### Auto-Enable Logic\n\nFor convenience, SwanLab will auto-enable if you specify a project without setting `use_swanlab`:\n\n```yaml\n# This configuration:\nswanlab_project: my-project\n\n# Automatically becomes:\nuse_swanlab: true\nswanlab_project: my-project\n```\n\n### Distributed Training Detection\n\nIn distributed training scenarios (multi-GPU), the plugin automatically detects and reports:\n\n```yaml\nuse_swanlab: true\nswanlab_project: my-project\nswanlab_mode: cloud\n\n# When running with torchrun --nproc_per_node=4:\n# Info: Distributed training detected (world_size=4)\n# Info: SwanLab mode: cloud\n# Info: Only rank 0 will initialize SwanLab\n# Info: Other ranks will skip SwanLab to avoid conflicts\n```\n\n**Why Only Rank 0**:\n- Avoids duplicate experiment runs\n- Reduces network/cloud API overhead on worker ranks\n- Prevents race conditions in metric logging\n\n## Authentication\n\n### Method 1: Environment Variable (Recommended)\n```bash\nexport SWANLAB_API_KEY=your-api-key-here\n```\n\n### Method 2: Login Command\n```bash\nswanlab login\n# Enter your API key when prompted\n```\n\n### Method 3: Config File\n```yaml\nswanlab_api_key: your-api-key-here\n```\n\n## What Gets Logged?\n\n### Automatically Logged Metrics\n- Training loss\n- Learning rate\n- Gradient norm\n- Training steps\n- Epoch progress\n\n### Automatically Logged Config\n- Model configuration (base_model, model_type)\n- Training hyperparameters (learning_rate, batch_size, etc.)\n- Optimizer settings\n- Parallelization settings (FSDP, DeepSpeed, Context Parallel)\n- Axolotl configuration file\n- DeepSpeed configuration (if used)\n\n## Viewing Your Experiments\n\n### Cloud Mode\nVisit [https://swanlab.cn](https://swanlab.cn) and navigate to your project to view:\n- Real-time training metrics\n- Hyperparameter comparison\n- System resource usage\n- Configuration files\n\n### Local Mode\n```bash\n# Start local dashboard\nswanlab watch ./swanlog\n\n# Open browser to http://localhost:5092\n```\n\n## Integration with Existing Tools\n\nSwanLab can work alongside other tracking tools:\n\n```yaml\nplugins:\n  - axolotl.integrations.swanlab.SwanLabPlugin\n\n# Use both SwanLab and Wandb\nuse_swanlab: true\nswanlab_project: my-project\n\nuse_wandb: true\nwandb_project: my-project\n```\n\n## Troubleshooting\n\n### Configuration Errors\n\n#### Error: \"SwanLab enabled but 'swanlab_project' is not set\"\n\n**Cause**: You enabled SwanLab (`use_swanlab: true`) but forgot to specify a project name.\n\n**Solution**:\n```yaml\nuse_swanlab: true\nswanlab_project: my-project  # Add this line\n```\n\n#### Error: \"Invalid swanlab_mode: 'xxx'\"\n\n**Cause**: You provided an invalid mode value.\n\n**Solution**: Use one of the valid modes:\n```yaml\nswanlab_mode: cloud     # or: local, offline, disabled\n```\n\n#### Error: \"swanlab_project cannot be an empty string\"\n\n**Cause**: You set `swanlab_project: \"\"` (empty string).\n\n**Solution**: Either provide a valid name or remove the field:\n```yaml\n# Option 1: Provide valid name\nswanlab_project: my-project\n\n# Option 2: Remove the field entirely\n# swanlab_project: \"\"  <- Remove this line\n```\n\n### Import Errors\n\n#### Error: \"SwanLab is not installed\"\n\n**Cause**: SwanLab package is not installed in your environment.\n\n**Solution**:\n```bash\npip install swanlab\n# or\npip install swanlab>=0.3.0\n```\n\n### Performance Issues\n\n#### Warning: \"Multiple logging tools enabled\"\n\n**Cause**: You have multiple experiment tracking tools enabled (e.g., SwanLab + WandB + MLflow).\n\n**Impact**: ~1-2% performance overhead per logger, cumulative.\n\n**Solution**: For production training, disable all but one logger:\n```yaml\n# Option 1: Keep only SwanLab\nuse_swanlab: true\nswanlab_project: my-project\nuse_wandb: false      # Disable others\nuse_mlflow: false\n\n# Option 2: Keep only WandB\nuse_swanlab: false\nuse_wandb: true\nwandb_project: my-project\n```\n\n**Exception**: Multiple loggers are acceptable for:\n- Short comparison runs (< 100 steps)\n- Migration testing between logging tools\n- Debugging logger-specific issues\n\n### Distributed Training Issues\n\n#### SwanLab creates duplicate runs in multi-GPU training\n\n**Cause**: All ranks are initializing SwanLab instead of just rank 0.\n\n**Expected Behavior**: The plugin automatically ensures only rank 0 initializes SwanLab. You should see:\n```text\nInfo: Distributed training detected (world_size=4)\nInfo: Only rank 0 will initialize SwanLab\nInfo: Other ranks will skip SwanLab to avoid conflicts\n```\n\n**If you see duplicates**:\n1. Check your plugin is loaded correctly\n2. Verify you're using the latest SwanLab integration code\n3. Check logs for initialization messages on all ranks\n\n### SwanLab not logging metrics\n\n**Solution**: Ensure SwanLab is initialized before training starts. The plugin automatically handles this in `pre_model_load`.\n\n### API Key errors\n\n**Solution**:\n```bash\n# Verify API key\necho $SWANLAB_API_KEY\n\n# Re-login\nswanlab login\n```\n\n### Cloud sync issues\n\n**Solution**: Use `offline` mode and sync later:\n```yaml\nswanlab_mode: offline\n```\n\nThen sync when ready:\n```bash\nswanlab sync ./swanlog\n```\n\n### Plugin not loaded\n\n**Solution**: Verify plugin path in config:\n```yaml\nplugins:\n  - axolotl.integrations.swanlab.SwanLabPlugin  # Correct path\n```\n\n### Lark Notification Issues\n\n#### Error: \"Failed to import SwanLab Lark plugin\"\n\n**Cause**: Your SwanLab version doesn't include the Lark plugin (requires SwanLab >= 0.3.0).\n\n**Solution**:\n```bash\n# Upgrade SwanLab to latest version\npip install --upgrade swanlab\n\n# Or install specific version\npip install 'swanlab>=0.3.0'\n```\n\n#### Warning: \"Lark webhook has no secret configured\"\n\n**Cause**: You provided `swanlab_lark_webhook_url` but no `swanlab_lark_secret`.\n\n**Impact**: Lark notifications will work, but without HMAC authentication (security risk).\n\n**Solution**: Add HMAC secret for production use:\n```yaml\nswanlab_lark_webhook_url: https://open.feishu.cn/open-apis/bot/v2/hook/xxx\nswanlab_lark_secret: your-webhook-secret  # Add this line\n```\n\n**When it's OK to skip secret**:\n- Local development and testing\n- Internal networks with restricted access\n- Non-sensitive training experiments\n\n**When secret is required**:\n- Production training jobs\n- Training with proprietary data\n- Multi-team shared Lark groups\n\n#### Error: \"Failed to register Lark callback\"\n\n**Cause**: Invalid webhook URL or network connectivity issues.\n\n**Diagnostic steps**:\n```bash\n# 1. Test webhook URL manually\ncurl -X POST \"YOUR_WEBHOOK_URL\" \\\n  -H 'Content-Type: application/json' \\\n  -d '{\"msg_type\":\"text\",\"content\":{\"text\":\"Test from Axolotl\"}}'\n\n# 2. Check SwanLab version\npip show swanlab\n\n# 3. Verify webhook URL format\n# Should start with: https://open.feishu.cn/open-apis/bot/v2/hook/\n```\n\n**Solution**:\n1. Verify webhook URL is correct (copy from Lark bot settings)\n2. Check network connectivity to Lark API\n3. Ensure webhook is not expired (Lark webhooks can expire)\n4. Regenerate webhook URL in Lark bot settings if needed\n\n#### Lark notifications not received\n\n**Cause**: Multiple possible causes.\n\n**Diagnostic checklist**:\n\n1. **Check training logs** for Lark registration confirmation:\n   ```text\n   # Expected log message (rank 0 only):\n   INFO: Registered Lark notification callback with HMAC authentication\n   ```\n\n2. **Verify webhook in Lark**: Test webhook manually (see above)\n\n3. **Check distributed training**: Only rank 0 sends notifications\n   ```bash\n   # If running multi-GPU, check rank 0 logs specifically\n   grep \"Registered Lark\" logs/rank_0.log\n   ```\n\n4. **Verify SwanLab is initialized**: Lark callback needs SwanLab to be running\n   ```yaml\n   use_swanlab: true  # Must be enabled\n   swanlab_project: my-project  # Must be set\n   ```\n\n5. **Check Lark bot permissions**: Ensure bot is added to the target group chat\n\n#### Duplicate Lark notifications in multi-GPU training\n\n**Expected Behavior**: Should NOT happen - only rank 0 sends notifications.\n\n**If you see duplicates**:\n1. Check that all GPUs are using the same config file\n2. Verify plugin is loaded correctly on all ranks\n3. Check logs for unexpected Lark initialization on non-zero ranks\n4. Ensure `RANK` or `LOCAL_RANK` environment variables are set correctly\n\n**Solution**: This is a bug if it occurs. Report with:\n- Full training command\n- Logs from all ranks\n- Config file\n\n## Comparison: SwanLab vs WandB\n\n| Feature | SwanLab | WandB |\n|---------|---------|-------|\n| Open Source | ✅ Yes | ❌ No |\n| Self-Hosting | ✅ Easy | ⚠️ Complex |\n| Free Tier | ✅ Generous | ⚠️ Limited |\n| Chinese Support | ✅ Native | ⚠️ Limited |\n| Offline Mode | ✅ Full support | ✅ Supported |\n| Integration | 🆕 New | ✅ Mature |\n\n## Advanced Usage\n\n### Custom Logging\n\nYou can add custom metrics in your callbacks:\n\n```python\nimport swanlab\n\n# In your custom callback\nswanlab.log({\n    \"custom_metric\": value,\n    \"epoch\": epoch_num\n})\n```\n\n### Experiment Comparison\n\n```bash\n# Compare multiple experiments\nswanlab compare run1 run2 run3\n```\n\n## Support\n\n- **Documentation**: [https://docs.swanlab.cn](https://docs.swanlab.cn)\n- **GitHub**: [https://github.com/SwanHubX/SwanLab](https://github.com/SwanHubX/SwanLab)\n- **Issues**: Report bugs at [GitHub Issues](https://github.com/SwanHubX/SwanLab/issues)\n\n## License\n\nThis integration follows the Axolotl Community License Agreement.\n\n## Acknowledgements\n\nThis integration is built on top of:\n- [SwanLab](https://github.com/SwanHubX/SwanLab) - Experiment tracking tool\n- [Transformers](https://github.com/huggingface/transformers) - SwanLabCallback\n- [Axolotl](https://github.com/axolotl-ai-cloud/axolotl) - Training framework\n"
  },
  {
    "path": "src/axolotl/integrations/swanlab/__init__.py",
    "content": "\"\"\"SwanLab integration plugin for Axolotl\"\"\"\n\nfrom axolotl.integrations.swanlab.args import SwanLabConfig\nfrom axolotl.integrations.swanlab.plugins import SwanLabPlugin\n\n__all__ = [\"SwanLabConfig\", \"SwanLabPlugin\"]\n"
  },
  {
    "path": "src/axolotl/integrations/swanlab/args.py",
    "content": "\"\"\"SwanLab configuration arguments\"\"\"\n\nfrom pydantic import BaseModel, Field, field_validator, model_validator\n\n\nclass SwanLabConfig(BaseModel):\n    \"\"\"SwanLab configuration subset\"\"\"\n\n    use_swanlab: bool | None = Field(\n        default=True,\n        json_schema_extra={\n            \"description\": \"Enable SwanLab experiment tracking and visualization\"\n        },\n    )\n    swanlab_project: str | None = Field(\n        default=None,\n        json_schema_extra={\"description\": \"Your SwanLab project name\"},\n    )\n    swanlab_experiment_name: str | None = Field(\n        default=None,\n        json_schema_extra={\"description\": \"Set the name of your SwanLab experiment\"},\n    )\n    swanlab_description: str | None = Field(\n        default=None,\n        json_schema_extra={\"description\": \"Description for your SwanLab experiment\"},\n    )\n    swanlab_mode: str | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": '\"cloud\" to sync to SwanLab cloud, \"local\" for local only, \"offline\" to save metadata locally, \"disabled\" to turn off SwanLab'\n        },\n    )\n    swanlab_workspace: str | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"SwanLab workspace name (organization or username)\"\n        },\n    )\n    swanlab_api_key: str | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"SwanLab API key for authentication. Can also be set via SWANLAB_API_KEY environment variable\"\n        },\n    )\n    swanlab_log_model: bool | None = Field(\n        default=False,\n        json_schema_extra={\n            \"description\": \"Whether to log model checkpoints to SwanLab (feature coming soon)\"\n        },\n    )\n    swanlab_web_host: str | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Web address for SwanLab cloud environment (for private deployment)\"\n        },\n    )\n    swanlab_api_host: str | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"API address for SwanLab cloud environment (for private deployment)\"\n        },\n    )\n    swanlab_lark_webhook_url: str | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Lark (Feishu) webhook URL for sending training notifications to team chat\"\n        },\n    )\n    swanlab_lark_secret: str | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Secret for Lark webhook HMAC signature authentication (optional)\"\n        },\n    )\n    swanlab_log_completions: bool | None = Field(\n        default=True,\n        json_schema_extra={\n            \"description\": \"Enable logging RLHF completions to SwanLab for qualitative analysis (DPO/KTO/ORPO/GRPO)\"\n        },\n    )\n    swanlab_completion_log_interval: int | None = Field(\n        default=100,\n        json_schema_extra={\n            \"description\": \"Number of training steps between completion table logging to SwanLab\"\n        },\n    )\n    swanlab_completion_max_buffer: int | None = Field(\n        default=128,\n        json_schema_extra={\n            \"description\": \"Maximum number of completions to buffer before logging (prevents memory leaks)\"\n        },\n    )\n\n    @field_validator(\"swanlab_mode\")\n    @classmethod\n    def validate_swanlab_mode(cls, v):\n        \"\"\"Validate swanlab_mode is one of the allowed values.\"\"\"\n        if v is None:\n            return v\n\n        valid_modes = [\"cloud\", \"local\", \"offline\", \"disabled\"]\n        if v not in valid_modes:\n            raise ValueError(\n                f\"Invalid swanlab_mode: '{v}'.\\n\\n\"\n                f\"Valid options: {', '.join(valid_modes)}\\n\\n\"\n                f\"Examples:\\n\"\n                f\"  swanlab_mode: cloud     # Sync to SwanLab cloud\\n\"\n                f\"  swanlab_mode: local     # Local only, no cloud sync\\n\"\n                f\"  swanlab_mode: offline   # Save metadata locally\\n\"\n                f\"  swanlab_mode: disabled  # Turn off SwanLab\\n\"\n            )\n        return v\n\n    @field_validator(\"swanlab_project\")\n    @classmethod\n    def validate_swanlab_project(cls, v):\n        \"\"\"Validate swanlab_project is non-empty when provided.\"\"\"\n        if v is not None and isinstance(v, str) and len(v.strip()) == 0:\n            raise ValueError(\n                \"swanlab_project cannot be an empty string.\\n\\n\"\n                \"Either:\\n\"\n                \"  1. Provide a valid project name: swanlab_project: my-project\\n\"\n                \"  2. Remove the swanlab_project field entirely\\n\"\n            )\n        return v\n\n    @model_validator(mode=\"after\")\n    def validate_swanlab_enabled_requires_project(self):\n        \"\"\"Validate that if use_swanlab is True, swanlab_project must be set.\"\"\"\n        if self.use_swanlab is True and not self.swanlab_project:\n            raise ValueError(\n                \"SwanLab enabled (use_swanlab: true) but 'swanlab_project' is not set.\\n\\n\"\n                \"Solutions:\\n\"\n                \"  1. Add 'swanlab_project: your-project-name' to your config\\n\"\n                \"  2. Set 'use_swanlab: false' to disable SwanLab\\n\\n\"\n                \"Example:\\n\"\n                \"  use_swanlab: true\\n\"\n                \"  swanlab_project: my-llm-training\\n\"\n            )\n        return self\n"
  },
  {
    "path": "src/axolotl/integrations/swanlab/callbacks.py",
    "content": "\"\"\"SwanLab callbacks for Axolotl trainers.\n\nThis module provides HuggingFace Trainer callbacks for logging\nRLHF completions to SwanLab.\n\"\"\"\n\nfrom transformers import (\n    TrainerCallback,\n    TrainerControl,\n    TrainerState,\n    TrainingArguments,\n)\n\nfrom axolotl.integrations.swanlab.completion_logger import CompletionLogger\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\n\nclass SwanLabRLHFCompletionCallback(TrainerCallback):\n    \"\"\"Callback for logging RLHF completions to SwanLab.\n\n    This callback periodically logs model completions (prompts, chosen/rejected\n    responses, rewards) to SwanLab during RLHF training for qualitative analysis.\n\n    Supports DPO, KTO, ORPO, and GRPO trainers.\n\n    Example usage:\n        >>> callback = SwanLabRLHFCompletionCallback(\n        ...     log_interval=100,  # Log every 100 steps\n        ...     max_completions=128,  # Keep last 128 completions\n        ... )\n        >>> trainer.add_callback(callback)\n\n    Attributes:\n        logger: CompletionLogger instance\n        log_interval: Number of steps between SwanLab logging\n        trainer_type: Auto-detected trainer type (dpo/kto/orpo/grpo)\n    \"\"\"\n\n    def __init__(\n        self,\n        log_interval: int = 100,\n        max_completions: int = 128,\n        table_name: str = \"rlhf_completions\",\n    ):\n        \"\"\"Initialize SwanLab RLHF completion callback.\n\n        Args:\n            log_interval: Log to SwanLab every N steps. Default: 100\n            max_completions: Maximum completions to buffer. Default: 128\n            table_name: SwanLab table name. Default: \"rlhf_completions\"\n        \"\"\"\n        super().__init__()\n        self.logger = CompletionLogger(maxlen=max_completions)\n        self.log_interval = log_interval\n        self.table_name = table_name\n        self.trainer_type: str | None = None  # Auto-detected\n        self._last_logged_step = 0\n\n    def on_init_end(\n        self,\n        args: TrainingArguments,\n        state: TrainerState,\n        control: TrainerControl,\n        **kwargs,\n    ):\n        \"\"\"Detect trainer type on initialization.\"\"\"\n        trainer = kwargs.get(\"trainer\")\n        if trainer is not None:\n            trainer_name = trainer.__class__.__name__\n            if \"DPO\" in trainer_name:\n                self.trainer_type = \"dpo\"\n            elif \"KTO\" in trainer_name:\n                self.trainer_type = \"kto\"\n            elif \"ORPO\" in trainer_name:\n                self.trainer_type = \"orpo\"\n            elif \"GRPO\" in trainer_name:\n                self.trainer_type = \"grpo\"\n            else:\n                self.trainer_type = \"unknown\"\n\n            LOG.info(\n                f\"SwanLab RLHF completion logging enabled for {trainer_name} \"\n                f\"(type: {self.trainer_type})\"\n            )\n\n    def on_log(\n        self,\n        args: TrainingArguments,\n        state: TrainerState,\n        control: TrainerControl,\n        logs: dict | None = None,\n        **kwargs,\n    ):\n        \"\"\"Capture completions from logs and buffer them.\n\n        Different trainers log completions in different formats:\n        - DPO: logs['dpo/chosen'], logs['dpo/rejected'], logs['dpo/reward_diff']\n        - KTO: logs['kto/completion'], logs['kto/label'], logs['kto/reward']\n        - ORPO: logs['orpo/chosen'], logs['orpo/rejected']\n        - GRPO: logs['grpo/completion'], logs['grpo/reward']\n\n        Note: This is a placeholder implementation. Actual log keys depend\n        on the TRL trainer implementation. You may need to patch the trainers\n        to expose completion data in logs.\n        \"\"\"\n        if logs is None or self.trainer_type is None:\n            return\n\n        step = state.global_step\n\n        # DPO completions\n        if self.trainer_type == \"dpo\":\n            if all(key in logs for key in [\"dpo/prompt\", \"dpo/chosen\", \"dpo/rejected\"]):\n                self.logger.add_dpo_completion(\n                    step=step,\n                    prompt=logs.get(\"dpo/prompt\", \"\"),\n                    chosen=logs.get(\"dpo/chosen\", \"\"),\n                    rejected=logs.get(\"dpo/rejected\", \"\"),\n                    reward_diff=logs.get(\"dpo/reward_diff\"),\n                )\n\n        # KTO completions\n        elif self.trainer_type == \"kto\":\n            if all(key in logs for key in [\"kto/prompt\", \"kto/completion\"]):\n                self.logger.add_kto_completion(\n                    step=step,\n                    prompt=logs.get(\"kto/prompt\", \"\"),\n                    completion=logs.get(\"kto/completion\", \"\"),\n                    label=logs.get(\"kto/label\", False),\n                    reward=logs.get(\"kto/reward\"),\n                )\n\n        # ORPO completions\n        elif self.trainer_type == \"orpo\":\n            if all(\n                key in logs for key in [\"orpo/prompt\", \"orpo/chosen\", \"orpo/rejected\"]\n            ):\n                self.logger.add_orpo_completion(\n                    step=step,\n                    prompt=logs.get(\"orpo/prompt\", \"\"),\n                    chosen=logs.get(\"orpo/chosen\", \"\"),\n                    rejected=logs.get(\"orpo/rejected\", \"\"),\n                    log_odds_ratio=logs.get(\"orpo/log_odds_ratio\"),\n                )\n\n        # GRPO completions\n        elif self.trainer_type == \"grpo\":\n            if all(key in logs for key in [\"grpo/prompt\", \"grpo/completion\"]):\n                self.logger.add_grpo_completion(\n                    step=step,\n                    prompt=logs.get(\"grpo/prompt\", \"\"),\n                    completion=logs.get(\"grpo/completion\", \"\"),\n                    reward=logs.get(\"grpo/reward\"),\n                    advantage=logs.get(\"grpo/advantage\"),\n                )\n\n        # Periodically log to SwanLab\n        if step - self._last_logged_step >= self.log_interval:\n            if len(self.logger) > 0:\n                self.logger.log_to_swanlab(table_name=self.table_name)\n                self.logger.clear()\n                self._last_logged_step = step\n\n    def on_train_end(\n        self,\n        args: TrainingArguments,\n        state: TrainerState,\n        control: TrainerControl,\n        **kwargs,\n    ):\n        \"\"\"Log remaining completions at end of training.\"\"\"\n        if len(self.logger) > 0:\n            LOG.info(\n                f\"Training complete, logging final {len(self.logger)} completions to SwanLab\"\n            )\n            self.logger.log_to_swanlab(table_name=self.table_name)\n            self._last_logged_step = state.global_step\n"
  },
  {
    "path": "src/axolotl/integrations/swanlab/completion_logger.py",
    "content": "\"\"\"SwanLab completion logger for RLHF/DPO/KTO/ORPO/GRPO training.\n\nThis module provides utilities for logging model completions during\npreference training to SwanLab for qualitative analysis.\n\"\"\"\n\nfrom collections import deque\nfrom collections.abc import Mapping\nfrom typing import Any\n\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\n\nclass CompletionLogger:\n    \"\"\"Memory-bounded logger for RLHF completions.\n\n    Stores prompts, completions, and rewards in fixed-size deques to prevent\n    memory leaks during long training runs. Logs completion tables to SwanLab\n    for qualitative analysis of model outputs.\n\n    Example usage:\n        >>> logger = CompletionLogger(maxlen=128)\n        >>> logger.add_dpo_completion(\n        ...     step=0,\n        ...     prompt=\"What is AI?\",\n        ...     chosen=\"Artificial Intelligence is...\",\n        ...     rejected=\"AI means...\",\n        ...     reward_diff=0.5\n        ... )\n        >>> logger.log_to_swanlab()\n\n    Attributes:\n        maxlen: Maximum number of completions to store (older ones are dropped)\n        data: Deque storing completion dictionaries\n    \"\"\"\n\n    def __init__(self, maxlen: int = 128):\n        \"\"\"Initialize completion logger with bounded buffer.\n\n        Args:\n            maxlen: Maximum number of completions to store. When the buffer\n                is full, oldest completions are automatically discarded.\n                Default: 128 (sufficient for most RLHF runs without memory issues)\n        \"\"\"\n        self.maxlen = maxlen\n        self.data: deque[Mapping[str, Any]] = deque(maxlen=maxlen)\n\n    def add_dpo_completion(\n        self,\n        step: int,\n        prompt: str,\n        chosen: str,\n        rejected: str,\n        reward_diff: float | None = None,\n    ) -> None:\n        \"\"\"Add a DPO completion to the buffer.\n\n        Args:\n            step: Training step number\n            prompt: Input prompt\n            chosen: Chosen (preferred) completion\n            rejected: Rejected (non-preferred) completion\n            reward_diff: Reward difference (chosen - rejected), if available\n        \"\"\"\n        entry = {\n            \"step\": step,\n            \"prompt\": prompt,\n            \"chosen\": chosen,\n            \"rejected\": rejected,\n        }\n        if reward_diff is not None:\n            entry[\"reward_diff\"] = reward_diff\n\n        self.data.append(entry)\n\n    def add_kto_completion(\n        self,\n        step: int,\n        prompt: str,\n        completion: str,\n        label: bool,\n        reward: float | None = None,\n    ) -> None:\n        \"\"\"Add a KTO completion to the buffer.\n\n        Args:\n            step: Training step number\n            prompt: Input prompt\n            completion: Model-generated completion\n            label: True if desirable, False if undesirable\n            reward: Reward score, if available\n        \"\"\"\n        entry = {\n            \"step\": step,\n            \"prompt\": prompt,\n            \"completion\": completion,\n            \"label\": \"desirable\" if label else \"undesirable\",\n        }\n        if reward is not None:\n            entry[\"reward\"] = reward\n\n        self.data.append(entry)\n\n    def add_orpo_completion(\n        self,\n        step: int,\n        prompt: str,\n        chosen: str,\n        rejected: str,\n        log_odds_ratio: float | None = None,\n    ) -> None:\n        \"\"\"Add an ORPO completion to the buffer.\n\n        Args:\n            step: Training step number\n            prompt: Input prompt\n            chosen: Chosen (preferred) completion\n            rejected: Rejected (non-preferred) completion\n            log_odds_ratio: Log odds ratio between chosen and rejected\n        \"\"\"\n        entry = {\n            \"step\": step,\n            \"prompt\": prompt,\n            \"chosen\": chosen,\n            \"rejected\": rejected,\n        }\n        if log_odds_ratio is not None:\n            entry[\"log_odds_ratio\"] = log_odds_ratio\n\n        self.data.append(entry)\n\n    def add_grpo_completion(\n        self,\n        step: int,\n        prompt: str,\n        completion: str,\n        reward: float | None = None,\n        advantage: float | None = None,\n    ) -> None:\n        \"\"\"Add a GRPO completion to the buffer.\n\n        Args:\n            step: Training step number\n            prompt: Input prompt\n            completion: Model-generated completion\n            reward: Reward score from reward model\n            advantage: Advantage estimate (reward - baseline)\n        \"\"\"\n        entry = {\n            \"step\": step,\n            \"prompt\": prompt,\n            \"completion\": completion,\n        }\n        if reward is not None:\n            entry[\"reward\"] = reward\n        if advantage is not None:\n            entry[\"advantage\"] = advantage\n\n        self.data.append(entry)\n\n    def log_to_swanlab(self, table_name: str = \"completions\") -> bool:\n        \"\"\"Log buffered completions to SwanLab as a table.\n\n        Creates a SwanLab echarts Table with all buffered completions.\n        Only logs if SwanLab is initialized and data is available.\n\n        Args:\n            table_name: Name of the table in SwanLab dashboard.\n                Default: \"completions\"\n\n        Returns:\n            True if logging succeeded, False otherwise\n        \"\"\"\n        if not self.data:\n            LOG.debug(\"No completions to log to SwanLab\")\n            return False\n\n        try:\n            import swanlab\n\n            if swanlab.get_run() is None:\n                LOG.debug(\"SwanLab not initialized, skipping completion logging\")\n                return False\n\n            # Convert deque to list of dicts\n            completions = list(self.data)\n\n            # Extract headers from first entry (all entries should have same structure)\n            headers = list(completions[0].keys())\n\n            # Build rows: each completion becomes one row\n            rows = []\n            for completion in completions:\n                row = [completion.get(header, \"\") for header in headers]\n                rows.append(row)\n\n            # Log to SwanLab as echarts Table\n            swanlab.log({table_name: swanlab.echarts.Table().add(headers, rows)})\n\n            LOG.info(f\"Logged {len(rows)} completions to SwanLab table '{table_name}'\")\n            return True\n\n        except ImportError:\n            LOG.warning(\n                \"SwanLab not installed, cannot log completions. \"\n                \"Install with: pip install swanlab\"\n            )\n            return False\n        except Exception as err:  # pylint: disable=broad-except\n            LOG.exception(\"Failed to log completions to SwanLab: %s\", err)\n            return False\n\n    def clear(self) -> None:\n        \"\"\"Clear all buffered completions.\"\"\"\n        self.data.clear()\n\n    def __len__(self) -> int:\n        \"\"\"Return number of buffered completions.\"\"\"\n        return len(self.data)\n\n    def __repr__(self) -> str:\n        \"\"\"String representation showing buffer status.\"\"\"\n        return (\n            f\"CompletionLogger(maxlen={self.maxlen}, \"\n            f\"buffered={len(self.data)}/{self.maxlen})\"\n        )\n"
  },
  {
    "path": "src/axolotl/integrations/swanlab/plugins.py",
    "content": "\"\"\"SwanLab Plugin for Axolotl\"\"\"\n\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING\n\nfrom axolotl.integrations.base import BasePlugin\nfrom axolotl.utils.logging import get_logger\n\nif TYPE_CHECKING:\n    from transformers import TrainerCallback\n\n    from axolotl.utils.dict import DictDefault\n\nLOG = get_logger(__name__)\n\n\nclass SwanLabPlugin(BasePlugin):\n    \"\"\"\n    SwanLab integration plugin for Axolotl.\n\n    Provides experiment tracking, visualization, and logging capabilities\n    using SwanLab (https://swanlab.cn).\n\n    Usage in config.yaml:\n        plugins:\n          - axolotl.integrations.swanlab.SwanLabPlugin\n\n        use_swanlab: true\n        swanlab_project: my-project\n        swanlab_experiment_name: my-experiment\n        swanlab_mode: cloud  # or 'local', 'offline', 'disabled'\n    \"\"\"\n\n    def __init__(self):\n        super().__init__()\n        self.swanlab_initialized = False\n        LOG.info(\"SwanLab plugin initialized\")\n\n    def get_input_args(self) -> str:\n        \"\"\"Returns the configuration model for SwanLab integration.\"\"\"\n        return \"axolotl.integrations.swanlab.SwanLabConfig\"\n\n    def register(self, cfg: dict):\n        \"\"\"Register SwanLab plugin with configuration and conflict detection.\"\"\"\n        LOG.info(\"Registering SwanLab plugin\")\n\n        # === Conflict Detection: Required Fields ===\n\n        # Check if SwanLab is enabled\n        if cfg.get(\"use_swanlab\"):\n            # 1. Validate project name is set\n            if not cfg.get(\"swanlab_project\"):\n                raise ValueError(\n                    \"SwanLab enabled but 'swanlab_project' is not set.\\n\\n\"\n                    \"Solutions:\\n\"\n                    \"  1. Add 'swanlab_project: your-project-name' to your config\\n\"\n                    \"  2. Set 'use_swanlab: false' to disable SwanLab\\n\\n\"\n                    \"See: src/axolotl/integrations/swanlab/README.md for examples\"\n                )\n\n            # 2. Validate swanlab_mode value\n            valid_modes = [\"cloud\", \"local\", \"offline\", \"disabled\"]\n            mode = cfg.get(\"swanlab_mode\")\n            if mode and mode not in valid_modes:\n                raise ValueError(\n                    f\"Invalid swanlab_mode: '{mode}'.\\n\\n\"\n                    f\"Valid options: {', '.join(valid_modes)}\\n\\n\"\n                    f\"Example:\\n\"\n                    f\"  swanlab_mode: cloud  # Sync to SwanLab cloud\\n\"\n                    f\"  swanlab_mode: local  # Local only, no cloud sync\\n\"\n                )\n\n            # 3. Check API key for cloud mode\n            import os\n\n            mode = cfg.get(\"swanlab_mode\", \"cloud\")  # Default is cloud\n            if mode == \"cloud\":\n                api_key = cfg.get(\"swanlab_api_key\") or os.environ.get(\n                    \"SWANLAB_API_KEY\"\n                )\n                if not api_key:\n                    LOG.warning(\n                        \"SwanLab cloud mode enabled but no API key found.\\n\"\n                        \"SwanLab may fail to initialize during training.\\n\\n\"\n                        \"Solutions:\\n\"\n                        \"  1. Set SWANLAB_API_KEY environment variable:\\n\"\n                        \"     export SWANLAB_API_KEY=your-api-key\\n\"\n                        \"  2. Add 'swanlab_api_key: your-api-key' to config (less secure)\\n\"\n                        \"  3. Run 'swanlab login' before training\\n\"\n                        \"  4. Use 'swanlab_mode: local' for offline tracking\\n\"\n                    )\n\n        # === Conflict Detection: Multi-Logger Performance Warning ===\n\n        # Detect all active logging tools\n        active_loggers = []\n        if cfg.get(\"use_wandb\"):\n            active_loggers.append(\"WandB\")\n        if cfg.get(\"use_mlflow\"):\n            active_loggers.append(\"MLflow\")\n        if cfg.get(\"comet_api_key\") or cfg.get(\"comet_project_name\"):\n            active_loggers.append(\"Comet\")\n        if cfg.get(\"use_swanlab\"):\n            active_loggers.append(\"SwanLab\")\n\n        if len(active_loggers) > 1:\n            LOG.warning(\n                f\"\\n{'=' * 70}\\n\"\n                f\"Multiple logging tools enabled: {', '.join(active_loggers)}\\n\"\n                f\"{'=' * 70}\\n\"\n                f\"This may cause:\\n\"\n                f\"  - Performance overhead (~1-2% per logger, cumulative)\\n\"\n                f\"  - Increased memory usage\\n\"\n                f\"  - Longer training time per step\\n\"\n                f\"  - Potential config/callback conflicts\\n\\n\"\n                f\"Recommendations:\\n\"\n                f\"  - Choose ONE primary logging tool for production training\\n\"\n                f\"  - Use multiple loggers only for:\\n\"\n                f\"    * Migration period (transitioning between tools)\\n\"\n                f\"    * Short comparison runs\\n\"\n                f\"    * Debugging specific tool issues\\n\"\n                f\"  - Monitor system resources (CPU, memory) during training\\n\"\n                f\"{'=' * 70}\\n\"\n            )\n\n            if len(active_loggers) >= 3:\n                LOG.error(\n                    f\"\\n{'!' * 70}\\n\"\n                    f\"WARNING: {len(active_loggers)} logging tools enabled simultaneously!\\n\"\n                    f\"{'!' * 70}\\n\"\n                    f\"This is likely unintentional and WILL significantly impact performance.\\n\"\n                    f\"Expected overhead: ~{len(active_loggers) * 1.5:.1f}% per training step.\\n\\n\"\n                    f\"STRONGLY RECOMMEND:\\n\"\n                    f\"  - Disable all but ONE logging tool\\n\"\n                    f\"  - Use config inheritance to manage multiple configs\\n\"\n                    f\"{'!' * 70}\\n\"\n                )\n\n        # === Auto-Enable Logic ===\n\n        # Enable SwanLab if project is specified\n        if cfg.get(\"swanlab_project\") and not cfg.get(\"use_swanlab\"):\n            cfg[\"use_swanlab\"] = True\n            LOG.info(\"Automatically enabled use_swanlab because swanlab_project is set\")\n\n    def pre_model_load(self, cfg: DictDefault):\n        \"\"\"Initialize SwanLab before model loading with runtime checks.\"\"\"\n        if not cfg.use_swanlab:\n            return\n\n        # === Runtime Check: Import Availability ===\n        try:\n            import swanlab\n        except ImportError as err:\n            raise ImportError(\n                \"SwanLab is not installed.\\n\\n\"\n                \"Install with:\\n\"\n                \"  pip install swanlab\\n\\n\"\n                \"Or add to requirements:\\n\"\n                \"  swanlab>=0.3.0\\n\\n\"\n                f\"Original error: {err}\"\n            ) from err\n\n        # Log SwanLab version\n        try:\n            swanlab_version = swanlab.__version__\n            LOG.info(f\"SwanLab version: {swanlab_version}\")\n        except AttributeError:\n            LOG.warning(\"Could not determine SwanLab version\")\n\n        # === Runtime Check: Distributed Training Setup ===\n        from axolotl.utils.distributed import get_world_size, is_main_process\n\n        world_size = get_world_size()\n        if world_size > 1:\n            mode = getattr(cfg, \"swanlab_mode\", \"cloud\")\n            LOG.info(\n                f\"\\n{'=' * 70}\\n\"\n                f\"Distributed training detected (world_size={world_size})\\n\"\n                f\"SwanLab mode: {mode}\\n\"\n                f\"{'=' * 70}\\n\"\n                f\"Behavior:\\n\"\n                f\"  - Only rank 0 will initialize SwanLab\\n\"\n                f\"  - Other ranks will skip SwanLab to avoid conflicts\\n\"\n            )\n\n            if mode == \"cloud\":\n                LOG.info(\n                    f\"  - Only rank 0 will upload to SwanLab cloud\\n\"\n                    f\"  - Other ranks run without SwanLab overhead\\n\"\n                    f\"{'=' * 70}\\n\"\n                )\n\n        # Only initialize SwanLab on the main process (rank 0)\n        # to avoid creating multiple runs in distributed training\n        if not is_main_process():\n            LOG.debug(\"Skipping SwanLab initialization on non-main process\")\n            return\n\n        # Initialize SwanLab run (passing all params directly to init)\n        try:\n            init_kwargs = self._get_swanlab_init_kwargs(cfg)\n            swanlab.init(**init_kwargs)\n            self.swanlab_initialized = True\n            LOG.info(f\"SwanLab initialized with project: {cfg.swanlab_project}\")\n\n            # Register Lark notification callback (if configured)\n            self._register_lark_callback(cfg)\n\n            # Log configuration (with error handling)\n            try:\n                config_dict = self._prepare_config_for_logging(cfg)\n                swanlab.config.update(config_dict)\n                LOG.debug(\"Successfully logged config to SwanLab\")\n            except Exception as config_err:  # pylint: disable=broad-except\n                LOG.warning(\n                    f\"Failed to log config to SwanLab: {config_err}. Continuing anyway.\"\n                )\n\n        except Exception as err:  # pylint: disable=broad-except\n            LOG.exception(\"Failed to initialize SwanLab: %s\", err)\n            self.swanlab_initialized = False\n\n    def add_callbacks_pre_trainer(self, cfg: DictDefault, model):\n        \"\"\"Add SwanLab callbacks before trainer creation.\"\"\"\n        callbacks: list[TrainerCallback] = []\n\n        if not cfg.use_swanlab:\n            return callbacks\n\n        if not self.swanlab_initialized:\n            LOG.warning(\"SwanLab not initialized, skipping callback registration\")\n            return callbacks\n\n        try:\n            from axolotl.utils.callbacks.swanlab import (\n                CustomSwanLabCallback,\n                SaveAxolotlConfigtoSwanLabCallback,\n            )\n\n            # Add our custom lightweight SwanLabCallback\n            # (avoids omegaconf/antlr4 version conflicts)\n            swanlab_callback = CustomSwanLabCallback()\n            callbacks.append(swanlab_callback)\n            LOG.info(\"Added CustomSwanLabCallback for metrics logging\")\n\n            # Add Axolotl config logging callback\n            if cfg.axolotl_config_path:\n                config_callback = SaveAxolotlConfigtoSwanLabCallback(\n                    cfg.axolotl_config_path\n                )\n                callbacks.append(config_callback)\n                LOG.info(\"Added SaveAxolotlConfigtoSwanLabCallback\")\n\n        except ImportError as err:\n            LOG.exception(\"Failed to import SwanLab callbacks: %s\", err)\n\n        return callbacks\n\n    def post_trainer_create(self, cfg: DictDefault, trainer):\n        \"\"\"Post-trainer creation hook.\"\"\"\n        if cfg.use_swanlab and self.swanlab_initialized:\n            try:\n                import swanlab\n\n                # Log additional trainer information (with safe conversion)\n                trainer_config = {\n                    \"total_steps\": int(trainer.state.max_steps)\n                    if trainer.state.max_steps\n                    else None,\n                    \"num_train_epochs\": float(trainer.args.num_train_epochs)\n                    if trainer.args.num_train_epochs\n                    else None,\n                    \"train_batch_size\": int(trainer.args.train_batch_size)\n                    if hasattr(trainer.args, \"train_batch_size\")\n                    else None,\n                    \"gradient_accumulation_steps\": int(\n                        trainer.args.gradient_accumulation_steps\n                    )\n                    if trainer.args.gradient_accumulation_steps\n                    else None,\n                }\n                # Remove None values\n                trainer_config = {\n                    k: v for k, v in trainer_config.items() if v is not None\n                }\n\n                if trainer_config:\n                    swanlab.config.update(trainer_config)\n                    LOG.info(\"Logged trainer configuration to SwanLab\")\n            except Exception as err:  # pylint: disable=broad-except\n                LOG.debug(f\"Failed to log trainer config to SwanLab: {err}\")\n\n            # Register RLHF completion logging callback if enabled\n            self._register_completion_callback(cfg, trainer)\n\n    def _get_swanlab_init_kwargs(self, cfg: DictDefault) -> dict:\n        \"\"\"Prepare kwargs for swanlab.init().\n\n        Passes all configuration parameters directly to swanlab.init()\n        instead of using environment variables as an intermediate layer.\n\n        Returns:\n            dict: Keyword arguments for swanlab.init()\n        \"\"\"\n        init_kwargs = {}\n\n        # Project name (required)\n        if cfg.swanlab_project:\n            init_kwargs[\"project\"] = cfg.swanlab_project\n\n        # Experiment name\n        if cfg.swanlab_experiment_name:\n            init_kwargs[\"experiment_name\"] = cfg.swanlab_experiment_name\n\n        # Description\n        if cfg.swanlab_description:\n            init_kwargs[\"description\"] = cfg.swanlab_description\n\n        # Workspace (organization)\n        if cfg.swanlab_workspace:\n            init_kwargs[\"workspace\"] = cfg.swanlab_workspace\n\n        # Mode: cloud, local, offline, disabled\n        if cfg.swanlab_mode:\n            init_kwargs[\"mode\"] = cfg.swanlab_mode\n\n        # API key (pass directly instead of via env var)\n        if cfg.swanlab_api_key:\n            init_kwargs[\"api_key\"] = cfg.swanlab_api_key\n\n        # Private deployment hosts (pass directly instead of via env var)\n        if cfg.swanlab_web_host:\n            init_kwargs[\"web_host\"] = cfg.swanlab_web_host\n\n        if cfg.swanlab_api_host:\n            init_kwargs[\"api_host\"] = cfg.swanlab_api_host\n\n        # Log model checkpoints (coming soon in SwanLab)\n        if cfg.swanlab_log_model:\n            init_kwargs[\"log_model\"] = cfg.swanlab_log_model\n\n        # Custom branding - adds Axolotl identifier to SwanLab UI\n        # This helps identify runs from Axolotl vs other frameworks\n        init_kwargs[\"config\"] = {\"UPPERFRAME\": \"🦎 Axolotl\"}\n\n        return init_kwargs\n\n    def _prepare_config_for_logging(self, cfg: DictDefault) -> dict:\n        \"\"\"Prepare configuration dict for logging to SwanLab.\"\"\"\n\n        def safe_convert(value):\n            \"\"\"Convert value to JSON-serializable type.\"\"\"\n            if value is None:\n                return None\n            if isinstance(value, (int, float, bool)):\n                return value\n            if isinstance(value, str):\n                return value\n            # Convert everything else to string\n            return str(value)\n\n        try:\n            # Extract important training parameters with safe conversion\n            config_dict = {\n                \"base_model\": safe_convert(getattr(cfg, \"base_model\", \"\")),\n                \"model_type\": safe_convert(getattr(cfg, \"model_type\", \"\")),\n                \"sequence_len\": safe_convert(getattr(cfg, \"sequence_len\", None)),\n                \"micro_batch_size\": safe_convert(\n                    getattr(cfg, \"micro_batch_size\", None)\n                ),\n                \"gradient_accumulation_steps\": safe_convert(\n                    getattr(cfg, \"gradient_accumulation_steps\", None)\n                ),\n                \"num_epochs\": safe_convert(getattr(cfg, \"num_epochs\", None)),\n                \"max_steps\": safe_convert(getattr(cfg, \"max_steps\", None)),\n                \"learning_rate\": safe_convert(getattr(cfg, \"learning_rate\", None)),\n                \"lr_scheduler\": safe_convert(getattr(cfg, \"lr_scheduler\", \"\")),\n                \"optimizer\": safe_convert(getattr(cfg, \"optimizer\", \"\")),\n                \"warmup_ratio\": safe_convert(getattr(cfg, \"warmup_ratio\", None)),\n                \"weight_decay\": safe_convert(getattr(cfg, \"weight_decay\", None)),\n                \"seed\": safe_convert(getattr(cfg, \"seed\", None)),\n                \"bf16\": safe_convert(getattr(cfg, \"bf16\", None)),\n                \"tf32\": safe_convert(getattr(cfg, \"tf32\", None)),\n                \"flash_attention\": safe_convert(getattr(cfg, \"flash_attention\", None)),\n                \"sample_packing\": safe_convert(getattr(cfg, \"sample_packing\", None)),\n            }\n\n            # Add FSDP/parallel config - only boolean flags\n            if hasattr(cfg, \"fsdp_config\") and cfg.fsdp_config:\n                config_dict[\"fsdp_enabled\"] = True\n                config_dict[\"fsdp_version\"] = safe_convert(\n                    getattr(cfg, \"fsdp_version\", None)\n                )\n\n            if hasattr(cfg, \"deepspeed\") and cfg.deepspeed:\n                config_dict[\"deepspeed_enabled\"] = True\n\n            # Add context parallel info\n            if hasattr(cfg, \"context_parallel_size\"):\n                config_dict[\"context_parallel_size\"] = safe_convert(\n                    getattr(cfg, \"context_parallel_size\", None)\n                )\n            if hasattr(cfg, \"tensor_parallel_size\"):\n                config_dict[\"tensor_parallel_size\"] = safe_convert(\n                    getattr(cfg, \"tensor_parallel_size\", None)\n                )\n            if hasattr(cfg, \"dp_shard_size\"):\n                config_dict[\"dp_shard_size\"] = safe_convert(\n                    getattr(cfg, \"dp_shard_size\", None)\n                )\n\n            # Remove None values and empty strings\n            config_dict = {\n                k: v\n                for k, v in config_dict.items()\n                if v is not None and v != \"\" and v != \"None\"\n            }\n\n            return config_dict\n        except Exception as err:  # pylint: disable=broad-except\n            LOG.warning(f\"Failed to prepare config for logging: {err}\")\n            # Return minimal config\n            try:\n                lr = getattr(cfg, \"learning_rate\", None)\n                lr_value = float(lr) if lr is not None else None\n            except (TypeError, ValueError):\n                lr_value = None\n            return {\n                \"base_model\": str(getattr(cfg, \"base_model\", \"unknown\")),\n                \"learning_rate\": lr_value,\n            }\n\n    def _register_lark_callback(self, cfg: DictDefault):\n        \"\"\"Register Lark (Feishu) notification callback if configured.\n\n        Lark notifications enable sending training updates to team chat channels,\n        useful for production monitoring and team collaboration.\n\n        Args:\n            cfg: Configuration object with Lark webhook settings\n        \"\"\"\n        # Check if Lark webhook URL is configured\n        lark_webhook_url = getattr(cfg, \"swanlab_lark_webhook_url\", None)\n        if not lark_webhook_url:\n            return  # Lark not configured, skip\n\n        try:\n            import swanlab\n            from swanlab.plugin.notification import LarkCallback\n\n            # Get optional secret for HMAC signature authentication\n            lark_secret = getattr(cfg, \"swanlab_lark_secret\", None)\n\n            # Create Lark callback with webhook URL and optional secret\n            lark_callback = LarkCallback(\n                webhook_url=lark_webhook_url,\n                secret=lark_secret,\n            )\n\n            # Register callback with SwanLab\n            swanlab.register_callbacks([lark_callback])\n\n            if lark_secret:\n                LOG.info(\n                    \"Registered Lark notification callback with HMAC authentication\"\n                )\n            else:\n                LOG.info(\"Registered Lark notification callback (no HMAC secret)\")\n                LOG.warning(\n                    \"Lark webhook has no secret configured. \"\n                    \"For production use, set 'swanlab_lark_secret' to enable HMAC signature verification.\"\n                )\n\n        except ImportError as err:\n            LOG.warning(\n                f\"Failed to import SwanLab Lark plugin: {err}\\n\\n\"\n                \"Lark notifications require SwanLab >= 0.3.0 with plugin support.\\n\"\n                \"Install with: pip install 'swanlab>=0.3.0'\\n\\n\"\n                \"Continuing without Lark notifications...\"\n            )\n        except Exception as err:  # pylint: disable=broad-except\n            LOG.exception(\n                \"Failed to register Lark callback: %s\\n\\n\"\n                \"Check your Lark webhook URL and secret configuration.\\n\"\n                \"Continuing without Lark notifications...\",\n                err,\n            )\n\n    def _register_completion_callback(self, cfg: DictDefault, trainer):\n        \"\"\"Register RLHF completion logging callback if enabled and applicable.\n\n        This callback logs model completions (prompts, chosen/rejected responses,\n        rewards) to SwanLab during RLHF training for qualitative analysis.\n\n        Args:\n            cfg: Configuration object with completion logging settings\n            trainer: The trainer instance to add callback to\n        \"\"\"\n        # Check if completion logging is enabled\n        log_completions = getattr(cfg, \"swanlab_log_completions\", True)\n        if not log_completions:\n            LOG.debug(\"SwanLab completion logging disabled by config\")\n            return\n\n        # Check if trainer is an RLHF trainer\n        trainer_name = trainer.__class__.__name__\n        rlhf_trainers = [\"DPO\", \"KTO\", \"ORPO\", \"GRPO\", \"CPO\"]\n        is_rlhf_trainer = any(name in trainer_name for name in rlhf_trainers)\n\n        if not is_rlhf_trainer:\n            LOG.debug(\n                f\"Trainer {trainer_name} is not an RLHF trainer, \"\n                \"skipping completion logging callback\"\n            )\n            return\n\n        try:\n            from axolotl.integrations.swanlab.callbacks import (\n                SwanLabRLHFCompletionCallback,\n            )\n\n            # Get configuration parameters\n            log_interval = getattr(cfg, \"swanlab_completion_log_interval\", 100)\n            max_buffer = getattr(cfg, \"swanlab_completion_max_buffer\", 128)\n\n            # Create and register callback\n            completion_callback = SwanLabRLHFCompletionCallback(\n                log_interval=log_interval,\n                max_completions=max_buffer,\n                table_name=\"rlhf_completions\",\n            )\n\n            trainer.add_callback(completion_callback)\n\n            LOG.info(\n                f\"Registered SwanLab RLHF completion logging callback for {trainer_name} \"\n                f\"(log_interval={log_interval}, max_buffer={max_buffer})\"\n            )\n\n        except ImportError as err:\n            LOG.warning(\n                f\"Failed to import SwanLab completion callback: {err}\\n\\n\"\n                \"This is a bug - the callback should be available.\\n\"\n                \"Please report this issue.\\n\\n\"\n                \"Continuing without completion logging...\"\n            )\n        except Exception as err:  # pylint: disable=broad-except\n            LOG.exception(\n                \"Failed to register SwanLab completion callback: %s\\n\\n\"\n                \"Continuing without completion logging...\",\n                err,\n            )\n"
  },
  {
    "path": "src/axolotl/integrations/swanlab/profiling.py",
    "content": "\"\"\"SwanLab profiling utilities for Axolotl trainers.\n\nThis module provides decorators and context managers for profiling\ntrainer methods and logging execution times to SwanLab.\n\"\"\"\n\nimport time\nfrom contextlib import contextmanager\nfrom functools import wraps\nfrom typing import Any, Callable\n\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\n\n@contextmanager\ndef swanlab_profiling_context(trainer: Any, func_name: str):\n    \"\"\"Context manager for profiling trainer methods.\n\n    Measures execution time and logs to SwanLab if enabled.\n\n    Example usage:\n        >>> with swanlab_profiling_context(self, \"training_step\"):\n        ...     result = do_expensive_computation()\n\n    Args:\n        trainer: Trainer instance (must have cfg attribute with use_swanlab flag)\n        func_name: Name of the function being profiled\n\n    Yields:\n        None\n    \"\"\"\n    start_time = time.perf_counter()\n    try:\n        yield\n    finally:\n        duration = time.perf_counter() - start_time\n\n        # Check if SwanLab is enabled and initialized\n        use_swanlab = getattr(getattr(trainer, \"cfg\", None), \"use_swanlab\", False)\n        if use_swanlab:\n            try:\n                import swanlab\n\n                if swanlab.get_run() is not None:\n                    # Log profiling metric\n                    trainer_class = trainer.__class__.__name__\n                    metric_name = f\"profiling/Time taken: {trainer_class}.{func_name}\"\n\n                    swanlab.log({metric_name: duration})\n\n            except ImportError:\n                # SwanLab not installed, silently skip\n                pass\n            except Exception as err:  # pylint: disable=broad-except\n                # Log error but don't fail training\n                LOG.debug(f\"Failed to log profiling metric for {func_name}: {err}\")\n\n\ndef swanlab_profile(func: Callable) -> Callable:\n    \"\"\"Decorator to profile and log function execution time to SwanLab.\n\n    Automatically measures execution time of trainer methods and logs\n    to SwanLab as profiling metrics.\n\n    Example usage:\n        >>> class MyTrainer:\n        ...     @swanlab_profile\n        ...     def training_step(self, model, inputs):\n        ...         return super().training_step(model, inputs)\n\n    Args:\n        func: Function to profile (must be a method of a trainer instance)\n\n    Returns:\n        Wrapped function with profiling\n    \"\"\"\n\n    @wraps(func)\n    def wrapper(self, *args, **kwargs):\n        with swanlab_profiling_context(self, func.__name__):\n            return func(self, *args, **kwargs)\n\n    return wrapper\n\n\nclass ProfilingConfig:\n    \"\"\"Configuration for SwanLab profiling.\n\n    This class provides a centralized way to control profiling behavior.\n\n    Attributes:\n        enabled: Whether profiling is enabled globally\n        min_duration_ms: Minimum duration (in ms) to log (filters out very fast ops)\n        log_interval: Log every N function calls (to reduce overhead)\n    \"\"\"\n\n    def __init__(\n        self,\n        enabled: bool = True,\n        min_duration_ms: float = 0.1,\n        log_interval: int = 1,\n    ):\n        \"\"\"Initialize profiling configuration.\n\n        Args:\n            enabled: Enable profiling. Default: True\n            min_duration_ms: Minimum duration to log (ms). Default: 0.1\n            log_interval: Log every N calls. Default: 1 (log all)\n        \"\"\"\n        self.enabled = enabled\n        self.min_duration_ms = min_duration_ms\n        self.log_interval = log_interval\n        self._call_counts: dict[str, int] = {}\n\n    def should_log(self, func_name: str, duration_seconds: float) -> bool:\n        \"\"\"Check if a profiling measurement should be logged.\n\n        Args:\n            func_name: Name of the profiled function\n            duration_seconds: Execution duration in seconds\n\n        Returns:\n            True if should log, False otherwise\n        \"\"\"\n        if not self.enabled:\n            return False\n\n        # Check minimum duration threshold\n        duration_ms = duration_seconds * 1000\n        if duration_ms < self.min_duration_ms:\n            return False\n\n        # Check log interval\n        self._call_counts.setdefault(func_name, 0)\n        self._call_counts[func_name] += 1\n\n        # Always log on first call OR at intervals\n        count = self._call_counts[func_name]\n        if count == 1 or count % self.log_interval == 0:\n            return True\n\n        return False\n\n\n# Global profiling config (can be modified by users)\nDEFAULT_PROFILING_CONFIG = ProfilingConfig()\n\n\n@contextmanager\ndef swanlab_profiling_context_advanced(\n    trainer: Any,\n    func_name: str,\n    config: ProfilingConfig | None = None,\n):\n    \"\"\"Advanced profiling context with configurable behavior.\n\n    Similar to swanlab_profiling_context but with additional configuration\n    options for filtering and throttling profiling logs.\n\n    Example usage:\n        >>> config = ProfilingConfig(min_duration_ms=1.0, log_interval=10)\n        >>> with swanlab_profiling_context_advanced(self, \"forward\", config):\n        ...     output = model(inputs)\n\n    Args:\n        trainer: Trainer instance\n        func_name: Function name\n        config: Profiling configuration. If None, uses DEFAULT_PROFILING_CONFIG\n\n    Yields:\n        None\n    \"\"\"\n    if config is None:\n        config = DEFAULT_PROFILING_CONFIG\n\n    start_time = time.perf_counter()\n    try:\n        yield\n    finally:\n        duration = time.perf_counter() - start_time\n\n        # Check if should log based on config\n        if config.should_log(func_name, duration):\n            # Check if SwanLab is enabled\n            use_swanlab = getattr(getattr(trainer, \"cfg\", None), \"use_swanlab\", False)\n            if use_swanlab:\n                try:\n                    import swanlab\n\n                    if swanlab.get_run() is not None:\n                        trainer_class = trainer.__class__.__name__\n                        metric_name = (\n                            f\"profiling/Time taken: {trainer_class}.{func_name}\"\n                        )\n\n                        swanlab.log({metric_name: duration})\n\n                except ImportError:\n                    pass\n                except Exception as err:  # pylint: disable=broad-except\n                    LOG.debug(f\"Failed to log profiling metric for {func_name}: {err}\")\n"
  },
  {
    "path": "src/axolotl/kernels/__init__.py",
    "content": ""
  },
  {
    "path": "src/axolotl/kernels/geglu.py",
    "content": "\"\"\"Module for definition of GEGLU Triton kernels.\n\nSee \"GLU Variants Improve Transformer\" (https://arxiv.org/abs/2002.05202).\n\nCredit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation.\n\"\"\"\n\nimport torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef _geglu_fwd_kernel(\n    gate_ptr,\n    up_ptr,\n    out_ptr,\n    n_elements,\n    BLOCK_SIZE: tl.constexpr,\n):\n    \"\"\"GEGLU forward kernel.\n\n    Args:\n        gate_ptr: Pointer to gate tensor [*, hidden_dim].\n        up_ptr: Pointer to up-projection tensor [*, hidden_dim].\n        out_ptr: Pointer to output tensor [*, hidden_dim].\n        n_elements: Total number of elements in the input tensors.\n        BLOCK_SIZE: Size of thread blocks for parallel computation.\n    \"\"\"\n    block_idx = tl.program_id(0)\n    offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n    mask = offsets < n_elements\n\n    gate = tl.load(gate_ptr + offsets, mask=mask, other=0).to(tl.float32)\n    up = tl.load(up_ptr + offsets, mask=mask, other=0)\n\n    # Compute activation in fp32 then convert back\n    gelu_gate = 0.5 * gate * (tl.math.erf(tl.math.rsqrt(2.0) * gate) + 1.0)\n    gelu_gate = gelu_gate.to(up.dtype)\n    result = gelu_gate * up\n\n    tl.store(out_ptr + offsets, result, mask=mask)\n\n\ndef geglu_forward(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor:\n    \"\"\"GEGLU forward pass.\n\n    Args:\n        gate: Input gate tensor of shape [batch, seq_len, hidden_dim].\n        up: Up-projection tensor of shape [batch, seq_len, hidden_dim].\n\n    Returns:\n        torch.Tensor: Output tensor of shape [batch, seq_len, hidden_dim].\n    \"\"\"\n    batch, seq_len, hidden_dim = gate.shape\n    n_elements = gate.numel()\n    out = torch.empty((batch, seq_len, hidden_dim), dtype=gate.dtype, device=\"cuda\")\n\n    grid = lambda meta: (triton.cdiv(n_elements, meta[\"BLOCK_SIZE\"]),)  # noqa: E731\n    _geglu_fwd_kernel[grid](\n        gate_ptr=gate,\n        up_ptr=up,\n        out_ptr=out,\n        n_elements=n_elements,\n        BLOCK_SIZE=1024,\n    )\n    return out\n\n\n@triton.jit\ndef _geglu_bwd_kernel(\n    grad_out_ptr,\n    gate_ptr,\n    up_ptr,\n    n_elements,\n    BLOCK_SIZE: tl.constexpr,\n):\n    \"\"\"GEGLU backward kernel. Stores gradient results in-place.\n\n    Args:\n        grad_out_ptr: Pointer to gradient output tensor [*, hidden_dim].\n        gate_ptr: Pointer to gate tensor [*, hidden_dim].\n        up_ptr: Pointer to up-projection tensor [*, hidden_dim].\n        n_elements: Total number of elements in the input tensors.\n        BLOCK_SIZE: Size of thread blocks for parallel computation.\n\n    Note:\n        After kernel execution, tensors are modified in-place:\n        - `grad_out_ptr` contains GEGLU activation output (`h`)\n        - `gate_ptr` contains gradient w.r.t gate (`grad_gate`)\n        - `up_ptr` contains gradient w.r.t up (`grad_up`)\n    \"\"\"\n    block_idx = tl.program_id(0)\n    offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n    mask = offsets < n_elements\n\n    grad_out = tl.load(grad_out_ptr + offsets, mask=mask, other=0)\n    gate = tl.load(gate_ptr + offsets, mask=mask, other=0).to(tl.float32)\n    up = tl.load(up_ptr + offsets, mask=mask, other=0)\n\n    # Forward pass\n    gelu_partial = 0.5 * (tl.math.erf(tl.math.rsqrt(2.0) * gate) + 1.0)\n    gelu_gate = gelu_partial * gate\n    gelu_gate = gelu_gate.to(grad_out.dtype)\n\n    # Forward output\n    h = gelu_gate * up\n\n    # Compute gradients\n    grad_up = grad_out * gelu_gate\n\n    # Compute gate gradient using GELU derivative\n    temp = grad_out * up\n    t = 0.3989422804014327  # 1/sqrt(2*pi)\n    dgelu_dgate = gelu_partial + t * gate * tl.exp(-0.5 * gate * gate)\n    grad_gate = temp.to(tl.float32) * dgelu_dgate\n    grad_gate = grad_gate.to(grad_out.dtype)\n\n    # Store results\n    tl.store(grad_out_ptr + offsets, h, mask=mask)\n    tl.store(gate_ptr + offsets, grad_gate, mask=mask)\n    tl.store(up_ptr + offsets, grad_up, mask=mask)\n\n\ndef geglu_backward(\n    grad_output: torch.Tensor, gate: torch.Tensor, up: torch.Tensor\n) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n    \"\"\"GEGLU backward pass using in-place operations.\n\n    Args:\n        grad_output: Gradient of loss with respect to output, shape `[batch, seq_len, hidden_dim]`.\n        gate: Gate tensor from forward pass, shape `[batch, seq_len, hidden_dim]`.\n        up: Up-projection tensor from forward pass, shape `[batch, seq_len, hidden_dim]`.\n\n    Returns:\n        Tuple containing:\n            - GEGLU activation output (`h`)\n            - Gradient with respect to gate (`grad_gate`)\n            - Gradient with respect to up (`grad_up`)\n\n    Note:\n        This function modifies its input tensors in-place to store results.\n    \"\"\"\n    n_elements = grad_output.numel()\n\n    grid = lambda meta: (triton.cdiv(n_elements, meta[\"BLOCK_SIZE\"]),)  # noqa: E731\n    _geglu_bwd_kernel[grid](\n        grad_out_ptr=grad_output,\n        gate_ptr=gate,\n        up_ptr=up,\n        n_elements=n_elements,\n        BLOCK_SIZE=1024,\n    )\n\n    return grad_output, gate, up\n"
  },
  {
    "path": "src/axolotl/kernels/lora.py",
    "content": "\"\"\"\nModule for definition of Low-Rank Adaptation (LoRA) Triton kernels.\n\nSee \"LoRA: Low-Rank Adaptation of Large Language Models\"\n(https://arxiv.org/abs/2106.09685).\n\nCredit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation.\n\"\"\"\n\nfrom typing import Callable\n\nimport torch\nfrom bitsandbytes.functional import QuantState\nfrom torch import nn\nfrom torch.distributed.tensor import DTensor\n\nfrom .geglu import geglu_backward, geglu_forward\nfrom .quantize import dequantize\nfrom .swiglu import swiglu_backward, swiglu_forward\nfrom .utils import torch_amp_custom_bwd, torch_amp_custom_fwd\n\n\ndef get_lora_parameters(\n    proj: nn.Module,\n) -> tuple[\n    torch.Tensor,\n    torch.Tensor | None,\n    QuantState | torch.Tensor | None,\n    torch.Tensor | None,\n    torch.Tensor | None,\n    float | None,\n]:\n    \"\"\"\n    Gets LoRA parameters from a projection module.\n\n    Args:\n        proj: The projection module to extract parameters from.\n\n    Returns:\n        A tuple containing the base weights, quantization state, LoRA A and B weights,\n        scaling factor, and base layer bias. Quant state, weights, and bias may be\n        `None` if not available.\n    \"\"\"\n    # For DPO or disabled adapters\n    base_layer = proj.base_layer if hasattr(proj, \"base_layer\") else proj\n    W = base_layer.weight\n    b = base_layer.bias\n\n    if not hasattr(proj, \"disable_adapters\") or proj.disable_adapters or proj.merged:\n        quant_state = getattr(W, \"quant_state\", None)\n        if quant_state is None and W.dtype == torch.float8_e4m3fn:\n            quant_state = getattr(base_layer, \"weight_scale_inv\", None)\n        return W, b, quant_state, None, None, None\n\n    quant_state = getattr(W, \"quant_state\", None)\n    if quant_state is None and W.dtype == torch.float8_e4m3fn:\n        quant_state = getattr(base_layer, \"weight_scale_inv\", None)\n\n    active_adapter = (\n        proj.active_adapters[0]\n        if hasattr(proj, \"active_adapters\")\n        else proj.active_adapter\n    )\n\n    linear_A = proj.lora_A[active_adapter]\n    linear_B = proj.lora_B[active_adapter]\n\n    # This manual unsharding is needed for FSDP2 + LoRA kernels compatibility.\n    # We fuse linear layers + LoRA adapters calculations into a single\n    # torch.autograd.Function, bypassing the registered unshard / reshard behavior.\n    # Note that we don't apply resharding later in this module (it gets messy quickly),\n    # but LoRA parameters are generally small enough that this is not an issue.\n    if isinstance(linear_A.weight, DTensor):\n        linear_A.unshard()\n        linear_B.unshard()\n\n    A = linear_A.weight\n    B = linear_B.weight\n    s = proj.scaling[active_adapter]\n\n    return W, b, quant_state, A, B, s\n\n\ndef matmul_lora(\n    X: torch.Tensor,\n    W: torch.Tensor,\n    b: torch.Tensor | None,\n    W_quant: QuantState | torch.Tensor | None,\n    A: torch.Tensor | None,\n    B: torch.Tensor | None,\n    s: float | None,\n    out: torch.Tensor | None = None,\n) -> torch.Tensor:\n    \"\"\"\n    Efficient fused matmul + LoRA computation.\n\n    Args:\n        X: Input tensor [*, in_features]\n        W: Base weight matrix [out_features, in_features]\n        W_quant: Quantization state for W\n        A: LoRA A matrix [rank, in_features]\n        B: LoRA B matrix [out_features, rank]\n        s: LoRA scaling factor\n        out: Optional output tensor for inplace operations\n\n    Returns:\n        Result of X @ W + X @ A @ B\n    \"\"\"\n    dtype = X.dtype\n    W = dequantize(W.t(), W_quant)\n\n    reshape = False\n    if X.dim() == 3:\n        batch, seq_len, _ = X.shape\n        X = X.view(-1, X.shape[-1])\n        reshape = True\n\n    out = torch.matmul(X, W, out=out)\n    if W_quant is not None:\n        del W\n\n    if A is not None:\n        A, B = A.t().to(dtype), B.t().to(dtype)  # type: ignore[union-attr]\n        out += s * X @ A @ B\n\n    if b is not None:\n        out += b\n\n    return out.view(batch, seq_len, -1) if reshape else out\n\n\nclass LoRA_MLP(torch.autograd.Function):\n    \"\"\"Optimized LoRA MLP implementation.\"\"\"\n\n    @staticmethod\n    @torch_amp_custom_fwd\n    def forward(\n        ctx,\n        X: torch.Tensor,\n        gate_weight: torch.Tensor,\n        gate_bias: torch.Tensor | None,\n        gate_quant: QuantState | None,\n        gate_A: torch.Tensor | None,\n        gate_B: torch.Tensor | None,\n        gate_scale: float,\n        up_weight: torch.Tensor,\n        up_bias: torch.Tensor | None,\n        up_quant: QuantState | None,\n        up_A: torch.Tensor | None,\n        up_B: torch.Tensor | None,\n        up_scale: float,\n        down_weight: torch.Tensor,\n        down_bias: torch.Tensor | None,\n        down_quant: QuantState | None,\n        down_A: torch.Tensor | None,\n        down_B: torch.Tensor | None,\n        down_scale: float,\n        activation_fn: Callable,\n        activation_fn_backward: Callable,\n        inplace: bool | None = True,\n    ) -> torch.Tensor:\n        \"\"\"\n        Forward pass for LoRA MLP.\n\n        Args:\n            ctx: Autograd context\n            X: Input features\n            gate_weight: Gate projection weight\n            gate_bias: Gate projection bias\n            gate_quant: Gate quantization state\n            gate_A: Gate LoRA A matrix\n            gate_B: Gate LoRA B matrix\n            gate_scale: Gate LoRA scale\n            up_weight: Up projection weight\n            up_quant: Up projection quantization state\n            up_A: Up projection LoRA A matrix\n            up_B: Up projection LoRA B matrix\n            up_scale: Up projection LoRA scale\n            down_weight: Down projection weight\n            down_bias: Down projection bias\n            down_quant: Down projection quantization state\n            down_A: Down projection LoRA A matrix\n            down_B: Down projection LoRA B matrix\n            down_scale: Down projection LoRA scale\n            activation_fn: Forward activation function\n            activation_fn_backward: Backward activation function\n            inplace: Whether to perform operations in-place\n\n        Returns:\n            Output transformed by multi-layer perceptron and activation function\n        \"\"\"\n        # Compute projections\n        gate = matmul_lora(\n            X, gate_weight, gate_bias, gate_quant, gate_A, gate_B, gate_scale\n        )\n        up = matmul_lora(X, up_weight, up_bias, up_quant, up_A, up_B, up_scale)\n\n        # Activation\n        hidden = activation_fn(gate, up)\n\n        # Down projection\n        output = matmul_lora(\n            hidden, down_weight, down_bias, down_quant, down_A, down_B, down_scale\n        )\n\n        # Save for backward\n        ctx.save_for_backward(X, gate, up, gate_A, gate_B, up_A, up_B, down_A, down_B)\n        ctx.scales = (gate_scale, up_scale, down_scale)\n        ctx.quants = (gate_quant, up_quant, down_quant)\n        ctx.weights = (gate_weight, up_weight, down_weight)\n        ctx.activation_fn = activation_fn\n        ctx.activation_fn_backward = activation_fn_backward\n        ctx.inplace = inplace\n\n        return output\n\n    @staticmethod\n    @torch_amp_custom_bwd\n    def backward(\n        ctx: torch.autograd.function.FunctionCtx,\n        grad_output: torch.Tensor,\n    ) -> tuple[\n        torch.Tensor | None,\n        None,\n        None,\n        None,\n        torch.Tensor | None,\n        torch.Tensor | None,\n        None,\n        None,\n        None,\n        None,\n        torch.Tensor | None,\n        torch.Tensor | None,\n        None,\n        None,\n        None,\n        None,\n        torch.Tensor | None,\n        torch.Tensor | None,\n        None,\n        None,\n        None,\n        None,\n        None,\n    ]:\n        \"\"\"\n        Performs backward pass computation for LoRA MLP.\n\n        Args:\n            ctx: Context object storing tensors saved during forward pass\n            grad_output: Gradient of loss with respect to layer output\n\n        Returns:\n            Tuple containing gradients for all inputs from forward pass:\n            - Input gradient tensor (or `None`)\n            - `None` for weights/biases/quantization states\n            - LoRA A/B matrix gradients (or `None`)\n            - `None` for scaling factors\n            - `None` for activation functions and flags\n        \"\"\"\n        (\n            X,\n            gate,\n            up,\n            gate_A,\n            gate_B,\n            up_A,\n            up_B,\n            down_A,\n            down_B,\n        ) = ctx.saved_tensors\n        gate_scale, up_scale, down_scale = ctx.scales\n        gate_quant, up_quant, down_quant = ctx.quants\n        gate_weight, up_weight, down_weight = ctx.weights\n\n        # Transpose all LoRA matrices\n        gate_A, gate_B = (\n            gate_A.t() if gate_A is not None else None,\n            gate_B.t() if gate_B is not None else None,\n        )\n        up_A, up_B = (\n            up_A.t() if up_A is not None else None,\n            up_B.t() if up_B is not None else None,\n        )\n        down_A, down_B = (\n            down_A.t() if down_A is not None else None,\n            down_B.t() if down_B is not None else None,\n        )\n\n        # Reshape inputs\n        batch, seq_len, hd = X.shape\n        grad_output = grad_output.view(-1, grad_output.shape[-1])\n        X = X.view(-1, X.shape[-1])\n        gate = gate.view(-1, gate.shape[-1])\n        up = up.view(-1, up.shape[-1])\n        dtype = X.dtype\n\n        # Down projection\n        grad_down = matmul_lora(\n            grad_output,\n            down_weight.t(),\n            None,\n            down_quant,\n            down_B,\n            down_A,\n            down_scale,\n        )\n\n        # Activation backward\n        h, grad_gate, grad_up = ctx.activation_fn_backward(grad_down, gate, up)\n\n        # Initialize and compute LoRA gradients\n        d_down_A = d_down_B = d_up_A = d_up_B = d_gate_A = d_gate_B = None\n\n        if down_A is not None and down_B is not None:\n            d_down_A = h.t() @ (grad_output @ down_B.t())\n            d_down_B = (down_A.t() @ h.t()) @ grad_output\n            d_down_A *= down_scale\n            d_down_B *= down_scale\n\n        if up_A is not None and up_B is not None:\n            d_up_A = X.t() @ (grad_up @ up_B.t())\n            d_up_B = (up_A.t() @ X.t()) @ grad_up\n            d_up_A *= up_scale\n            d_up_B *= up_scale\n\n        if gate_A is not None and gate_B is not None:\n            d_gate_A = X.t() @ (grad_gate @ gate_B.t())\n            d_gate_B = (gate_A.t() @ X.t()) @ grad_gate\n            d_gate_A *= gate_scale\n            d_gate_B *= gate_scale\n\n        # Compute input gradients\n        dX = torch.zeros_like(X) if ctx.needs_input_grad[0] else None\n\n        if dX is not None:\n            # Up projection gradients\n            up_weight = dequantize(up_weight.t(), up_quant)\n            if ctx.inplace:\n                dX = torch.matmul(grad_up, up_weight.t(), out=X)\n            else:\n                dX = torch.matmul(grad_up, up_weight.t())\n            del up_weight\n\n            # Note the .to(dtype) only where mixing LoRA with base weights\n            if up_A is not None and up_B is not None:\n                dX += grad_up @ up_B.to(dtype).t() @ (up_scale * up_A.to(dtype).t())\n\n            # Gate projection gradients\n            gate_weight = dequantize(gate_weight, gate_quant)\n            dX += grad_gate @ gate_weight\n            del gate_weight\n\n            if gate_A is not None and gate_B is not None:\n                dX += (\n                    grad_gate\n                    @ gate_B.to(dtype).t()\n                    @ (gate_scale * gate_A.to(dtype).t())\n                )\n\n            # Reshape back\n            dX = dX.view(batch, seq_len, hd)\n\n        # Return gradients in correct order matching forward inputs\n        return (\n            dX,\n            None,\n            None,\n            None,\n            d_gate_A.t() if d_gate_A is not None else None,\n            d_gate_B.t() if d_gate_B is not None else None,\n            None,\n            None,\n            None,\n            None,\n            d_up_A.t() if d_up_A is not None else None,\n            d_up_B.t() if d_up_B is not None else None,\n            None,\n            None,\n            None,\n            None,\n            d_down_A.t() if d_down_A is not None else None,\n            d_down_B.t() if d_down_B is not None else None,\n            None,\n            None,\n            None,\n            None,\n            None,\n        )\n\n\ndef apply_lora_mlp_swiglu(self, X: torch.Tensor, inplace: bool = True) -> torch.Tensor:\n    \"\"\"\n    Applies LoRA to MLP layer with SwiGLU activation.\n\n    Args:\n        X: Input tensor for the MLP layer\n        inplace: Whether to perform operations in-place to save memory\n\n    Returns:\n        Output tensor after applying LoRA-adapted MLP with SwiGLU activation\n    \"\"\"\n    gateW, gateb, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)\n    upW, upb, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj)\n    downW, downb, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)\n\n    out = LoRA_MLP.apply(\n        X,\n        gateW,\n        gateb,\n        gateW_quant,\n        gateA,\n        gateB,\n        gateS,\n        upW,\n        upb,\n        upW_quant,\n        upA,\n        upB,\n        upS,\n        downW,\n        downb,\n        downW_quant,\n        downA,\n        downB,\n        downS,\n        swiglu_forward,\n        swiglu_backward,\n        inplace,\n    )\n\n    return out\n\n\ndef apply_lora_mlp_geglu(self, X: torch.Tensor, inplace: bool = True) -> torch.Tensor:\n    \"\"\"\n    Applies LoRA to MLP layer with GEGLU activation.\n\n    Args:\n        X: Input tensor for the MLP layer\n        inplace: Whether to perform operations in-place to save memory\n\n    Returns:\n        Output tensor after applying LoRA-adapted MLP with GEGLU activation\n    \"\"\"\n    gateW, gateb, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)\n    upW, upb, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj)\n    downW, downb, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)\n    out = LoRA_MLP.apply(\n        X,\n        gateW,\n        gateb,\n        gateW_quant,\n        gateA,\n        gateB,\n        gateS,\n        upW,\n        upb,\n        upW_quant,\n        upA,\n        upB,\n        upS,\n        downW,\n        downb,\n        downW_quant,\n        downA,\n        downB,\n        downS,\n        geglu_forward,\n        geglu_backward,\n        inplace,\n    )\n\n    return out\n\n\nclass LoRA_QKV(torch.autograd.Function):\n    \"\"\"\n    Optimized LoRA QKV implementation with quantization support.\n\n    Implements efficient computation of query, key, value projections with LoRA,\n    supporting quantization and memory optimization.\n    \"\"\"\n\n    @staticmethod\n    @torch_amp_custom_fwd\n    def forward(\n        ctx: torch.autograd.function.FunctionCtx,\n        X: torch.Tensor,\n        q_weight: torch.Tensor,\n        q_bias: torch.Tensor | None,\n        q_quant: QuantState | None,\n        q_A: torch.Tensor | None,\n        q_B: torch.Tensor | None,\n        q_scale: float,\n        k_weight: torch.Tensor,\n        k_bias: torch.Tensor | None,\n        k_quant: QuantState | None,\n        k_A: torch.Tensor | None,\n        k_B: torch.Tensor | None,\n        k_scale: float,\n        v_weight: torch.Tensor,\n        v_bias: torch.Tensor | None,\n        v_quant: QuantState | None,\n        v_A: torch.Tensor | None,\n        v_B: torch.Tensor | None,\n        v_scale: float,\n        inplace: bool = True,\n    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n        \"\"\"\n        Forward pass computing Q, K, V projections with LoRA.\n\n        Args:\n            ctx: Autograd context\n            X: Input tensor\n            q_weight: Query projection weight\n            q_bias: Query projection bias\n            q_quant: Query quantization state\n            q_A: Query LoRA A matrix\n            q_B: Query LoRA B matrix\n            q_scale: Query LoRA scale\n            k_weight: Key projection weight\n            k_bias: Key projection bias\n            k_quant: Key quantization state\n            k_A: Key LoRA A matrix\n            k_B: Key LoRA B matrix\n            k_scale: Key LoRA scale\n            v_weight: Value projection weight\n            v_bias: Value projection bias\n            v_quant: Value quantization state\n            v_A: Value LoRA A matrix\n            v_B: Value LoRA B matrix\n            v_scale: Value LoRA scale\n            inplace: Whether to perform operations in-place\n\n        Returns:\n            Tuple of (Query, Key, Value) projection tensors\n        \"\"\"\n        Q = matmul_lora(X, q_weight, q_bias, q_quant, q_A, q_B, q_scale)\n        K = matmul_lora(X, k_weight, k_bias, k_quant, k_A, k_B, k_scale)\n        V = matmul_lora(X, v_weight, v_bias, v_quant, v_A, v_B, v_scale)\n\n        ctx.save_for_backward(X, q_A, q_B, k_A, k_B, v_A, v_B)\n        ctx.scales = (q_scale, k_scale, v_scale)\n        ctx.quants = (q_quant, k_quant, v_quant)\n        ctx.weights = (q_weight, k_weight, v_weight)\n        ctx.biases = (q_bias, k_bias, v_bias)\n        ctx.inplace = inplace\n\n        return Q, K, V\n\n    @staticmethod\n    @torch_amp_custom_bwd\n    def backward(\n        ctx: torch.autograd.function.FunctionCtx,\n        q_grad: torch.Tensor,\n        k_grad: torch.Tensor,\n        v_grad: torch.Tensor,\n    ) -> tuple[\n        torch.Tensor,\n        None,\n        None,\n        None,\n        torch.Tensor | None,\n        torch.Tensor | None,\n        None,\n        None,\n        None,\n        None,\n        torch.Tensor | None,\n        torch.Tensor | None,\n        None,\n        None,\n        None,\n        None,\n        torch.Tensor | None,\n        torch.Tensor | None,\n        None,\n        None,\n    ]:\n        \"\"\"\n        Backward pass computing gradients for LoRA QKV.\n\n        Args:\n            ctx: Autograd context\n            q_grad: Gradient for query projection\n            k_grad: Gradient for key projection\n            v_grad: Gradient for value projection\n\n        Returns:\n            Tuple containing gradients for all forward inputs\n        \"\"\"\n        X, A_q, B_q, A_k, B_k, A_v, B_v = ctx.saved_tensors\n        q_weight, k_weight, v_weight = ctx.weights\n        q_quant, k_quant, v_quant = ctx.quants\n        q_scale, k_scale, v_scale = ctx.scales\n        dtype = X.dtype\n\n        # Reshape gradients\n        batch, seq_len = X.shape[:2]\n        q_grad = q_grad.view(-1, q_grad.shape[-1])\n        k_grad = k_grad.reshape(-1, k_grad.shape[-1])\n        v_grad = v_grad.view(-1, v_grad.shape[-1])\n        X = X.view(-1, X.shape[-1])\n\n        # Pre-transpose X once\n        X_t = X.t()\n\n        # Initialize LoRA gradients as None\n        d_A_q = d_B_q = d_A_k = d_B_k = d_A_v = d_B_v = None\n\n        # Compute q path LoRA gradients if adapters exist\n        if A_q is not None and B_q is not None:\n            A_q_scaled = (q_scale * A_q).to(dtype)\n            B_q_scaled = B_q.to(dtype)\n            d_A_q = torch.mm(X_t, torch.mm(q_grad, B_q_scaled))\n            d_B_q = torch.mm(torch.mm(A_q_scaled, X_t), q_grad)\n\n        # Compute k path LoRA gradients if adapters exist\n        if A_k is not None and B_k is not None:\n            A_k_scaled = (k_scale * A_k).to(dtype)\n            B_k_scaled = B_k.to(dtype)\n            d_A_k = torch.mm(X_t, torch.mm(k_grad, B_k_scaled))\n            d_B_k = torch.mm(torch.mm(A_k_scaled, X_t), k_grad)\n\n        # Compute v path LoRA gradients if adapters exist\n        if A_v is not None and B_v is not None:\n            A_v_scaled = (v_scale * A_v).to(dtype)\n            B_v_scaled = B_v.to(dtype)\n            d_A_v = torch.mm(X_t, torch.mm(v_grad, B_v_scaled))\n            d_B_v = torch.mm(torch.mm(A_v_scaled, X_t), v_grad)\n\n        # Compute input gradient, reusing X memory if possible\n        out_buffer = X if ctx.inplace else None\n\n        # Q path\n        q_weight_t = dequantize(q_weight, q_quant)\n        grad_X = torch.mm(q_grad, q_weight_t, out=out_buffer)\n        del q_weight\n        del q_weight_t\n        if A_q is not None and B_q is not None:\n            # Stay decomposed: dQ @ B^T gives [T, R], then [T, R] @ (s*A) gives [T, in]\n            # This is 65x fewer FLOPs than materializing B@A into [out, in]\n            grad_X.addmm_(torch.mm(q_grad, B_q_scaled), A_q_scaled)\n\n        # K path\n        k_weight_t = dequantize(k_weight, k_quant)\n        grad_X.addmm_(k_grad, k_weight_t)\n        del k_weight\n        del k_weight_t\n        if A_k is not None and B_k is not None:\n            grad_X.addmm_(torch.mm(k_grad, B_k_scaled), A_k_scaled)\n\n        # V path\n        v_weight_t = dequantize(v_weight, v_quant)\n        grad_X.addmm_(v_grad, v_weight_t)\n        del v_weight\n        del v_weight_t\n        if A_v is not None and B_v is not None:\n            grad_X.addmm_(torch.mm(v_grad, B_v_scaled), A_v_scaled)\n\n        # Transpose gradients if needed\n        if d_A_q is not None:\n            d_A_q = d_A_q.t()\n            d_B_q = d_B_q.t()  # type: ignore[union-attr]\n        if d_A_k is not None:\n            d_A_k = d_A_k.t()\n            d_B_k = d_B_k.t()  # type: ignore[union-attr]\n        if d_A_v is not None:\n            d_A_v = d_A_v.t()\n            d_B_v = d_B_v.t()  # type: ignore[union-attr]\n\n        return (\n            grad_X.view(batch, seq_len, -1),\n            None,\n            None,\n            None,\n            d_A_q,\n            d_B_q,\n            None,\n            None,\n            None,\n            None,\n            d_A_k,\n            d_B_k,\n            None,\n            None,\n            None,\n            None,\n            d_A_v,\n            d_B_v,\n            None,\n            None,\n        )\n\n\ndef apply_lora_qkv(\n    self, X: torch.Tensor, inplace: bool = True\n) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n    \"\"\"\n    Applies LoRA to compute Query, Key, Value projections.\n\n    Args:\n        X: Input tensor\n        inplace: Whether to perform operations in-place\n\n    Returns:\n        Tuple of (Query, Key, Value) projection tensors\n    \"\"\"\n    QW, Qb, QW_quant, QA, QB, QS = get_lora_parameters(self.q_proj)\n    KW, Kb, KW_quant, KA, KB, KS = get_lora_parameters(self.k_proj)\n    VW, Vb, VW_quant, VA, VB, VS = get_lora_parameters(self.v_proj)\n    Q, K, V = LoRA_QKV.apply(\n        X,\n        QW,\n        Qb,\n        QW_quant,\n        QA,\n        QB,\n        QS,\n        KW,\n        Kb,\n        KW_quant,\n        KA,\n        KB,\n        KS,\n        VW,\n        Vb,\n        VW_quant,\n        VA,\n        VB,\n        VS,\n        inplace,\n    )\n\n    return Q, K, V\n\n\nclass LoRA_O(torch.autograd.Function):\n    \"\"\"Optimized LoRA implementation for output projection.\"\"\"\n\n    @staticmethod\n    @torch_amp_custom_fwd\n    def forward(\n        ctx: torch.autograd.function.FunctionCtx,\n        X: torch.Tensor,\n        W: torch.Tensor,\n        b: torch.Tensor,\n        W_quant: QuantState | None,\n        A: torch.Tensor,\n        B: torch.Tensor,\n        s: float,\n    ) -> torch.Tensor:\n        \"\"\"\n        Forward pass for output projection with LoRA.\n\n        Args:\n            ctx: Autograd context\n            X: Input tensor\n            W: Output projection weight\n            b: Output projection bias\n            W_quant: Weight quantization state\n            A: LoRA A matrix\n            B: LoRA B matrix\n            s: LoRA scaling factor\n\n        Returns:\n            Output projection result\n        \"\"\"\n        XW = matmul_lora(X, W, b, W_quant, A, B, s)\n        ctx.custom_saved_tensors = (\n            W,\n            W_quant,\n            s,\n        )\n        ctx.save_for_backward(A, B, X)\n\n        return XW\n\n    @staticmethod\n    @torch_amp_custom_bwd\n    def backward(\n        ctx: torch.autograd.function.FunctionCtx,\n        dY: torch.Tensor,\n    ) -> tuple[\n        torch.Tensor,\n        None,\n        None,\n        None,\n        torch.Tensor,\n        torch.Tensor,\n        None,\n    ]:\n        \"\"\"\n        Backward pass computing gradients for LoRA output projection.\n\n        Args:\n            ctx: Autograd context\n            dY: Gradient of loss with respect to output\n\n        Returns:\n            Tuple containing gradients for all forward inputs\n        \"\"\"\n        W, W_quant, s = ctx.custom_saved_tensors\n        A, B, X = ctx.saved_tensors\n\n        batch, seq_len, hd = X.shape\n        dY = dY.reshape(-1, dY.shape[-1])\n        X = X.reshape(-1, X.shape[-1])\n        dtype = X.dtype\n\n        # Weight projection\n        dY_X = X.t() @ dY\n        d_A = s * dY_X @ B\n        d_B = s * A @ dY_X\n\n        # Get derivative for dX\n        W = dequantize(W.t(), W_quant)\n        dX = dY @ W.t()\n        del W\n\n        A, B = A.to(dtype), B.to(dtype)\n        # Stay decomposed: dY @ B gives [T, R], then [T, R] @ A gives [T, in]\n        dX.addmm_(torch.mm(dY, B), A, alpha=s)\n\n        # W, b, W_quant, A, B, s\n        return dX.view(batch, seq_len, hd), None, None, None, d_A.t(), d_B.t(), None\n\n\ndef apply_lora_o(self, X: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    Applies LoRA to output projection layer.\n\n    Args:\n        X: Input tensor\n\n    Returns:\n        Transformed output tensor\n    \"\"\"\n    OW, Ob, OW_quant, OA, OB, OS = get_lora_parameters(self.o_proj)\n    output = LoRA_O.apply(X, OW, Ob, OW_quant, OA, OB, OS)\n\n    return output\n"
  },
  {
    "path": "src/axolotl/kernels/quantize.py",
    "content": "\"\"\"Dequantization utilities for `bitsandbytes` and FP8 integration.\"\"\"\n\nimport ctypes\n\nimport bitsandbytes as bnb\nimport torch\nfrom bitsandbytes.functional import QuantState, get_ptr\nfrom packaging.version import Version\n\ncdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32\ncdequantize_blockwise_fp16_nf4 = bnb.functional.lib.cdequantize_blockwise_fp16_nf4\ncdequantize_blockwise_bf16_nf4 = bnb.functional.lib.cdequantize_blockwise_bf16_nf4\n\nCUDA_STREAM: torch.cuda.Stream | None = None\nHAS_CUDA_STREAM: bool = Version(bnb.__version__) > Version(\"0.43.3\")\n\n\ndef dequantize_fp8(\n    W: torch.Tensor,\n    scale_inv: torch.Tensor,\n    dtype: torch.dtype = torch.bfloat16,\n) -> torch.Tensor:\n    \"\"\"Dequantize FP8 block-quantized weights: W_dequant = W_fp8 * scale_inv.\n\n    Args:\n        W: FP8 weight tensor [out_features, in_features] in float8_e4m3fn.\n        scale_inv: Per-block inverse scale [ceil(out/block), ceil(in/block)]\n            or per-tensor scalar.\n        dtype: Output dtype (default bf16).\n\n    Returns:\n        Dequantized tensor in the specified dtype.\n    \"\"\"\n    W_float = W.to(dtype)\n    if scale_inv.numel() == 1:\n        return W_float * scale_inv.to(dtype)\n    if scale_inv.dim() == 2 and W.dim() == 2:\n        sr, sc = scale_inv.shape\n        br = W.shape[0] // sr\n        bc = W.shape[1] // sc\n        # If dimensions are exactly divisible, use fast reshape path\n        if sr * br == W.shape[0] and sc * bc == W.shape[1]:\n            return (\n                W_float.reshape(sr, br, sc, bc) * scale_inv[:, None, :, None].to(dtype)\n            ).reshape(W.shape)\n        # Tail-block handling: compute actual block size (ceil division),\n        # tile scale_inv to cover full shape, then crop to W's dimensions\n        br_ceil = -(-W.shape[0] // sr)  # ceil(rows / scale_rows) = block_size\n        bc_ceil = -(-W.shape[1] // sc)\n        scale_expanded = (\n            scale_inv.to(dtype)\n            .repeat_interleave(br_ceil, dim=0)\n            .repeat_interleave(bc_ceil, dim=1)\n        )[: W.shape[0], : W.shape[1]]\n        return W_float * scale_expanded\n    return W_float * scale_inv.to(dtype)\n\n\ndef dequantize(\n    W: torch.Tensor,\n    quant_state: QuantState | list | torch.Tensor | None = None,\n    out: torch.Tensor | None = None,\n) -> torch.Tensor:\n    \"\"\"\n    Fast NF4 dequantization using `bitsandbytes` CUDA kernels.\n\n    Performs efficient dequantization of weights from NF4 format using `bitsandbytes`'\n    optimized CUDA implementations. Supports both legacy list and new `QuantState`\n    formats.\n\n    Args:\n        W: Quantized weight tensor to dequantize\n        quant_state: Quantization state containing metadata needed for\n            dequantization. Can be either a `QuantState` object or legacy list format.\n            If None, returns `W` unchanged.\n        out: Optional output tensor for storing dequantized results. Must match\n            expected shape and dtype if provided.\n\n    Returns:\n        Dequantized tensor in the specified dtype (fp16 or bf16). Will be transposed if\n        input `W` was transposed.\n\n    Raises:\n        AssertionError: If provided output tensor doesn't match expected shape / dtype.\n\n    Note:\n        Uses CUDA streams for better performance when available in newer `bitsandbytes`\n        versions (>0.43.3).\n    \"\"\"\n    if quant_state is None:\n        return W\n\n    # FP8 path: quant_state is actually scale_inv tensor\n    if W.dtype == torch.float8_e4m3fn:\n        scale_inv = quant_state\n        # Caller may pass W.t() (non-contiguous) — dequantize in original\n        # layout then transpose back so the result shape matches the input.\n        if not W.is_contiguous() and W.dim() == 2:\n            return dequantize_fp8(W.t(), scale_inv).t()\n        return dequantize_fp8(W, scale_inv)\n\n    # Get the target device from input tensor W\n    target_device = W.device\n\n    # Extract quantization state\n    if not isinstance(quant_state, list):\n        # New style quant_state class\n        absmax = quant_state.absmax.to(target_device)\n        shape = quant_state.shape\n        dtype = quant_state.dtype\n        blocksize = quant_state.blocksize\n        offset = quant_state.offset.to(target_device)\n        state2 = quant_state.state2\n        absmax2 = state2.absmax.to(target_device)\n        code2 = state2.code.to(target_device)\n        blocksize2 = state2.blocksize\n    else:\n        # Legacy list format\n        absmax, shape, dtype, blocksize, compressed_stats, _, _ = quant_state\n        absmax = absmax.to(target_device)\n        offset, state2 = compressed_stats\n        offset = offset.to(target_device)\n        absmax2, code2, blocksize2, _, _, _, _ = state2\n        absmax2 = absmax2.to(target_device)\n        code2 = code2.to(target_device)\n\n    # Setup output tensor on the same device as input\n    if out is None:\n        out = torch.empty(shape, dtype=dtype, device=target_device)\n    else:\n        assert out.shape == shape and out.dtype == dtype\n        out = out.to(target_device)\n\n    # Dequantize statistics on the target device\n    n_elements_absmax: int = absmax.numel()\n    out_absmax: torch.Tensor = torch.empty(\n        n_elements_absmax, dtype=torch.float32, device=target_device\n    )\n    ptr_out_absmax: int = get_ptr(out_absmax)\n\n    # Use CUDA stream if available\n    if HAS_CUDA_STREAM:\n        global CUDA_STREAM\n        if CUDA_STREAM is None:\n            CUDA_STREAM = torch.cuda.current_stream(target_device)\n\n        cdequantize_blockwise_fp32(\n            get_ptr(code2),\n            get_ptr(absmax),\n            get_ptr(absmax2),\n            ptr_out_absmax,\n            ctypes.c_int(blocksize2),\n            ctypes.c_int(n_elements_absmax),\n            CUDA_STREAM,\n        )\n    else:\n        cdequantize_blockwise_fp32(\n            get_ptr(code2),\n            get_ptr(absmax),\n            get_ptr(absmax2),\n            ptr_out_absmax,\n            ctypes.c_int(blocksize2),\n            ctypes.c_int(n_elements_absmax),\n        )\n\n    out_absmax += offset\n\n    # Choose appropriate dequantization function\n    fx = (\n        cdequantize_blockwise_fp16_nf4\n        if dtype == torch.float16\n        else cdequantize_blockwise_bf16_nf4\n    )\n\n    # Dequantize weights\n    if HAS_CUDA_STREAM:\n        fx(\n            get_ptr(None),\n            get_ptr(W),\n            ptr_out_absmax,\n            get_ptr(out),\n            ctypes.c_int(blocksize),\n            ctypes.c_int(out.numel()),\n            CUDA_STREAM,\n        )\n    else:\n        fx(\n            get_ptr(None),\n            get_ptr(W),\n            ptr_out_absmax,\n            get_ptr(out),\n            ctypes.c_int(blocksize),\n            ctypes.c_int(out.numel()),\n        )\n\n    # Handle transposed data\n    is_transposed: bool = W.shape[0] == 1\n    return out.t() if is_transposed else out\n"
  },
  {
    "path": "src/axolotl/kernels/swiglu.py",
    "content": "\"\"\"\nModule for definition of SwiGLU Triton kernels.\n\nSee \"GLU Variants Improve Transformer\" (https://arxiv.org/abs/2002.05202).\n\nCredit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation.\n\"\"\"\n\nimport torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef _swiglu_fwd_kernel(\n    gate_ptr,\n    up_ptr,\n    out_ptr,\n    n_elements,\n    block_size: tl.constexpr,\n):\n    \"\"\"\n    SwiGLU forward kernel. The kernel computes activation in fp32 precision for better\n    numerical stability, then converts back to original dtype for the final result.\n\n    Args:\n        gate_ptr: Pointer to gate tensor `[*, hidden_dim]`.\n        up_ptr: Pointer to up-projection tensor `[*, hidden_dim]`.\n        out_ptr: Pointer to output tensor `[*, hidden_dim]`.\n        n_elements: Total number of elements in the input tensors.\n        block_size: Size of thread blocks for parallel computation.\n    \"\"\"\n    block_idx = tl.program_id(0)\n    offsets = block_idx * block_size + tl.arange(0, block_size)\n    mask = offsets < n_elements\n\n    # Load gate in fp32, keep up in original dtype\n    gate = tl.load(gate_ptr + offsets, mask=mask, other=0).to(tl.float32)\n    up = tl.load(up_ptr + offsets, mask=mask, other=0)\n\n    # Compute activation in fp32 then convert back\n    f = gate * tl.sigmoid(gate)\n    f = f.to(up.dtype)\n    result = f * up\n\n    tl.store(out_ptr + offsets, result, mask=mask)\n\n\n@triton.jit\ndef _swiglu_bwd_kernel(\n    grad_out_ptr,\n    gate_ptr,\n    up_ptr,\n    n_elements,\n    block_size: tl.constexpr,\n):\n    \"\"\"\n    SwiGLU backward kernel. Stores gradient results in-place.\n\n    Args:\n        grad_out_ptr: Pointer to gradient output tensor `[*, hidden_dim]`.\n        gate_ptr: Pointer to gate tensor `[*, hidden_dim]`.\n        up_ptr: Pointer to up-projection tensor `[*, hidden_dim]`.\n        n_elements: Total number of elements in the input tensors.\n        block_size: Size of thread blocks for parallel computation.\n\n    Note:\n        After kernel execution, tensors are modified in-place:\n        - `grad_out_ptr` contains forward output (`h`)\n        - `gate_ptr` contains gradient w.r.t gate (`grad_gate`)\n        - `up_ptr` contains gradient w.r.t up (`grad_up`)\n    \"\"\"\n    block_idx = tl.program_id(0)\n    offsets = block_idx * block_size + tl.arange(0, block_size)\n    mask = offsets < n_elements\n\n    # Load values - only convert gate to fp32\n    grad_out = tl.load(grad_out_ptr + offsets, mask=mask, other=0)\n    gate = tl.load(gate_ptr + offsets, mask=mask, other=0).to(tl.float32)\n    up = tl.load(up_ptr + offsets, mask=mask, other=0)\n\n    # Compute SiLU and forward output\n    sigmoid_gate = tl.sigmoid(gate)\n    silu_gate = sigmoid_gate * gate\n    silu_gate = silu_gate.to(grad_out.dtype)\n    h = silu_gate * up\n\n    # Compute gradients\n    grad_up = grad_out * silu_gate  # gradient for up is grad_out * SiLU(gate)\n\n    # Compute gate gradient\n    temp = grad_out * up\n    grad_gate = temp.to(tl.float32) * sigmoid_gate * (1.0 + gate * (1.0 - sigmoid_gate))\n    grad_gate = grad_gate.to(grad_out.dtype)\n\n    # Store results with correct gradient ordering\n    tl.store(grad_out_ptr + offsets, h, mask=mask)\n    tl.store(gate_ptr + offsets, grad_gate, mask=mask)  # grad wrt gate\n    tl.store(up_ptr + offsets, grad_up, mask=mask)  # grad wrt up\n\n\ndef swiglu_forward(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    SwiGLU forward pass. Computes SwiGLU activation: `x * sigmoid(x) * up`, where\n    `x` is the gate tensor.\n\n    Args:\n        gate: Input gate tensor of shape `[batch, seq_len, hidden_dim]`.\n        up: Up-projection tensor of shape `[batch, seq_len, hidden_dim]`.\n\n    Returns:\n        Output tensor of shape `[batch, seq_len, hidden_dim]`.\n    \"\"\"\n    batch, seq_len, hidden_dim = gate.shape\n    n_elements = gate.numel()\n    out = torch.empty((batch, seq_len, hidden_dim), dtype=gate.dtype, device=\"cuda\")\n\n    grid = lambda meta: (triton.cdiv(n_elements, meta[\"block_size\"]),)  # noqa: E731\n    _swiglu_fwd_kernel[grid](\n        gate_ptr=gate,\n        up_ptr=up,\n        out_ptr=out,\n        n_elements=n_elements,\n        block_size=1024,\n    )\n\n    return out\n\n\ndef swiglu_backward(\n    grad_output: torch.Tensor, gate: torch.Tensor, up: torch.Tensor\n) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n    \"\"\"\n    SwiGLU backward pass using in-place operations.\n\n    Args:\n        grad_output: Gradient of loss with respect to output, shape `[batch, seq_len, hidden_dim]`.\n        gate: Gate tensor from forward pass, shape `[batch, seq_len, hidden_dim]`.\n        up: Up-projection tensor from forward pass, shape `[batch, seq_len, hidden_dim]`.\n\n    Returns:\n        Tuple containing:\n            - Forward pass output (`h`)\n            - Gradient with respect to gate (`df`)\n            - Gradient with respect to up-projection (`de`)\n    \"\"\"\n    n_elements = grad_output.numel()\n\n    grid = lambda meta: (triton.cdiv(n_elements, meta[\"block_size\"]),)  # noqa: E731\n    _swiglu_bwd_kernel[grid](\n        grad_out_ptr=grad_output,\n        gate_ptr=gate,\n        up_ptr=up,\n        n_elements=n_elements,\n        block_size=1024,\n    )\n\n    # After kernel execution, tensors contain:\n    # grad_output: h (forward output)\n    # gate: grad_gate (grad wrt gate)\n    # up: grad_up (grad wrt up)\n    return grad_output, gate, up\n"
  },
  {
    "path": "src/axolotl/kernels/utils.py",
    "content": "\"\"\"Utilities for `axolotl.kernels` submodules.\"\"\"\n\nimport torch\nfrom packaging.version import Version\n\nif Version(torch.__version__) < Version(\"2.4.0\"):\n    torch_amp_custom_fwd = torch.cuda.amp.custom_fwd\n    torch_amp_custom_bwd = torch.cuda.amp.custom_bwd\nelse:\n    torch_amp_custom_fwd = torch.amp.custom_fwd(device_type=\"cuda\")\n    torch_amp_custom_bwd = torch.amp.custom_bwd(device_type=\"cuda\")\n"
  },
  {
    "path": "src/axolotl/loaders/__init__.py",
    "content": "\"\"\"Init for axolotl.loaders module\"\"\"\n\n# flake8: noqa\n\nfrom .adapter import load_adapter, load_lora\nfrom .constants import MULTIMODAL_AUTO_MODEL_MAPPING\nfrom .model import ModelLoader\nfrom .processor import load_processor\nfrom .tokenizer import load_tokenizer\n"
  },
  {
    "path": "src/axolotl/loaders/adapter.py",
    "content": "\"\"\"Adapter loading functionality, including LoRA / QLoRA and associated utils\"\"\"\n\nimport os\nimport types\nfrom typing import Any\n\nimport bitsandbytes as bnb\nimport torch\nfrom bitsandbytes.nn import Params4bit\nfrom peft import (\n    AdaptionPromptConfig,\n    LoftQConfig,\n    LoraConfig,\n    PeftConfig,\n    PeftMixedModel,\n    PeftModel,\n    TaskType,\n    get_peft_model,\n)\nfrom transformers import PreTrainedModel\n\nfrom axolotl.loaders.utils import get_linear_embedding_layers\nfrom axolotl.telemetry.errors import send_errors\nfrom axolotl.utils.dict import DictDefault\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\n\ndef setup_quantized_meta_for_peft(model: torch.nn.Module):\n    \"\"\"Replaces `quant_state.to` with a dummy function to prevent PEFT from moving `quant_state` to meta device\"\"\"\n\n    def temp_to_method(self, *args, **kwargs):\n        return self\n\n    for param in model.parameters():\n        if isinstance(param, Params4bit) and param.quant_state is not None:\n            param.quant_state._orig_to = param.quant_state.to\n            param.quant_state.to = types.MethodType(temp_to_method, param.quant_state)\n\n\ndef setup_quantized_peft_meta_for_training(model: torch.nn.Module):\n    \"\"\"Replaces dummy `quant_state.to` method with the original function to allow training to continue\"\"\"\n    for param in model.parameters():\n        if isinstance(param, Params4bit) and hasattr(param.quant_state, \"_orig_to\"):\n            param.quant_state.to = param.quant_state._orig_to\n            param.quant_state._orig_to = None\n\n\ndef find_all_linear_names(model):\n    cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear)\n    lora_module_names = set()\n    for name, module in model.named_modules():\n        if (\n            isinstance(module, cls)\n            or \"Linear\" in module.__class__.__name__\n            and module.__class__.__name__ not in (\"LlamaLinearScalingRotaryEmbedding\",)\n        ):\n            names = name.split(\".\")\n            lora_module_names.add(names[0] if len(names) == 1 else names[-1])\n\n    embedding_modules = get_linear_embedding_layers(model.config.model_type)\n    output_embedding = embedding_modules[1]\n    if output_embedding in lora_module_names:  # needed for 16-bit\n        lora_module_names.remove(output_embedding)\n\n    return list(lora_module_names)\n\n\ndef load_lora(\n    model: PreTrainedModel,\n    cfg: DictDefault,\n    inference: bool = False,\n    config_only: bool = False,\n) -> tuple[PreTrainedModel | PeftModel | PeftMixedModel | None, PeftConfig | None]:\n    lora_target_modules = cfg.lora_target_modules or []\n    lora_target_parameters = cfg.lora_target_parameters or []\n\n    if cfg.lora_target_linear:\n        linear_names = find_all_linear_names(model)\n        LOG.info(f\"found linear modules: {repr(sorted(linear_names))}\")\n        lora_target_modules_as_list = (\n            lora_target_modules\n            if isinstance(lora_target_modules, list)\n            else [lora_target_modules]\n        )\n        lora_target_modules = list(set(lora_target_modules_as_list + linear_names))\n\n    lora_config_kwargs = {}\n    loftq_bits = cfg.peft and cfg.peft.loftq_config and cfg.peft.loftq_config.loftq_bits\n    if loftq_bits:\n        lora_config_kwargs[\"loftq_config\"] = LoftQConfig(loftq_bits=loftq_bits)\n        lora_config_kwargs[\"init_lora_weights\"] = \"loftq\"\n    if cfg.peft_init_lora_weights:\n        lora_config_kwargs[\"init_lora_weights\"] = cfg.peft_init_lora_weights\n    if cfg.peft_use_dora:\n        lora_config_kwargs[\"use_dora\"] = cfg.peft_use_dora\n        LOG.info(\"Initializing LoRA weights using dora. This might take longer.\")\n    if cfg.peft_use_rslora:\n        lora_config_kwargs[\"use_rslora\"] = cfg.peft_use_rslora\n    if cfg.peft_layer_replication:\n        lora_config_kwargs[\"layer_replication\"] = cfg.peft_layer_replication\n    if cfg.peft_trainable_token_indices:\n        lora_config_kwargs[\"trainable_token_indices\"] = cfg.peft_trainable_token_indices\n    if cfg.peft_ensure_weight_tying is not None:\n        lora_config_kwargs[\"ensure_weight_tying\"] = cfg.peft_ensure_weight_tying\n\n    # Determine the correct PEFT task type\n    model_cls = type(model).__name__\n    if \"SequenceClassification\" in model_cls:\n        task_type = TaskType.SEQ_CLS\n    elif \"TokenClassification\" in model_cls:\n        task_type = TaskType.TOKEN_CLS\n    else:\n        task_type = TaskType.CAUSAL_LM\n\n    lora_config = LoraConfig(\n        r=cfg.lora_r,\n        lora_alpha=cfg.lora_alpha,\n        target_modules=lora_target_modules,\n        target_parameters=lora_target_parameters,\n        layers_to_transform=cfg.peft_layers_to_transform,\n        layers_pattern=cfg.peft_layers_pattern,\n        lora_dropout=cfg.lora_dropout,\n        fan_in_fan_out=cfg.lora_fan_in_fan_out,\n        modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None,\n        bias=\"none\",\n        task_type=task_type,\n        **lora_config_kwargs,\n    )\n\n    if config_only:\n        return None, lora_config\n\n    rank = int(os.environ.get(\"LOCAL_RANK\", 0))\n\n    if (\n        cfg.fsdp_config\n        and cfg.adapter\n        and cfg.fsdp_config.cpu_ram_efficient_loading\n        and rank != 0\n    ):\n        setup_quantized_meta_for_peft(model)\n\n    model_kwargs: Any = {}\n    if cfg.peft_autocast_adapter_dtype is not None:\n        model_kwargs[\"autocast_adapter_dtype\"] = cfg.peft_autocast_adapter_dtype\n\n    if cfg.lora_model_dir:\n        LOG.debug(\"Loading pretrained PEFT - LoRA\")\n        if cfg.lora_on_cpu:\n            model_kwargs[\"max_memory\"] = {\"cpu\": \"256GiB\"}\n            model_kwargs[\"device_map\"] = {\"\": \"cpu\"}\n        model = PeftModel.from_pretrained(\n            model,\n            cfg.lora_model_dir,\n            is_trainable=(not inference),\n            **model_kwargs,\n        )\n    else:\n        model = get_peft_model(model, lora_config, **model_kwargs)\n\n    # FP8 models: LoRA A/B inherit FP8 dtype from base weights, but training\n    # requires a compute dtype (bf16/fp16). Cast trainable LoRA params.\n    if cfg.torch_dtype:\n        _fp8_cast_dtype = cfg.torch_dtype\n    elif torch.cuda.is_available() and torch.cuda.is_bf16_supported():\n        _fp8_cast_dtype = torch.bfloat16\n    else:\n        _fp8_cast_dtype = torch.float16\n    for _name, param in model.named_parameters():\n        if param.requires_grad and param.dtype == torch.float8_e4m3fn:\n            param.data = param.data.to(_fp8_cast_dtype)\n\n    if rank == 0:\n        try:\n            model.print_trainable_parameters()\n        except AttributeError as exc:\n            LOG.warning(\n                \"Exception caught during model.print_trainable_parameters(): %s\", exc\n            )\n    elif (\n        cfg.fsdp_config\n        and cfg.adapter\n        and cfg.fsdp_config.cpu_ram_efficient_loading\n        and rank != 0\n    ):\n        setup_quantized_peft_meta_for_training(model)\n\n    return model, lora_config\n\n\n@send_errors\ndef load_adapter(\n    model: PreTrainedModel,\n    cfg: DictDefault,\n    adapter: str | None,\n    inference: bool = False,\n) -> tuple[PreTrainedModel | PeftModel | PeftMixedModel, PeftConfig | None]:\n    if adapter is None:\n        return model, None\n    if hasattr(model, \"enable_input_require_grads\"):\n        model.enable_input_require_grads()\n    if adapter in [\"lora\", \"qlora\"]:\n        peft_model, lora_config = load_lora(model, cfg, inference=inference)\n        return peft_model, lora_config\n    if adapter == \"llama-adapter\":\n        peft_model, lora_config = load_llama_adapter(model, cfg)\n        return peft_model, lora_config\n\n    raise NotImplementedError(f\"{adapter} PEFT adapter not available\")\n\n\ndef load_llama_adapter(\n    model: PreTrainedModel, cfg: DictDefault\n) -> tuple[PeftModel | PeftMixedModel, PeftConfig]:\n    peft_config = AdaptionPromptConfig(\n        adapter_layers=cfg.peft_adapter.layers,  # layers (L)\n        adapter_len=cfg.peft_adapter.len,  # prompt length (K)\n        task_type=\"CAUSAL_LM\",\n    )\n\n    if cfg.lora_model_dir:\n        LOG.debug(\"Loading pretrained PEFT - llama_adapter\")\n        peft_model = PeftModel.from_pretrained(\n            model,\n            cfg.lora_model_dir,\n            torch_dtype=torch.float16,\n        )\n    else:\n        peft_model = get_peft_model(model, peft_config)\n\n    peft_model.print_trainable_parameters()\n\n    return peft_model, peft_config\n"
  },
  {
    "path": "src/axolotl/loaders/adapters/__init__.py",
    "content": ""
  },
  {
    "path": "src/axolotl/loaders/constants.py",
    "content": "\"\"\"Shared constants for axolotl.loaders module\"\"\"\n\nfrom transformers import AutoModelForImageTextToText\nfrom transformers.models.auto.modeling_auto import (\n    MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES,\n)\n\nMULTIMODAL_AUTO_MODEL_MAPPING = dict(MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES)\n\nMULTIMODAL_AUTO_MODEL_MAPPING[\"lfm2-vl\"] = AutoModelForImageTextToText\n\ntry:\n    from transformers import VoxtralForConditionalGeneration\n\n    # transformers >4.53.2\n    MULTIMODAL_AUTO_MODEL_MAPPING[\"voxtral\"] = VoxtralForConditionalGeneration\nexcept ImportError:\n    pass\n"
  },
  {
    "path": "src/axolotl/loaders/model.py",
    "content": "\"\"\"\nModel loader class implementation for loading, configuring, and patching various models.\n\"\"\"\n\nimport gc\nimport math\nimport os\nfrom functools import cached_property\nfrom importlib.util import find_spec\nfrom typing import Any\n\nimport peft\nimport torch\nimport transformers\nimport transformers.modeling_utils\nfrom accelerate import init_empty_weights\nfrom accelerate.parallelism_config import ParallelismConfig\nfrom peft import (\n    PeftConfig,\n    PeftMixedModel,\n    PeftModel,\n    PeftModelForCausalLM,\n    prepare_model_for_kbit_training,\n)\nfrom torch.distributed import DeviceMesh\nfrom transformers import (\n    AutoModelForCausalLM,\n    AutoModelForImageTextToText,\n    AwqConfig,\n    BitsAndBytesConfig,\n    GPTQConfig,\n    PreTrainedModel,\n    PreTrainedTokenizerBase,\n)\nfrom transformers.integrations.deepspeed import (\n    HfTrainerDeepSpeedConfig,\n    is_deepspeed_zero3_enabled,\n)\n\nfrom axolotl.common.architectures import MOE_ARCH_BLOCK\nfrom axolotl.integrations.base import PluginManager\nfrom axolotl.loaders.adapter import load_adapter, load_lora\nfrom axolotl.loaders.constants import MULTIMODAL_AUTO_MODEL_MAPPING\nfrom axolotl.loaders.patch_manager import PatchManager\nfrom axolotl.loaders.utils import (\n    get_linear_embedding_layers,\n    get_module_class_from_name,\n    load_model_config,\n)\nfrom axolotl.models.mamba import fix_mamba_attn_for_loss\nfrom axolotl.telemetry.errors import send_errors\nfrom axolotl.utils.bench import log_gpu_memory_usage\nfrom axolotl.utils.dict import DictDefault\nfrom axolotl.utils.distributed import (\n    build_parallelism_config,\n    get_device_count,\n    get_device_type,\n)\nfrom axolotl.utils.logging import get_logger\nfrom axolotl.utils.model_shard_quant import load_sharded_model_quant\nfrom axolotl.utils.schemas.enums import RLType\n\nLOG = get_logger(__name__)\nPLUGIN_MANAGER = PluginManager.get_instance()\n\n\nclass ModelLoader:\n    \"\"\"Manages model configuration, initialization and application of patches during\n    model loading.\n\n    This class orchestrates the entire process of loading a model from configuration to\n    final preparation. It handles device mapping, quantization, attention mechanisms,\n    adapter integration, and various optimizations.\n\n    The loading process includes:\n        - Loading and validating model configuration\n        - Applying monkey patches for optimizations / fixes\n        - Setting up device mapping (including multi-GPU configurations)\n        - Configuring quantization\n        - Setting attention mechanisms (Flash Attention, SDPA, etc.)\n        - Loading and initializing the model\n        - Applying adapters (LoRA, QLoRA, etc.)\n\n    Attributes:\n        model: The loaded model instance (available after load() is called).\n        model_kwargs: Dictionary of keyword arguments passed to model initialization.\n        base_model: Name or path of the base model to load.\n        model_type: Type of model to load (e.g., `AutoModelForCausalLM`).\n        model_config: Configuration object for the model.\n        auto_model_loader: class used for loading the model (default:\n            `AutoModelForCausalLM`).\n    \"\"\"\n\n    use_parallel_config: bool | None = False\n    parallelism_config: ParallelismConfig | None = None\n    device_mesh: DeviceMesh | None = None\n\n    def __init__(\n        self,\n        cfg: DictDefault,\n        tokenizer: PreTrainedTokenizerBase,\n        *,\n        inference: bool = False,\n        reference_model: bool = False,\n        **kwargs,\n    ):\n        \"\"\"Initializes the ModelLoader.\n\n        Args:\n            cfg: Configuration dictionary with model and training settings.\n            tokenizer: Tokenizer instance associated with the model.\n            processor: Optional processor for multimodal models. Defaults to None.\n            inference: Whether the model is being loaded for inference mode. Defaults\n                to False.\n            reference_model: Whether this is a reference model (used in setups like DPO\n                training). Defaults to False.\n            **kwargs: Additional keyword arguments (ignored).\n        \"\"\"\n        self.cfg = cfg\n        self.tokenizer = tokenizer\n        self.inference: bool = inference\n        self.reference_model: bool = reference_model\n\n        # Init model kwargs\n        self.model_kwargs: dict[str, Any] = {}\n        if cfg.overrides_of_model_kwargs:\n            for key, val in cfg.overrides_of_model_kwargs.items():\n                self.model_kwargs[key] = val\n\n        # Init model\n        self.model: PreTrainedModel | PeftModel | PeftMixedModel\n        self.base_model = cfg.base_model\n        self.model_type = cfg.type_of_model\n\n        # Init model config\n        self.model_config = load_model_config(cfg)\n        self.auto_model_loader = AutoModelForCausalLM\n\n        # Initialize the patch manager\n        self.patch_manager = PatchManager(\n            cfg=cfg,\n            model_config=self.model_config,\n            inference=inference,\n        )\n\n    @cached_property\n    def has_flash_attn(self) -> bool:\n        \"\"\"Check if flash attention is installed.\"\"\"\n        return find_spec(\"flash_attn\") is not None\n\n    @property\n    def is_fsdp_enabled(self):\n        \"\"\"Property that determines if FSDP is enabled.\"\"\"\n        return self.cfg.fsdp_config is not None or self.cfg.fsdp is not None\n\n    @property\n    def is_qlora_and_fsdp_enabled(self):\n        \"\"\"Property that determines if FSDP with QLoRA is enabled.\"\"\"\n        return self.is_fsdp_enabled and self.cfg.adapter == \"qlora\"\n\n    @send_errors\n    def load(self) -> tuple[PreTrainedModel | PeftModelForCausalLM, PeftConfig | None]:\n        \"\"\"Load and prepare the model with all configurations and patches.\n\n        Returns:\n            A tuple with the loaded model and its LoRA configuration (if applicable).\n        \"\"\"\n        # Initial setup and patches\n        self.patch_manager.apply_pre_model_load_patches()\n        self._apply_pre_model_load_setup()\n\n        # Build the model\n        PLUGIN_MANAGER.pre_model_load(self.cfg)\n        self.patch_manager.apply_post_plugin_pre_model_load_patches()\n\n        skip_move_to_device = self._build_model()\n        self.patch_manager.apply_post_model_build_patches(self.model)\n\n        PLUGIN_MANAGER.post_model_build(self.cfg, self.model)\n\n        # Post-build model configuration\n        self._apply_post_model_load_setup()\n\n        # Load adapters (LoRA, etc.)\n        PLUGIN_MANAGER.pre_lora_load(self.cfg, self.model)\n        lora_config = self._load_adapters()\n        PLUGIN_MANAGER.post_lora_load(self.cfg, self.model)\n\n        # Apply remaining patches and finalize\n        self._apply_post_lora_load_setup(skip_move_to_device)\n        self.patch_manager.apply_post_model_load_patches(self.model)\n        PLUGIN_MANAGER.post_model_load(self.cfg, self.model)\n\n        return self.model, lora_config\n\n    def _apply_pre_model_load_setup(self):\n        \"\"\"Apply patches and setup configurations before model loading.\"\"\"\n        if self.use_parallel_config is not None:\n            self.use_parallel_config = (\n                self.cfg.fsdp_config\n                or (self.cfg.tensor_parallel_size and self.cfg.tensor_parallel_size > 1)\n                or (\n                    self.cfg.context_parallel_size\n                    and self.cfg.context_parallel_size > 1\n                )\n            )\n            if self.cfg.fsdp_config and self.cfg.fsdp_version != 2:\n                self.use_parallel_config = False\n\n        if self.use_parallel_config:\n            self._set_parallel_config()\n        self._set_auto_model_loader()\n        self._set_device_map_config()\n        if self.cfg.revision_of_model:\n            self.model_kwargs[\"revision\"] = self.cfg.revision_of_model\n        if self.cfg.use_kernels:\n            self.model_kwargs[\"use_kernels\"] = self.cfg.use_kernels\n            if \"allow_all_kernels\" not in self.model_kwargs:\n                self.model_kwargs[\"allow_all_kernels\"] = self.cfg.use_kernels\n        self._set_quantization_config()\n        self._set_attention_config()\n        self._check_model_requirements()\n\n    def _apply_post_model_load_setup(self):\n        \"\"\"Configure the model after it has been loaded.\"\"\"\n        # Handle PeftModel if needed\n        if (\n            isinstance(self.model, (peft.PeftModel, peft.PeftModelForCausalLM))\n            and not self.is_qlora_and_fsdp_enabled\n        ):\n            self.model = self.model.merge_and_unload()\n\n        self._configure_experts_implementation()\n        self._apply_activation_checkpointing()\n        self._resize_token_embeddings()\n        self._adjust_model_config()\n        self._configure_embedding_dtypes()\n        self._configure_qat()\n        log_gpu_memory_usage(LOG, \"Memory usage after model load\", 0)\n\n    def _configure_experts_implementation(self):\n        if self.cfg.experts_implementation is not None:\n            self.model.set_experts_implementation(self.cfg.experts_implementation)\n\n    def _apply_activation_checkpointing(self):\n        if self.cfg.activation_offloading is True:\n            from axolotl.core.trainers.mixins.activation_checkpointing import (\n                ac_wrap_hf_model,\n            )\n\n            # ^^ importing this at the module level breaks plugins\n            ac_wrap_hf_model(self.model)\n\n    def _resize_token_embeddings(self):\n        \"\"\"Resize token embeddings if needed.\"\"\"\n        embeddings_len = (\n            math.ceil(len(self.tokenizer) / 32) * 32\n            if self.cfg.resize_token_embeddings_to_32x\n            else len(self.tokenizer)\n        )\n        if hasattr(self.model, \"get_input_embeddings\") and (\n            self.model.get_input_embeddings().num_embeddings < embeddings_len\n            or (\n                self.model.get_input_embeddings().num_embeddings > embeddings_len\n                and self.cfg.shrink_embeddings\n            )\n        ):\n            resize_kwargs = {}\n            if self.cfg.mean_resizing_embeddings is not None and (\n                self.model_config.model_type != \"llava\"\n            ):\n                resize_kwargs[\"mean_resizing\"] = self.cfg.mean_resizing_embeddings\n            self.model.resize_token_embeddings(embeddings_len, **resize_kwargs)\n        else:\n            self.model.tie_weights()\n\n    def _adjust_model_config(self):\n        if (\n            hasattr(self.model, \"config\")\n            and hasattr(self.model.config, \"max_position_embeddings\")\n            and self.model.config.max_position_embeddings\n            and self.cfg.sequence_len > self.model.config.max_position_embeddings\n        ):\n            LOG.warning(\n                \"increasing model.config.max_position_embeddings from \"\n                f\"{self.model.config.max_position_embeddings} to {self.cfg.sequence_len}\"\n            )\n            self.model.config.max_position_embeddings = self.cfg.sequence_len\n\n        if (\n            hasattr(self.model, \"config\")\n            and hasattr(self.model.config, \"bos_token_id\")\n            and self.model.config.bos_token_id\n            and self.model.config.bos_token_id != self.tokenizer.bos_token_id\n        ):\n            self.model.config.bos_token_id = self.tokenizer.bos_token_id\n\n        if (\n            hasattr(self.model, \"config\")\n            and hasattr(self.model.config, \"eos_token_id\")\n            and self.model.config.eos_token_id\n            and self.model.config.eos_token_id != self.tokenizer.eos_token_id\n        ):\n            self.model.config.eos_token_id = self.tokenizer.eos_token_id\n\n    def _configure_embedding_dtypes(self):\n        \"\"\"Configure embedding module dtypes.\"\"\"\n        # Get embedding modules\n        embedding_modules = get_linear_embedding_layers(self.cfg.model_config_type)\n\n        # Initial dtype conversion\n        if not self.is_fsdp_enabled:\n            # We don't run this during FSDP because this will leave mixed and bfloat16\n            # dtypes in the model which FSDP doesn't like\n            if self.cfg.load_in_4bit and self.cfg.embeddings_skip_upcast:\n                embedding_modules = []\n            self._convert_embedding_modules_dtype(\n                embedding_modules,\n                dist_dtype=torch.float32,\n                before_kbit_train_or_finetune=True,\n            )\n\n        # Handle DeepSpeed Zero3\n        if (\n            is_deepspeed_zero3_enabled()\n            or os.getenv(\"ACCELERATE_DEEPSPEED_ZERO_STAGE\") == \"3\"\n        ):\n            self._set_z3_leaf_modules()\n\n        # Apply gradient checkpointing if needed\n        needs_fa2_dtype = self.cfg.adapter or self.is_fsdp_enabled\n        if self.cfg.adapter in [\"lora\", \"qlora\"]:\n            needs_fa2_dtype = True\n            if self.cfg.gradient_checkpointing:\n                self.model.gradient_checkpointing_enable(\n                    gradient_checkpointing_kwargs=self.cfg.gradient_checkpointing_kwargs\n                )\n\n        self._prepare_model_for_quantization()\n\n        # Convert dtypes if needed\n        should_convert = (\n            # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so\n            # we need to convert them back to fp16/bf16 for flash-attn compatibility.\n            (\n                (\n                    needs_fa2_dtype\n                    or self.cfg.flash_attention\n                    or self.cfg.flex_attention\n                    or self.cfg.sage_attention\n                )\n                and not self.is_qlora_and_fsdp_enabled\n            )\n            or (\n                # CCE requires embedding layers to be in fp16/bf16 for backward pass\n                self.cfg.cut_cross_entropy\n            )\n        )\n\n        if should_convert:\n            LOG.info(\"Converting modules to %s\", self.cfg.torch_dtype)\n            self._convert_embedding_modules_dtype(\n                embedding_modules=embedding_modules,\n                dist_dtype=self.cfg.torch_dtype,\n                before_kbit_train_or_finetune=False,\n            )\n\n    def _configure_qat(self):\n        \"\"\"Configure QAT.\"\"\"\n        if self.cfg.qat:\n            from axolotl.utils.quantization import prepare_model_for_qat\n\n            prepare_model_for_qat(\n                self.model,\n                self.cfg.qat.weight_dtype,\n                self.cfg.qat.group_size,\n                self.cfg.qat.activation_dtype,\n                self.cfg.qat.quantize_embedding,\n            )\n\n    def _load_adapters(self) -> PeftConfig | None:\n        \"\"\"Load LoRA or other adapters.\"\"\"\n        # Load LoRA or adapter\n        lora_config = None\n        if not self.reference_model or self.cfg.lora_model_dir:\n            # If we're not loading the reference model, then we're loading the model\n            # for training. Then, the DPO trainer doesn't want the PEFT model loaded\n            # over it, it just wants the LoRA / PEFT config.\n            if (\n                self.cfg.adapter\n                and self.cfg.rl in [RLType.DPO, RLType.IPO, RLType.KTO]\n                and not self.cfg.merge_lora\n            ):\n                _, lora_config = load_lora(\n                    self.model, self.cfg, inference=False, config_only=True\n                )\n            else:\n                self.model, lora_config = load_adapter(\n                    self.model, self.cfg, self.cfg.adapter\n                )\n\n        return lora_config\n\n    def _apply_post_lora_load_setup(self, skip_move_to_device: bool):\n        \"\"\"Apply final optimizations and patches.\"\"\"\n        # Place model on accelerator\n        if (\n            self.cfg.ddp\n            and not self.cfg.load_in_8bit\n            and not (self.cfg.rl and self.cfg.load_in_4bit)\n            and not skip_move_to_device\n        ):\n            self.model.to(f\"{str(get_device_type())}:{self.cfg.local_rank}\")\n\n        if get_device_count() > 1 and int(os.getenv(\"WORLD_SIZE\", \"1\")) == 1:\n            self.model.is_parallelizable = True\n            self.model.model_parallel = True\n\n        if not any(\n            param.requires_grad\n            for _, param in self.model.named_parameters(recurse=True)\n        ):\n            LOG.warning(\"There are no parameters that require gradient updates\")\n\n        if self.cfg.flash_optimum:\n            from optimum.bettertransformer import BetterTransformer\n\n            self.model = BetterTransformer.transform(self.model)\n\n        if self.cfg.adapter is not None:\n            log_gpu_memory_usage(LOG, \"after adapters\", self.model.device)\n\n        for _ in range(3):\n            gc.collect()\n            torch.cuda.empty_cache()\n\n    def _set_parallel_config(self):\n        \"\"\"Set parallelism configuration (DP, FSDP, TP, CP) in PartialState/Accelerator\"\"\"\n        parallelism_config, device_mesh = build_parallelism_config(self.cfg)\n        if parallelism_config:\n            self.parallelism_config = parallelism_config\n            self.device_mesh = device_mesh\n\n    def _set_auto_model_loader(self):\n        \"\"\"Set `self.auto_model_loader`. Defaults to `transformers.AutoModelForCausalLM`\n        (set at `__init__`). When using a multimodal model, `self.auto_model_loader`\n        should be set according to the type of the model.\n        \"\"\"\n        if self.cfg.is_multimodal:\n            self.auto_model_loader = MULTIMODAL_AUTO_MODEL_MAPPING.get(\n                self.model_config.model_type, AutoModelForImageTextToText\n            )\n            if isinstance(self.auto_model_loader, str):\n                self.auto_model_loader = AutoModelForImageTextToText\n\n    def _set_device_map_config(self):\n        \"\"\"Setup `device_map` according to config\"\"\"\n        device_map = self.cfg.device_map\n        max_memory = self.cfg.max_memory\n\n        if self.cfg.gpu_memory_limit:\n            gpu_memory_limit = (\n                str(self.cfg.gpu_memory_limit) + \"GiB\"\n                if isinstance(self.cfg.gpu_memory_limit, int)\n                else self.cfg.gpu_memory_limit\n            )\n\n            max_memory = {}\n            num_device = get_device_count()\n            for i in range(num_device):\n                max_memory[i] = gpu_memory_limit\n            max_memory[\"cpu\"] = \"256GiB\"  # something sufficiently large to fit anything\n\n        if max_memory is not None:\n            # Based on https://github.com/togethercomputer/OpenChatKit/blob/main/inference/bot.py\n            from accelerate import infer_auto_device_map\n\n            with init_empty_weights():\n                model_canvas = self.auto_model_loader.from_config(\n                    self.model_config,\n                    trust_remote_code=self.cfg.trust_remote_code or False,\n                )\n            model_canvas.tie_weights()\n            device_map = infer_auto_device_map(\n                model_canvas,\n                max_memory=max_memory,\n                dtype=self.cfg.torch_dtype,\n            )\n            # We can discard max_memory now as we have a device map set up\n            max_memory = None\n\n        self.model_kwargs[\"torch_dtype\"] = self.cfg.torch_dtype\n        self.model_kwargs[\"dtype\"] = self.cfg.torch_dtype\n\n        is_ds_zero3 = is_deepspeed_zero3_enabled()\n\n        # FSDP requires control over device placement, so don't set device_map when FSDP is enabled\n        if self.is_fsdp_enabled:\n            # For QLoRA + FSDP, we still need to set device_map to \"auto\" for proper initialization\n            if self.is_qlora_and_fsdp_enabled:\n                self.model_kwargs[\"device_map\"] = {\n                    \"\": int(os.environ.get(\"LOCAL_RANK\", 0))\n                }\n            # For other FSDP cases, don't set device_map at all\n        elif not is_ds_zero3:\n            self.model_kwargs[\"device_map\"] = device_map\n\n            # quantize_moe_experts quantizes expert weights on-the-fly during loading,\n            # so the actual VRAM usage is much less than bf16 estimates.\n            # When device_map is \"auto\", accelerate's infer_auto_device_map computes\n            # the device map at bf16 size (before quantization), causing it to offload\n            # layers to CPU, which BnB then rejects. Force single-GPU placement to\n            # prevent this. Only applies to the non-FSDP, non-ZeRO3 path (DDP/single).\n            if getattr(self.cfg, \"quantize_moe_experts\", False) and device_map in (\n                \"auto\",\n                None,\n            ):\n                self.model_kwargs[\"device_map\"] = {\n                    \"\": int(os.environ.get(\"LOCAL_RANK\", 0))\n                }\n\n            cur_device = get_device_type()\n            if \"mps\" in str(cur_device):\n                self.model_kwargs[\"device_map\"] = \"mps:0\"\n            elif \"npu\" in str(cur_device):\n                self.model_kwargs[\"device_map\"] = \"npu:0\"\n\n        # TODO: can we put the reference model on it's own gpu? I think we have to move\n        # logits around to calculate loss\n        # if cfg.rl:\n        #     if torch.cuda.device_count() > 1:\n        #         if reference_model:\n        #             model_kwargs[\"device_map\"] = \"cuda:\" + str(\n        #                 torch.cuda.current_device() + 1\n        #             )\n        #         else:\n        #             model_kwargs[\"device_map\"] = \"cuda:\" + str(torch.cuda.current_device())\n\n    def _set_quantization_config(self):\n        \"\"\"Set up quantization config (bitsandbytes, awq, gptq, etc.)\"\"\"\n\n        if self.cfg.model_quantization_config == \"Mxfp4Config\":\n            from transformers import Mxfp4Config\n\n            mxfp4_kwargs = {}\n            if self.cfg.model_quantization_config_kwargs:\n                mxfp4_kwargs = self.cfg.model_quantization_config_kwargs\n            self.model_kwargs[\"quantization_config\"] = Mxfp4Config(**mxfp4_kwargs)\n\n        if self.cfg.gptq:\n            if not hasattr(self.model_config, \"quantization_config\"):\n                LOG.warning(\n                    \"model config does not contain quantization_config information\"\n                )\n            else:\n                if self.cfg.gptq_disable_exllama is not None:\n                    self.model_config.quantization_config[\"disable_exllama\"] = (\n                        self.cfg.gptq_disable_exllama\n                    )\n                self.model_kwargs[\"quantization_config\"] = GPTQConfig(\n                    **self.model_config.quantization_config\n                )\n        if (\n            self.cfg.adapter in [\"qlora\", \"lora\"]\n            and hasattr(self.model_config, \"quantization_config\")\n            and self.model_config.quantization_config[\"quant_method\"]\n            in [\"gptq\", \"awq\", \"bitsandbytes\"]\n        ):\n            if self.model_config.quantization_config[\"quant_method\"] == \"gptq\":\n                self.model_kwargs[\"quantization_config\"] = GPTQConfig(\n                    **self.model_config.quantization_config\n                )\n            elif self.model_config.quantization_config[\"quant_method\"] == \"awq\":\n                self.model_kwargs[\"quantization_config\"] = AwqConfig(\n                    **self.model_config.quantization_config\n                )\n            elif (\n                self.model_config.quantization_config[\"quant_method\"] == \"bitsandbytes\"\n            ):\n                self.model_kwargs[\"quantization_config\"] = BitsAndBytesConfig(\n                    **self.model_config.quantization_config\n                )\n        elif self.cfg.adapter == \"qlora\" and self.cfg.load_in_4bit:\n            bnb_config = {\n                \"load_in_4bit\": True,\n                \"llm_int8_threshold\": 6.0,\n                \"llm_int8_has_fp16_weight\": False,\n                \"bnb_4bit_compute_dtype\": self.cfg.torch_dtype,\n                \"bnb_4bit_use_double_quant\": True,\n                \"bnb_4bit_quant_type\": \"nf4\",\n                \"bnb_4bit_quant_storage\": torch.bfloat16,\n            }\n            if self.cfg.model_config_type in [\"jamba\", \"qwen2_moe\"] and not (\n                self.cfg.deepspeed or self.is_fsdp_enabled\n            ):\n                # for some reason, this causes the loss to be off by an order of magnitude\n                # but deepspeed needs this still in bfloat16\n                bnb_config[\"bnb_4bit_quant_storage\"] = torch.float32\n            if self.cfg.model_config_type == \"falcon_h1\":\n                # output projection cannot be quantized for Falcon-H1 models\n                bnb_config[\"llm_int8_skip_modules\"] = [\"out_proj\"]\n\n            if self.cfg.bnb_config_kwargs:\n                bnb_config.update(self.cfg.bnb_config_kwargs)\n\n            self.model_kwargs[\"quantization_config\"] = BitsAndBytesConfig(\n                **bnb_config,\n            )\n        elif self.cfg.adapter == \"lora\" and self.cfg.load_in_8bit:\n            bnb_config = {\n                \"load_in_8bit\": True,\n            }\n            # Exclude mamba blocks from int8 quantization for jamba\n            if self.cfg.model_config_type == \"jamba\":\n                bnb_config[\"llm_int8_skip_modules\"] = [\"mamba\"]\n            if self.cfg.model_config_type == \"falcon_h1\":\n                # output projection cannot be quantized for Falcon-H1 models\n                bnb_config[\"llm_int8_skip_modules\"] = [\"out_proj\"]\n            self.model_kwargs[\"quantization_config\"] = BitsAndBytesConfig(\n                **bnb_config,\n            )\n\n    def _set_attention_config(self):\n        \"\"\"Sample packing uses custom FA2 patch\"\"\"\n        if self.cfg.attn_implementation:\n            self.model_kwargs[\"attn_implementation\"] = self.cfg.attn_implementation\n        elif self.cfg.flex_attention:\n            self.model_kwargs[\"attn_implementation\"] = \"flex_attention\"\n            self.model_config._attn_implementation = \"flex_attention\"\n\n        elif self.cfg.flash_attention:\n            if not self.cfg.sample_packing and self.cfg.s2_attention:\n                pass\n            self.model_kwargs[\"attn_implementation\"] = \"flash_attention_2\"\n            self.model_config._attn_implementation = \"flash_attention_2\"\n        elif self.cfg.sdp_attention:\n            self.model_kwargs[\"attn_implementation\"] = \"sdpa\"\n            self.model_config._attn_implementation = \"sdpa\"\n        elif self.cfg.sage_attention:\n            # sets FA2 attention to re-use same internal handling like masking\n            self.model_kwargs[\"attn_implementation\"] = \"flash_attention_2\"\n            self.model_config._attn_implementation = \"flash_attention_2\"\n        elif self.cfg.eager_attention:\n            self.model_kwargs[\"attn_implementation\"] = \"eager\"\n            self.model_config._attn_implementation = \"eager\"\n\n        if self.cfg.low_cpu_mem_usage:\n            self.model_kwargs[\"low_cpu_mem_usage\"] = True\n\n    def _check_model_requirements(self):\n        if self.cfg.model_config_type in [\"lfm2-vl\", \"lfm2\"]:\n            from transformers.utils.import_utils import is_causal_conv1d_available\n\n            if is_causal_conv1d_available():\n                raise ImportError(\n                    \"The 'causal-conv1d' package is installed but causes compatibility issues with LFM2 models. \"\n                    \"Please uninstall it by running: `pip uninstall -y causal-conv1d`\"\n                )\n\n    def _configure_zero3_memory_efficient_loading(\n        self,\n    ) -> HfTrainerDeepSpeedConfig | None:\n        \"\"\"\n        Set the deepspeed config to load the model into RAM first before moving to VRAM.\n\n        IMPORTANT\n        ==========\n\n        We need to return `hf_ds_cfg` as it needs to exist before model loading for zero3.\n        HfTrainerDeepSpeedConfig is a class that is used to configure the DeepSpeed training.\n        It is not passed anywhere in the model loading function, just need to exist.\n        \"\"\"\n        hf_ds_cfg = None\n\n        if os.getenv(\"ACCELERATE_DEEPSPEED_ZERO_STAGE\") == \"3\":\n            hf_ds_cfg = HfTrainerDeepSpeedConfig(self.cfg.deepspeed)\n            hf_ds_cfg.fill_match(\n                \"train_micro_batch_size_per_gpu\", self.cfg.micro_batch_size\n            )\n            hf_ds_cfg.fill_match(\n                \"gradient_accumulation_steps\", self.cfg.gradient_accumulation_steps\n            )\n            hf_ds_cfg.fill_match(\n                \"train_batch_size\",\n                int(os.getenv(\"WORLD_SIZE\", \"1\"))\n                * self.cfg.micro_batch_size\n                * self.cfg.gradient_accumulation_steps,\n            )\n            if \"device_map\" in self.model_kwargs:\n                del self.model_kwargs[\"device_map\"]\n\n            transformers.modeling_utils.is_deepspeed_zero3_enabled = lambda: True\n            transformers.integrations.deepspeed.is_deepspeed_zero3_enabled = lambda: (\n                True\n            )\n\n        return hf_ds_cfg\n\n    def _load_model_from_config(self, model_loader_class=None) -> PreTrainedModel:\n        \"\"\"\n        Load model with random initialization using from_config.\n\n        Uses the selected loader when provided; otherwise falls back to the auto loader.\n        \"\"\"\n        loader = model_loader_class or self.auto_model_loader\n        if loader in [AutoModelForCausalLM, AutoModelForImageTextToText]:\n            model = loader.from_config(\n                config=self.model_config,\n                trust_remote_code=self.cfg.trust_remote_code or False,\n            )\n        else:\n            model = loader(config=self.model_config)\n\n        return model\n\n    def _load_model_from_pretrained(self, model_loader_class=None) -> PreTrainedModel:\n        \"\"\"Load model from pretrained weights.\"\"\"\n        loader = model_loader_class or self.auto_model_loader\n        kwargs = {\n            \"config\": self.model_config,\n            \"trust_remote_code\": self.cfg.trust_remote_code or False,\n            **self.model_kwargs,\n        }\n        return loader.from_pretrained(self.base_model, **kwargs)\n\n    def _build_model(self) -> bool:\n        \"\"\"Load model, with load strategy depending on config.\"\"\"\n        skip_move_to_device = False\n\n        if self.cfg.tensor_parallel_size > 1:\n            self.model_kwargs[\"tp_size\"] = self.cfg.tensor_parallel_size\n            self.model_kwargs[\"tp_plan\"] = \"auto\"\n            self.model_kwargs[\"device_mesh\"] = self.device_mesh\n            if \"device_map\" in self.model_kwargs:\n                del self.model_kwargs[\"device_map\"]  # not compatible with `tp_plan`\n\n        if self.is_fsdp_enabled:\n            if self.cfg.fsdp_config.cpu_ram_efficient_loading:\n                skip_move_to_device = True\n                # Don't delete device_map for QLoRA + FSDP - it was set correctly in\n                # _set_device_map\n                if (\n                    \"device_map\" in self.model_kwargs\n                    and not self.is_qlora_and_fsdp_enabled\n                ):\n                    del self.model_kwargs[\"device_map\"]\n            elif self.is_qlora_and_fsdp_enabled:\n                skip_move_to_device = True\n\n            if (\n                self.cfg.tensor_parallel_size <= 1\n                and self.cfg.fsdp_config.cpu_ram_efficient_loading\n                and self.cfg.fsdp_version == 2\n            ):\n                # setting device_map for TP is not supported\n                local_rank = int(os.getenv(\"LOCAL_RANK\", \"0\"))\n                if local_rank == 0:\n                    self.model_kwargs[\"device_map\"] = \"cpu\"\n                else:\n                    self.model_kwargs[\"device_map\"] = \"meta\"\n\n        if (\n            self.is_qlora_and_fsdp_enabled\n            and self.cfg.fsdp_config.cpu_ram_efficient_loading\n            and (\n                self.cfg.model_config_type == \"dbrx\"\n                or self.cfg.qlora_sharded_model_loading\n            )\n        ):\n            if self.cfg.reinit_weights:\n                LOG.warning(\n                    \"reinit_weights is not supported with sharded quantized loading. \"\n                    \"Loading from pretrained weights instead.\"\n                )\n            quant_storage = self.cfg.torch_dtype\n            quantization_config = getattr(\n                self.model_config, \"quantization_config\", None\n            )\n            quantization_config = (\n                quantization_config or self.model_kwargs[\"quantization_config\"]\n            )\n            self.model = load_sharded_model_quant(\n                self.base_model,\n                self.model_config,\n                self.cfg,\n                quant_storage=quant_storage,\n                quantization_config=quantization_config,\n            )\n            skip_move_to_device = True\n        elif self.model_type == \"MambaLMHeadModel\":\n            if self.cfg.reinit_weights:\n                LOG.warning(\n                    \"reinit_weights is not supported with MambaLMHeadModel. \"\n                    \"Loading from pretrained weights instead.\"\n                )\n            # FIXME this is janky at best and hacked together to make it work\n            MambaLMHeadModel = fix_mamba_attn_for_loss()\n\n            self.model_kwargs[\"dtype\"] = self.model_kwargs[\"torch_dtype\"]\n            self.model_kwargs[\"device\"] = torch.cuda.current_device()\n            self.model_kwargs.pop(\"torch_dtype\", None)\n            self.model_kwargs.pop(\"device_map\", None)\n\n            self.model = MambaLMHeadModel.from_pretrained(\n                self.base_model,\n                **self.model_kwargs,\n            )\n        else:\n            # Please don't remove underscore binding without reading the fn docstring\n            _ = self._configure_zero3_memory_efficient_loading()\n\n            if (\n                self.model_type\n                and self.model_type != \"AutoModelForCausalLM\"\n                and not self.cfg.trust_remote_code\n                and not self.cfg.gptq\n            ):\n                # Use model type from transformers\n                model_loader_class = getattr(transformers, self.model_type)\n            else:\n                # Use auto model loader (handles gptq and default cases)\n                model_loader_class = self.auto_model_loader\n\n            self.model_kwargs[\"dtype\"] = self.model_kwargs[\"torch_dtype\"]\n            if self.cfg.reinit_weights:\n                self.model = self._load_model_from_config(model_loader_class)\n            else:\n                self.model = self._load_model_from_pretrained(model_loader_class)\n\n        if is_deepspeed_zero3_enabled():\n            skip_move_to_device = True\n\n        if self.cfg.tensor_parallel_size > 1:\n            # workaround for upstream 4.54.0 not setting _tp_size or _device_mesh\n            # TODO(wing): remove once 4.54.1 is released\n            if self.model._tp_size != self.cfg.tensor_parallel_size:\n                self.model._tp_size = self.cfg.tensor_parallel_size\n                self.model._device_mesh = self.model_kwargs[\"device_mesh\"]\n\n        if self.cfg.experimental_skip_move_to_device is not None:\n            skip_move_to_device = self.cfg.experimental_skip_move_to_device\n\n        return skip_move_to_device\n\n    def _set_z3_leaf_modules(self):\n        from deepspeed.utils import set_z3_leaf_modules\n\n        moe_type = self.cfg.model_config_type_text or self.cfg.model_config_type\n        if moe_type in MOE_ARCH_BLOCK:\n            moe_blocks = MOE_ARCH_BLOCK[moe_type]\n            moe_blocks = [moe_blocks] if isinstance(moe_blocks, str) else moe_blocks\n            set_z3_leaf_modules(\n                self.model,\n                [\n                    get_module_class_from_name(self.model, module_name)\n                    for module_name in moe_blocks\n                ],\n            )\n\n    def _prepare_model_for_quantization(self):\n        \"\"\"Prepare loaded model for quantization.\"\"\"\n        skip_prepare_model_for_kbit_training = False\n        if self.cfg.model_config_type == \"qwen\" and self.cfg.adapter == \"lora\":\n            # Qwen doesn't play nicely with LoRA if this is enabled\n            skip_prepare_model_for_kbit_training = True\n\n        loftq_bits = (\n            self.cfg.peft\n            and self.cfg.peft.loftq_config\n            and self.cfg.peft.loftq_config.loftq_bits\n        )\n        if self.cfg.adapter == \"lora\" and loftq_bits:\n            skip_prepare_model_for_kbit_training = True\n\n        if (\n            self.is_qlora_and_fsdp_enabled\n            or (self.is_fsdp_enabled and self.cfg.fsdp_config.cpu_ram_efficient_loading)\n            or is_deepspeed_zero3_enabled()\n        ):\n            # Make sure everything is in the same dtype\n            skip_prepare_model_for_kbit_training = True\n\n        if getattr(self.model, \"_moe_experts_quantized\", False):\n            # Parametrized expert tensors dequantize on access — would OOM.\n            skip_prepare_model_for_kbit_training = True\n\n        if (\n            not skip_prepare_model_for_kbit_training\n            and self.cfg.adapter in [\"lora\", \"qlora\"]\n            and (self.cfg.load_in_8bit or self.cfg.load_in_4bit)\n        ):\n            LOG.info(\"converting PEFT model w/ prepare_model_for_kbit_training\")\n            self.model = prepare_model_for_kbit_training(\n                self.model, use_gradient_checkpointing=self.cfg.gradient_checkpointing\n            )\n\n    def _convert_embedding_modules_dtype(\n        self,\n        embedding_modules: list[str],\n        dist_dtype: torch.dtype,\n        before_kbit_train_or_finetune: bool,\n    ):\n        dest = {\"dtype\": dist_dtype}\n        if self.cfg.lora_on_cpu:\n            dest[\"device\"] = \"cpu\"\n        for name, module in self.model.named_modules():\n            if \"norm\" in name:\n                module.to(dist_dtype)\n            if before_kbit_train_or_finetune:\n                if name.endswith(\".gate\"):\n                    module.to(dist_dtype)\n                if self.model_config.model_type == \"btlm\":\n                    # don't upcast lm_head for btlm\n                    continue\n            if any(m in name for m in embedding_modules) and hasattr(module, \"weight\"):\n                module.to(**dest)\n"
  },
  {
    "path": "src/axolotl/loaders/patch_manager.py",
    "content": "\"\"\"Patch manager class implementation to complement `axolotl.loaders.ModelLoader`.\n\nApplies pre- and post-model load patches for various fixes and optimizations.\n\"\"\"\n\nimport importlib.util\nimport os\nfrom functools import cached_property\n\nimport addict\nimport transformers\nfrom transformers import PretrainedConfig, PreTrainedModel\nfrom transformers.modeling_flash_attention_utils import is_flash_attn_available\n\nfrom axolotl.integrations.base import PluginManager\nfrom axolotl.monkeypatch.multipack import (\n    SUPPORTED_MULTIPACK_MODEL_TYPES,\n    patch_for_multipack,\n)\nfrom axolotl.utils.dict import DictDefault\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\nPLUGIN_MANAGER = PluginManager.get_instance()\n\n\nclass PatchManager:\n    \"\"\"Manages the application of patches during the model loading process.\"\"\"\n\n    @staticmethod\n    def apply_pre_config_load_patches(cfg: DictDefault):\n        \"\"\"\n        Apply patches that must be set up before config loading.\n        This is for patches that intercept remote code loading from HuggingFace,\n        which needs to be in place before AutoConfig.from_pretrained() is called.\n\n        Args:\n            cfg: Configuration dictionary with model and training settings.\n        \"\"\"\n        if (\n            hasattr(cfg, \"base_model_config\")\n            and cfg.base_model_config\n            and \"kimi-linear\" in cfg.base_model_config.lower()\n        ):\n            from axolotl.monkeypatch.models.kimi_linear.patch_kimi_linear import (\n                patch_kimi_config,\n            )\n\n            patch_kimi_config()\n\n    @staticmethod\n    def apply_pre_tokenizer_load_patches(cfg: DictDefault):\n        \"\"\"\n        Apply patches that must be set up before tokenizer loading.\n        This is for patches that intercept remote code loading from HuggingFace,\n        which needs to be in place before AutoTokenizer.from_pretrained() is called.\n\n        Args:\n            cfg: Configuration dictionary with model and training settings.\n        \"\"\"\n        if (\n            hasattr(cfg, \"tokenizer_config\")\n            and cfg.tokenizer_config\n            and \"kimi-linear\" in cfg.tokenizer_config.lower()\n        ):\n            from axolotl.monkeypatch.models.kimi_linear.patch_kimi_linear import (\n                patch_kimi_tokenizer,\n            )\n\n            patch_kimi_tokenizer()\n\n    def __init__(\n        self,\n        cfg: DictDefault,\n        model_config: PretrainedConfig | addict.Dict,\n        inference: bool = False,\n    ):\n        \"\"\"Initialize the `PatchManager`.\n\n        Args:\n            cfg: Configuration dictionary with model and training settings.\n            model_config: Configuration object for the model.\n            inference: Whether the model is being loaded for inference mode.\n        \"\"\"\n        self.cfg = cfg\n        self.model_config = model_config\n        self.inference = inference\n\n    @cached_property\n    def has_flash_attn(self) -> bool:\n        \"\"\"Check if flash attention is installed.\"\"\"\n        return importlib.util.find_spec(\"flash_attn\") is not None\n\n    def apply_pre_model_load_patches(self):\n        \"\"\"Apply pre-model load patches based on config.\"\"\"\n        self._deactivate_hf_async_load()\n        self._apply_transformers_patches()\n        # self._apply_flex_attention_patches()\n        self._apply_flash_attention_patches()\n        self._apply_chunked_cross_entropy_patch()\n        self._apply_sageattn_patches()\n        self._apply_flash_attn_4_patches()\n        self._apply_fsdp_patches()\n        self._apply_adapter_patches()\n        self._apply_model_specific_patches()\n        self._apply_fp8_patches()\n        self._apply_flash_attention_peft_patches()\n        self._apply_gradient_checkpointing_patches()\n        self._patch_attention()\n        self._apply_multipack_patches()\n        self._patch_loss_llama()\n        self._patch_llama_derived_model()\n        self._apply_mistral_cross_entropy_patch()\n        self._apply_self_attention_lora_patch()\n        self._apply_fsdp2_bnb_patches()\n        self._apply_patch_deepspeed_zero3()\n        self._apply_voxtral_patches()\n        self._apply_apertus_patches()\n        self._apply_trl_vllm_patches()\n        self._apply_trl_trainer_utils_patches()\n\n    def apply_post_plugin_pre_model_load_patches(self):\n        \"\"\"Apply post plugin-pre_model_load load patches based on config.\"\"\"\n        self._apply_tiled_mlp(self.cfg.model_config_type)\n        self._apply_moe_expert_quantization_patch()\n\n    def _apply_transformers_patches(self):\n        from axolotl.monkeypatch.transformers.trainer_loss_calc import (\n            patch_evaluation_loop,\n            patch_maybe_log_save_evaluate,\n        )\n\n        patch_evaluation_loop()\n        patch_maybe_log_save_evaluate()\n\n        if self.cfg.context_parallel_size > 1:\n            from axolotl.monkeypatch.transformers.trainer_context_parallel import (\n                patch_prepare_context_parallel_inputs,\n            )\n\n            patch_prepare_context_parallel_inputs()\n\n    def apply_post_model_build_patches(self, model: PreTrainedModel):\n        \"\"\"Apply patches right after model build, before post-load setup.\"\"\"\n        self._finalize_moe_expert_quantization(model)\n\n    def apply_post_model_load_patches(self, model: PreTrainedModel):\n        \"\"\"Apply patches that require the model instance.\"\"\"\n        self._apply_llama_flash_attn_patches(model)\n        self._apply_unsloth_patches(model)\n        self._apply_lora_kernel_patch(model)\n        self._apply_scaling_softmax_patch(model)\n\n    def _apply_flash_attention_patches(self):\n        \"\"\"Apply patches related to Flash Attention.\"\"\"\n        if self.cfg.xformers_attention and self.cfg.sample_packing:\n            from axolotl.monkeypatch.attention import patch_xformers_attn_over_fa2\n\n            patch_xformers_attn_over_fa2()\n            self.cfg.flash_attention = True\n\n    def _apply_chunked_cross_entropy_patch(self):\n        if self.cfg.chunked_cross_entropy:\n            from axolotl.monkeypatch.loss.chunked import patch_chunked_ce_loss_fn\n\n            if self.cfg.chunked_cross_entropy_num_chunks:\n                patch_chunked_ce_loss_fn(self.cfg.chunked_cross_entropy_num_chunks)\n            else:\n                patch_chunked_ce_loss_fn()\n\n    def _apply_fsdp_patches(self):\n        \"\"\"Apply patches for FSDP configurations.\"\"\"\n        if self.cfg.fsdp_config:\n            from axolotl.monkeypatch.accelerate.fsdp2 import (\n                patch_initialize_missing_keys_for_fsdp,\n            )\n\n            patch_initialize_missing_keys_for_fsdp()\n\n        if self.cfg.context_parallel_size > 1 or (\n            self.cfg.fsdp_config and str(self.cfg.fsdp_version) == \"2\"\n        ):\n            from axolotl.monkeypatch.accelerate.parallelism_config import (\n                patch_parallelism_config,\n            )\n\n            patch_parallelism_config()\n        if self.cfg.fsdp_config and str(self.cfg.fsdp_version) == \"2\":\n            from axolotl.monkeypatch.accelerate.fsdp2 import (\n                patch_accelerate_fsdp2,\n                patch_tied_keys_for_meta_device,\n            )\n\n            patch_accelerate_fsdp2()\n            if self.cfg.fsdp_config.cpu_ram_efficient_loading:\n                patch_tied_keys_for_meta_device()\n            if self.cfg.rl:\n                from axolotl.monkeypatch.trainer.trl import patch_trl_prepare_fsdp2\n\n                patch_trl_prepare_fsdp2()\n\n        # if self.cfg.fsdp_config:\n        #     # see transformers#39152\n        #     from axolotl.monkeypatch.trainer_fsdp_optim import (\n        #         patch_training_loop_for_fsdp,\n        #     )\n        #\n        #     patch_training_loop_for_fsdp()\n\n    def _apply_adapter_patches(self):\n        \"\"\"Apply patches for adapter configurations.\"\"\"\n        if self.cfg.adapter and self.cfg.embeddings_skip_upcast:\n            from axolotl.monkeypatch.peft.utils import patch_peft_prep_code\n\n            patch_peft_prep_code()\n\n    def _apply_flex_attention_patches(self):\n        \"\"\"Apply patches for flexible attention.\"\"\"\n        if self.cfg.flex_attention:\n            from axolotl.monkeypatch.attention.flex_attn import (\n                patch_flex_wrapper,\n            )\n\n            flex_attn_compile_kwargs = self.cfg.flex_attn_compile_kwargs or {}\n            patch_flex_wrapper(**flex_attn_compile_kwargs)\n\n    def _apply_sageattn_patches(self):\n        \"\"\"Apply patches for SageAttention.\"\"\"\n        if self.cfg.sage_attention:\n            from axolotl.monkeypatch.attention.sage_attn import patch_sageattn\n\n            patch_sageattn()\n\n    def _apply_flash_attn_4_patches(self):\n        \"\"\"Auto-apply FA4 when flash_attention is enabled and FA4 is available on SM90+.\"\"\"\n        if not self.cfg.flash_attention:\n            return\n\n        from axolotl.monkeypatch.attention.flash_attn_4 import patch_flash_attn_4\n\n        patch_flash_attn_4(self.model_config)\n\n    def _apply_model_specific_patches(self):\n        \"\"\"Apply patches specific to model architectures.\"\"\"\n        if (\n            self.cfg.model_config_type == \"llama4\"\n            and self.cfg.llama4_linearized_experts\n        ):\n            from axolotl.monkeypatch.models.llama4.modeling import (\n                patch_llama4_linearized_modeling,\n            )\n\n            patch_llama4_linearized_modeling()\n\n        if self.cfg.model_config_type == \"qwen3_next\" and self.cfg.sample_packing:\n            from axolotl.monkeypatch.models.qwen3_next.modeling import (\n                patch_qwen3_next_modeling_packing,\n            )\n\n            patch_qwen3_next_modeling_packing()\n\n        if self.cfg.model_config_type == \"qwen3_5\" and self.cfg.sample_packing:\n            from axolotl.monkeypatch.models.qwen3_5.modeling import (\n                patch_qwen3_5_modeling_packing,\n            )\n\n            patch_qwen3_5_modeling_packing()\n\n        if self.cfg.model_config_type == \"qwen3_5_moe\" and self.cfg.sample_packing:\n            from axolotl.monkeypatch.models.qwen3_5.modeling import (\n                patch_qwen3_5_moe_modeling_packing,\n            )\n\n            patch_qwen3_5_moe_modeling_packing()\n\n        if (\n            self.cfg.model_config_type in [\"qwen3_5\", \"qwen3_5_moe\"]\n            and self.cfg.is_multimodal\n            and self.cfg.flash_attention\n        ):\n            from axolotl.monkeypatch.models.qwen3_5.modeling import (\n                patch_qwen3_5_vlm_flash_attention,\n            )\n\n            patch_qwen3_5_vlm_flash_attention()\n\n        if self.cfg.model_config_type == \"kimi_linear\":\n            from axolotl.monkeypatch.models.kimi_linear.patch_kimi_linear import (\n                patch_kimi_model,\n            )\n\n            patch_kimi_model()\n\n    def _apply_fp8_patches(self):\n        \"\"\"Apply patches for FP8 support.\"\"\"\n        if self.cfg.fp8:\n            from axolotl.monkeypatch.trainer_accelerator_args import (\n                patch_create_accelerate_code_for_fp8,\n            )\n\n            patch_create_accelerate_code_for_fp8(\n                self.cfg.fp8_enable_fsdp_float8_all_gather\n            )\n\n    def _apply_flash_attention_peft_patches(self):\n        \"\"\"Apply patches for Flash Attention with PEFT.\"\"\"\n        if self.cfg.adapter:\n            from axolotl.monkeypatch.transformers_fa_utils import (\n                patch_fa_peft_integration,\n            )\n\n            patch_fa_peft_integration()\n\n    def _apply_gradient_checkpointing_patches(self):\n        \"\"\"Apply patches for gradient checkpointing.\"\"\"\n        if (\n            self.cfg.gradient_checkpointing\n            and self.cfg.activation_offloading == \"legacy\"\n        ):\n            from axolotl.monkeypatch.gradient_checkpointing import (\n                hf_grad_checkpoint_offload_wrapper,\n            )\n\n            transformers.modeling_utils.checkpoint = hf_grad_checkpoint_offload_wrapper\n        elif (\n            self.cfg.gradient_checkpointing\n            and self.cfg.activation_offloading == \"offload_disk\"\n        ):\n            from axolotl.monkeypatch.gradient_checkpointing import (\n                hf_grad_checkpoint_disk_offload_wrapper,\n            )\n\n            transformers.modeling_utils.checkpoint = (\n                hf_grad_checkpoint_disk_offload_wrapper\n            )\n\n    def _apply_mistral_cross_entropy_patch(self):\n        \"\"\"Apply Mistral cross entropy patch if configured.\"\"\"\n        if (\n            self.cfg.model_config_type == \"mistral\"\n            and self.cfg.flash_attn_cross_entropy_loss\n        ):\n            from axolotl.monkeypatch.mistral_attn_hijack_flash import (\n                patch_mistral_cross_entropy,\n            )\n\n            patch_mistral_cross_entropy()\n\n    def _apply_self_attention_lora_patch(self):\n        \"\"\"Apply self-attention LoRA patches if configured.\"\"\"\n        if self.cfg.lora_qkv_kernel or self.cfg.lora_o_kernel:\n            # Only patch if conditions are met\n            can_patch = (\n                self.cfg.lora_dropout == 0\n                if hasattr(self.cfg, \"lora_dropout\")\n                else True\n            )  # default to True if lora_dropout is not set\n\n            if not can_patch:\n                LOG.warning(\"Cannot patch self-attention - requires no dropout\")\n                return\n\n            from axolotl.monkeypatch.lora_kernels import patch_self_attn_lora\n\n            patch_self_attn_lora(self.cfg)\n\n    def _apply_multipack_patches(self):\n        \"\"\"Apply multipack patches if necessary.\"\"\"\n        if (\n            self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES\n            and (self.cfg.flash_attention or self.cfg.flex_attention)\n            and self.cfg.sample_packing\n        ):\n            # Get automap config if it exists\n            auto_map_config = None\n            if isinstance(self.model_config, dict) and \"auto_map\" in self.model_config:\n                auto_map_config = self.model_config[\"auto_map\"]\n            elif hasattr(self.model_config, \"auto_map\"):\n                auto_map_config = self.model_config.auto_map\n\n            # Determine if the model has remote code\n            if auto_map_config is not None:\n                has_remote_code = \"AutoModelForCausalLM\" in auto_map_config\n            else:\n                has_remote_code = False\n\n            if has_remote_code and self.cfg.trust_remote_code is not None:\n                # If explicitly set in YAML, prefer that\n                has_remote_code = self.cfg.trust_remote_code\n\n            patch_for_multipack(\n                self.cfg.model_config_type,\n                model_name=self.cfg.base_model,\n                has_remote_code=has_remote_code,\n            )\n\n        if self.cfg.sample_packing:\n            from axolotl.monkeypatch.data.batch_dataset_fetcher import (\n                apply_multipack_dataloader_patch,\n            )\n\n            LOG.info(\"Applying multipack dataloader patch for sample packing...\")\n            apply_multipack_dataloader_patch()\n\n    def _apply_fsdp2_bnb_patches(self):\n        \"\"\"Apply FSDP2 BNB patches.\"\"\"\n        if (\n            self.cfg.fsdp_config\n            and str(self.cfg.fsdp_version) == \"2\"\n            and (self.cfg.load_in_4bit or self.cfg.load_in_8bit)\n        ):\n            from axolotl.monkeypatch.fsdp2_qlora import (\n                apply_init_dtype_attrs_patch,\n                apply_init_sharded_param_patch,\n                apply_init_unsharded_param_patch,\n                apply_linear8bitlt_save_patch,\n            )\n\n            apply_init_sharded_param_patch()\n            apply_init_unsharded_param_patch()\n            apply_init_dtype_attrs_patch()\n            if self.cfg.load_in_8bit:\n                apply_linear8bitlt_save_patch()\n\n    def _deactivate_hf_async_load(self):\n        \"\"\"Load weights synchronously so they can be converted and not OOM.\"\"\"\n        if self.cfg.load_in_4bit or self.cfg.load_in_8bit:\n            os.environ[\"HF_DEACTIVATE_ASYNC_LOAD\"] = \"1\"\n\n    def _apply_moe_expert_quantization_patch(self):\n        \"\"\"Patch transformers weight loading and PEFT for MoE expert quantization.\"\"\"\n        has_target_params = bool(getattr(self.cfg, \"lora_target_parameters\", None))\n\n        if not self.cfg.quantize_moe_experts and not has_target_params:\n            return\n\n        from axolotl.monkeypatch.moe_quant import (\n            patch_peft_target_parameters_matching,\n        )\n\n        if self.cfg.quantize_moe_experts:\n            from axolotl.monkeypatch.moe_quant import patch_moe_quantization_on_load\n\n            patch_moe_quantization_on_load(self.cfg)\n\n        patch_peft_target_parameters_matching()\n\n    def _finalize_moe_expert_quantization(self, model: PreTrainedModel):\n        \"\"\"Log quantization results and set model flag for downstream use.\"\"\"\n        import torch\n\n        model._moe_experts_quantized = False\n        if self.cfg.quantize_moe_experts:\n            from axolotl.monkeypatch.moe_quant import get_moe_quantized_count\n\n            count = get_moe_quantized_count()\n            if count > 0:\n                import gc\n\n                model._moe_experts_quantized = True\n                LOG.info(\n                    \"Quantized %d MoE expert parameter(s) to %s during model loading\",\n                    count,\n                    \"4-bit\" if self.cfg.load_in_4bit else \"8-bit\",\n                )\n                gc.collect()\n                torch.cuda.empty_cache()\n\n    def _apply_tiled_mlp(self, model_type: str):\n        if self.cfg.tiled_mlp:\n            from axolotl.monkeypatch.tiled_mlp import (\n                patch_tiled_mlp,\n            )\n\n            patch_tiled_mlp(\n                model_type,\n                use_original_mlp=self.cfg.tiled_mlp_use_original_mlp,\n                cfg_num_shards=self.cfg.tiled_mlp_num_shards,\n            )\n\n    def _apply_voxtral_patches(self):\n        \"\"\"Apply patches for Voxtral model.\"\"\"\n        if self.cfg.model_config_type == \"voxtral\":\n            from axolotl.monkeypatch.models.voxtral.modeling import (\n                patch_voxtral_conditional_generation_forward,\n            )\n\n            patch_voxtral_conditional_generation_forward()\n\n    def _patch_attention(self):\n        \"\"\"Apply attention-specific patches based on model type.\"\"\"\n        if not (self.cfg.flash_attention and hasattr(self.model_config, \"model_type\")):\n            return\n\n        if self.model_config.model_type == \"btlm\":\n            from axolotl.monkeypatch.btlm_attn_hijack_flash import (\n                replace_btlm_attn_with_flash_attn,\n            )\n\n            replace_btlm_attn_with_flash_attn(self.cfg.base_model)\n\n        if self.model_config.model_type == \"stablelm_epoch\" and self.cfg.sample_packing:\n            from axolotl.monkeypatch.stablelm_attn_hijack_flash import (\n                replace_stablelm_attn_with_flash_attn,\n            )\n\n            replace_stablelm_attn_with_flash_attn(self.cfg.base_model)\n\n        if self.model_config.model_type in (\"mistral3\", \"llava\"):\n            from axolotl.monkeypatch.models.pixtral.modeling_flash_attention_utils import (\n                apply_patch_is_packed_sequence,\n            )\n\n            apply_patch_is_packed_sequence()\n\n    def _patch_loss_llama(self):\n        \"\"\"Patch loss functions and other optimizations for LLaMA models.\"\"\"\n        if not self.cfg.is_llama_derived_model:\n            return\n\n        if self.cfg.flash_attn_cross_entropy and self.has_flash_attn:\n            from axolotl.monkeypatch.llama_attn_hijack_flash import (\n                patch_fa_llama_cross_entropy,\n            )\n\n            patch_fa_llama_cross_entropy()\n        elif self.cfg.unsloth_cross_entropy_loss:\n            from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch\n\n            integrate_cross_entropy_loss_patch(model_type=\"llama\")\n\n        if self.cfg.flash_attn_rms_norm and self.has_flash_attn:\n            from axolotl.monkeypatch.llama_attn_hijack_flash import patch_llama_rms_norm\n\n            patch_llama_rms_norm()\n        elif self.cfg.unsloth_rms_norm:\n            from axolotl.monkeypatch.unsloth_ import patch_unsloth_layernorm\n\n            patch_unsloth_layernorm()\n\n        if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:\n            from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora\n\n            patch_self_attn_lora()\n\n    def _patch_llama_flash_attention(self):\n        \"\"\"Apply Flash Attention patches for LLaMA models.\"\"\"\n        from axolotl.monkeypatch.llama_attn_hijack_flash import (\n            replace_llama_attn_with_flash_attn,\n        )\n\n        if self.cfg.s2_attention:\n            LOG.info(\"patching w/ flash-enabled, shifted-sparse attention\")\n            replace_llama_attn_with_flash_attn(\n                cross_entropy=self.cfg.flash_attn_cross_entropy,\n                rms_norm=self.cfg.flash_attn_rms_norm,\n                use_shifted_sparse_attn=True,\n            )\n        elif self.cfg.flash_attn_cross_entropy or self.cfg.flash_attn_rms_norm:\n            replace_llama_attn_with_flash_attn(\n                cross_entropy=self.cfg.flash_attn_cross_entropy,\n                rms_norm=self.cfg.flash_attn_rms_norm,\n            )\n\n    def _patch_llama_xformers_attention(self):\n        \"\"\"Apply xformers attention patches for LLaMA models.\"\"\"\n        from axolotl.monkeypatch.llama_attn_hijack_xformers import (\n            hijack_llama_attention,\n        )\n\n        LOG.info(\"Patching with xformers attention...\")\n        hijack_llama_attention()\n\n    def _patch_llama_derived_model(self):\n        \"\"\"Modify all llama derived models in one block.\"\"\"\n        if self.cfg.is_llama_derived_model and not (\n            self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES\n            and (self.cfg.flash_attention or self.cfg.flex_attention)\n            and self.cfg.sample_packing\n        ):\n            if self.cfg.flash_attention:\n                self._patch_llama_flash_attention()\n            elif self.cfg.xformers_attention:\n                self._patch_llama_xformers_attention()\n            elif self.cfg.s2_attention:\n                raise NotImplementedError(\n                    \"Shifted-sparse attention not currently implemented without flash attention.\"\n                )\n\n    def _apply_llama_flash_attn_patches(self, model):\n        \"\"\"Apply LLaMA-specific flash attention patches.\"\"\"\n        if (\n            self.model_config.model_type in [\"llama\", \"llama4\"]\n            and not self.cfg.trust_remote_code\n            and not self.cfg.gptq\n            and self.cfg.flash_attention\n            and is_flash_attn_available()\n            and not self.inference\n        ):\n            # TODO(MengqingCao): split these patches separately\n            from axolotl.monkeypatch.llama_attn_hijack_flash import (\n                is_xformers_swiglu_available,\n                replace_llama_mlp_with_swiglu,\n            )\n\n            if self.cfg.flash_attn_fuse_mlp and is_xformers_swiglu_available():\n                LOG.info(\"Patching with SwiGLU...\")\n                replace_llama_mlp_with_swiglu(model)\n\n    def _apply_unsloth_patches(self, model):\n        \"\"\"Apply unsloth optimization patches.\"\"\"\n        if self.cfg.unsloth_lora_mlp:\n            from axolotl.monkeypatch.unsloth_ import integrate_lora_mlp_patch\n\n            integrate_lora_mlp_patch(peft_model=model)\n\n        if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:\n            from axolotl.monkeypatch.unsloth_ import integrate_lora_patch\n\n            integrate_lora_patch(peft_model=model, cfg=self.cfg)\n\n        if self.cfg.unsloth_rope:\n            from axolotl.monkeypatch.unsloth_ import integrate_rope_embeddings\n\n            integrate_rope_embeddings()\n\n    def _apply_lora_kernel_patch(self, model):\n        \"\"\"Apply LoRA kernel patches.\"\"\"\n        if (\n            self.cfg.lora_mlp_kernel\n            or self.cfg.lora_qkv_kernel\n            or self.cfg.lora_o_kernel\n        ):\n            from axolotl.monkeypatch.lora_kernels import apply_lora_kernel_patches\n\n            apply_lora_kernel_patches(model=model, cfg=self.cfg)\n\n    def _apply_patch_deepspeed_zero3(self):\n        try:\n            from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled\n\n            from axolotl.monkeypatch.deepspeed_utils import apply_deepspeed_patches\n\n            if self.cfg.activation_offloading is True and (\n                is_deepspeed_zero3_enabled()\n                or os.getenv(\"ACCELERATE_DEEPSPEED_ZERO_STAGE\") == \"3\"\n            ):\n                apply_deepspeed_patches()\n        except ImportError as e:\n            LOG.warning(f\"DeepSpeed patches not applied: {e}\")\n\n    def _apply_apertus_patches(self):\n        \"\"\"Apply patches for Apertus model.\"\"\"\n        if self.cfg.model_config_type == \"apertus\":\n            from axolotl.monkeypatch.models.apertus.activation import (\n                patch_apertus_xielu_activation,\n            )\n\n            patch_apertus_xielu_activation()\n\n    def _apply_trl_vllm_patches(self):\n        \"\"\"Apply TRL vLLM patches for batched weight sync, NaN logprobs fix, and scalar handling.\"\"\"\n        if (\n            self.cfg.rl\n            and getattr(self.cfg, \"trl\", None)\n            and getattr(self.cfg.trl, \"use_vllm\", False)\n        ):\n            from axolotl.monkeypatch.trainer.trl_vllm import patch_trl_vllm\n\n            patch_trl_vllm()\n\n    def _apply_trl_trainer_utils_patches(self):\n        \"\"\"Replace trl.trainer.utils.{selective_log_softmax, entropy_from_logits} with Triton kernels.\"\"\"\n        if not self.cfg.rl:\n            return\n\n        try:\n            from axolotl.monkeypatch.trainer.utils import (\n                entropy_from_logits,\n                selective_log_softmax,\n            )\n        except (ImportError, ModuleNotFoundError):\n            LOG.warning(\"Triton not available — skipping trl.trainer.utils patches\")\n            return\n\n        import trl.trainer.utils\n\n        # Guard against repeated calls: only stash the original if trl still\n        # points at its own implementation (not our wrapper).\n        if trl.trainer.utils.selective_log_softmax is not selective_log_softmax:\n            from axolotl.monkeypatch.trainer import utils as _axolotl_trainer_utils\n\n            _axolotl_trainer_utils.selective_log_softmax_original = (\n                trl.trainer.utils.selective_log_softmax\n            )\n            trl.trainer.utils.selective_log_softmax = selective_log_softmax\n\n        if trl.trainer.utils.entropy_from_logits is not entropy_from_logits:\n            trl.trainer.utils.entropy_from_logits = entropy_from_logits\n\n        LOG.info(\n            \"Patched trl.trainer.utils with Triton selective_log_softmax and entropy_from_logits\"\n        )\n\n    def _apply_scaling_softmax_patch(self, model: PreTrainedModel):\n        \"\"\"Apply Scaling Softmax (SSMax) patch.  Ref: https://arxiv.org/abs/2501.19399\"\"\"\n        if self.cfg.scaling_softmax:\n            from axolotl.monkeypatch.scaled_softmax_attn import (\n                patch_scaled_softmax_attention,\n            )\n\n            patch_scaled_softmax_attention(\n                scaling_factor_init=self.cfg.scaling_softmax_factor or 0.43,\n                bias=self.cfg.scaling_softmax_bias or 0.0,\n                model=model,\n            )\n"
  },
  {
    "path": "src/axolotl/loaders/processor.py",
    "content": "\"\"\"Processor loading functionality for multi-modal models\"\"\"\n\nimport transformers\nfrom transformers import (\n    AutoProcessor,\n    PreTrainedTokenizerBase,\n)\n\nfrom axolotl.telemetry.errors import send_errors\nfrom axolotl.utils.dict import DictDefault\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\n\n@send_errors\ndef load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):\n    processor_cls = AutoProcessor\n    if cfg.processor_type:\n        processor_cls = getattr(transformers, cfg.processor_type)\n\n    # Build common kwargs for processor loading\n    processor_kwargs = {}\n    if cfg.revision_of_model:\n        processor_kwargs[\"revision\"] = cfg.revision_of_model\n\n    if cfg.tokenizer_use_mistral_common:\n\n        def _patch_mistralcommontokenizer():\n            \"\"\"\n            Transformers v5 stops reading the sub-processor.\n\n            We need to patch this, so both processors use this.\n            \"\"\"\n            import transformers.tokenization_mistral_common as tokenization_mistral_common\n\n            from axolotl.utils.mistral import HFMistralTokenizer\n\n            tokenization_mistral_common.MistralCommonBackend = HFMistralTokenizer\n\n        _patch_mistralcommontokenizer()\n\n        from transformers import VoxtralProcessor\n\n        if processor_cls == VoxtralProcessor:\n            return VoxtralProcessor.from_pretrained(\n                cfg.processor_config,\n                **processor_kwargs,\n            )\n\n        from axolotl.utils.mistral import Mistral3Processor\n\n        return Mistral3Processor(\n            tokenizer=tokenizer,\n        )\n\n    processor_kwargs[\"trust_remote_code\"] = cfg.trust_remote_code or False\n\n    processor = processor_cls.from_pretrained(\n        cfg.processor_config,\n        **processor_kwargs,\n    )\n    processor.tokenizer = tokenizer\n\n    # Attempt to load image size from processor if available\n    if (\n        cfg.image_size is None\n        and hasattr(processor, \"size\")\n        and any(dim in processor.size for dim in [\"width\", \"height\"])\n    ):\n        im_width = None\n        im_height = None\n        if \"width\" in processor.size:\n            im_width = processor.size[\"width\"]\n        if \"height\" in processor.size:\n            im_height = processor.size[\"height\"]\n\n        # If both width and height are set, use a tuple\n        if im_width is not None and im_height is not None:\n            cfg.image_size = (im_width, im_height)\n        # If only width is set, use as integer\n        elif im_width is not None:\n            cfg.image_size = im_width\n        # If only height is set, use as integer\n        elif im_height is not None:\n            cfg.image_size = im_height\n\n        LOG.debug(f\"Loaded image size: {cfg.image_size} from processor\")\n\n    return processor\n"
  },
  {
    "path": "src/axolotl/loaders/tokenizer.py",
    "content": "\"\"\"Tokenizer loading functionality and associated utils\"\"\"\n\nimport json\nimport os\n\nimport transformers\nfrom transformers import (\n    AddedToken,\n    AutoTokenizer,\n    PreTrainedTokenizer,\n)\n\nfrom axolotl.integrations.base import PluginManager\nfrom axolotl.loaders.utils import get_linear_embedding_layers, load_model_config\nfrom axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN\nfrom axolotl.telemetry.errors import send_errors\nfrom axolotl.utils.chat_templates import get_chat_template_from_config\nfrom axolotl.utils.dict import DictDefault\nfrom axolotl.utils.distributed import (\n    barrier,\n    is_local_main_process,\n    is_main_process,\n)\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\nPLUGIN_MANAGER = PluginManager.get_instance()\n\n\ndef modify_tokenizer_files(\n    tokenizer_path: str,\n    token_mappings: dict[int, str],\n    output_dir: str,\n    revision: str = \"main\",\n) -> str:\n    \"\"\"\n    Modify tokenizer files to replace added_tokens strings, save to output directory,\n    and return the path to the modified tokenizer.\n\n    This only works with reserved tokens that were added to the tokenizer, not tokens\n    already part of the vocab.\n\n    Args:\n        tokenizer_path: Path or name of the original tokenizer\n        token_mappings: Dict mapping {token_id (int): new_token_string}\n        output_dir: Directory to save the modified tokenizer\n        revision: Model revision/branch/tag/commit to load from (HF Hub)\n\n    Returns:\n        Path to the modified tokenizer directory\n\n    Ref: https://github.com/huggingface/transformers/issues/27974#issuecomment-1854188941\n    \"\"\"\n    # Create the tokenizer directory in output_dir if it doesn't exist\n    tokenizer_dir = os.path.join(output_dir, \"tokenizer\")\n    os.makedirs(tokenizer_dir, exist_ok=True)\n\n    if is_local_main_process():\n        # Load the tokenizer\n        temp_tokenizer = AutoTokenizer.from_pretrained(\n            tokenizer_path, use_fast=True, revision=revision\n        )\n\n        # Save the tokenizer to the output directory\n        temp_tokenizer.save_pretrained(tokenizer_dir)\n\n        # Get the token IDs and map them to their new values\n        token_id_mappings = {\n            int(token_id): new_value for token_id, new_value in token_mappings.items()\n        }\n\n        # 1. Update tokenizer_config.json - added_tokens_decoder\n        config_path = os.path.join(tokenizer_dir, \"tokenizer_config.json\")\n        if os.path.exists(config_path):\n            with open(config_path, \"r\", encoding=\"utf-8\") as f:\n                config_data = json.load(f)\n\n            # Update added_tokens_decoder\n            if \"added_tokens_decoder\" in config_data:\n                for token_id, new_value in token_id_mappings.items():\n                    token_id_str = str(token_id)\n                    if token_id_str in config_data[\"added_tokens_decoder\"]:\n                        config_data[\"added_tokens_decoder\"][token_id_str][\"content\"] = (\n                            new_value\n                        )\n                    else:\n                        raise ValueError(\n                            f\"Token ID {token_id_str} not found in added_tokens_decoder\"\n                        )\n\n            # Write the updated config back\n            with open(config_path, \"w\", encoding=\"utf-8\") as f:\n                json.dump(config_data, f, indent=2)\n\n        # 2. Update tokenizer.json - added_tokens\n        tokenizer_path = os.path.join(tokenizer_dir, \"tokenizer.json\")\n        if os.path.exists(tokenizer_path):\n            with open(tokenizer_path, \"r\", encoding=\"utf-8\") as f:\n                tokenizer_data = json.load(f)\n\n            # Update added_tokens\n            if \"added_tokens\" in tokenizer_data:\n                for token_id, new_value in token_id_mappings.items():\n                    for i, token_entry in enumerate(tokenizer_data[\"added_tokens\"]):\n                        if token_entry[\"id\"] == token_id:\n                            tokenizer_data[\"added_tokens\"][i][\"content\"] = new_value\n                            break\n                    else:\n                        # Reaching this section means the token_id was not found in tokenizer.json added_tokens\n                        raise ValueError(\n                            f\"Token ID {token_id} not found in added_tokens\"\n                        )\n            if \"model\" in tokenizer_data and \"vocab\" in tokenizer_data[\"model\"]:\n                for token_id, new_value in token_id_mappings.items():\n                    for entry_val, entry_id in tokenizer_data[\"model\"][\"vocab\"].items():\n                        if entry_id == token_id:\n                            del tokenizer_data[\"model\"][\"vocab\"][entry_val]\n                            tokenizer_data[\"model\"][\"vocab\"][new_value] = token_id\n                            break\n\n            # Write the updated tokenizer data back\n            with open(tokenizer_path, \"w\", encoding=\"utf-8\") as f:\n                json.dump(tokenizer_data, f, indent=2)\n\n    barrier()\n    return tokenizer_dir\n\n\n@send_errors\ndef load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:\n    \"\"\"Load and configure the tokenizer based on the provided config.\"\"\"\n\n    # Apply patches that need to be in place before tokenizer loading\n    from axolotl.loaders.patch_manager import PatchManager\n\n    PatchManager.apply_pre_tokenizer_load_patches(cfg)\n\n    def _load_mistral_common_tokenizer(cfg: DictDefault):\n        \"\"\"Load mistral-common tokenizer\"\"\"\n        from axolotl.utils.mistral import HFMistralTokenizer\n\n        # Load the HF-compatible wrapper around MistralTokenizer\n        kwargs = {}\n        if cfg.revision_of_model:\n            kwargs[\"revision\"] = cfg.revision_of_model\n        tokenizer = HFMistralTokenizer.from_pretrained(cfg.tokenizer_config, **kwargs)\n\n        return tokenizer\n\n    if cfg.tokenizer_use_mistral_common:\n        return _load_mistral_common_tokenizer(cfg)\n\n    model_config = load_model_config(cfg)\n    tokenizer_kwargs = {}\n    use_fast = True  # this is the default\n\n    if cfg.tokenizer_use_fast is not None:\n        use_fast = cfg.tokenizer_use_fast\n    if cfg.tokenizer_legacy is not None:\n        # True is the default w/ https://github.com/huggingface/transformers/pull/25224\n        tokenizer_kwargs[\"legacy\"] = cfg.tokenizer_legacy\n    if cfg.revision_of_model:\n        tokenizer_kwargs[\"revision\"] = cfg.revision_of_model\n\n    tokenizer_cls = AutoTokenizer\n    if cfg.tokenizer_type:\n        tokenizer_cls = getattr(transformers, cfg.tokenizer_type)\n\n    # Set base tokenizer path\n    tokenizer_path = cfg.tokenizer_config\n\n    # Apply token string overrides if specified\n    if cfg.added_tokens_overrides:\n        # Modify tokenizer files and get path to modified tokenizer\n        modify_kwargs = {\"output_dir\": cfg.output_dir}\n        if cfg.revision_of_model:\n            modify_kwargs[\"revision\"] = cfg.revision_of_model\n        tokenizer_path = modify_tokenizer_files(\n            tokenizer_path, cfg.added_tokens_overrides, **modify_kwargs\n        )\n\n    tokenizer = tokenizer_cls.from_pretrained(\n        tokenizer_path,\n        trust_remote_code=cfg.trust_remote_code or False,\n        use_fast=use_fast,\n        **tokenizer_kwargs,\n    )\n\n    if (\n        tokenizer.__class__.__name__\n        in [\n            \"LlamaTokenizer\",\n            \"LlamaTokenizerFast\",\n            \"CodeLlamaTokenizer\",\n            \"CodeLlamaTokenizerFast\",\n        ]\n        and hasattr(tokenizer, \"pad_token\")\n        and not tokenizer.pad_token\n    ):\n        # set a pad_token, but use eos_token so we don't add a new token\n        tokenizer.pad_token = LLAMA_DEFAULT_EOS_TOKEN\n\n    if tokenizer.__class__.__name__ == \"GPTNeoXTokenizerFast\":\n        tokenizer.add_special_tokens({\"pad_token\": \"[PAD]\"})  # nosec B105\n        os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n\n    # Mistral's official FA implementation requires left padding\n    if cfg.is_mistral_derived_model and cfg.flash_attention and not cfg.sample_packing:\n        tokenizer.padding_side = \"left\"\n\n    # Qwen base only has single token, so we need to set the special tokens\n    # the following check is for Qwen1 base models\n    if cfg.is_qwen_derived_model and hasattr(tokenizer, \"eod_id\"):\n        token_ids = [\"bos_token_id\", \"eos_token_id\", \"pad_token_id\", \"unk_token_id\"]\n        for attr_name in token_ids:\n            if getattr(tokenizer, attr_name) is None:\n                setattr(tokenizer, attr_name, tokenizer.eod_id)\n\n        token_names = [\"bos_token\", \"eos_token\", \"pad_token\", \"unk_token\"]\n        for attr_name in token_names:\n            if getattr(tokenizer, attr_name) is None:\n                setattr(tokenizer, attr_name, \"<|endoftext|>\")\n\n    additional_special_tokens = None\n    if cfg.special_tokens:\n        special_tokens = cfg.special_tokens.to_dict()\n        additional_special_tokens = special_tokens.pop(\n            \"additional_special_tokens\", None\n        )\n        lora_modules_to_save = get_linear_embedding_layers(model_config.model_type)\n        for k, val in special_tokens.items():\n            # check if new special token is not already in tokenizer and\n            # is adapter training to make sure lora_modules_to_save is set\n\n            if (\n                (getattr(tokenizer, k) is None or getattr(tokenizer, k) != val)\n                and (len(tokenizer.encode(val, add_special_tokens=False)) > 2)\n                and cfg.adapter\n                and (\n                    not cfg.lora_modules_to_save\n                    or not all(\n                        x in cfg.lora_modules_to_save for x in lora_modules_to_save\n                    )\n                )\n                and k != \"pad_token\"\n            ):\n                lora_modules_to_save_str = \", \".join(\n                    [f\"`{x}`\" for x in lora_modules_to_save]\n                )\n                raise ValueError(\n                    f\"Please set lora_modules_to_save to [{lora_modules_to_save_str}] \"\n                    \"when using an adapter and changing the special tokens.\"\n                )\n\n            tokenizer.add_special_tokens(\n                {k: AddedToken(val, rstrip=False, lstrip=False, normalized=False)}\n            )\n\n        # If we add bos_token and eos_token, we need to update the post processor to\n        # handle them correctly.\n        # https://github.com/huggingface/transformers/pull/24132\n        bos_or_eos_in_special_tokens = (\n            \"bos_token\" in cfg.special_tokens and \"eos_token\" in cfg.special_tokens\n        )\n        if (\n            tokenizer.__class__.__name__\n            in (\n                \"LlamaTokenizerFast\",\n                \"CodeLlamaTokenizerFast\",\n            )\n            and bos_or_eos_in_special_tokens\n        ):\n            tokenizer.update_post_processor()\n\n    if cfg.tokens:\n        tokenizer.add_tokens(\n            [\n                AddedToken(token, rstrip=False, lstrip=False, normalized=False)\n                for token in cfg.tokens\n            ]\n        )\n\n    # Additional special tokens are a List, and need to be treated differently than regular special\n    # tokens. We add them after we have called `add_tokens` in case these additional special tokens\n    # are new tokens.\n    #\n    # Usage:\n    #\n    # ```py\n    # special_tokens:\n    #   additional_special_tokens: [\"<|im_start|>\", \"<|im_end|>\"]\n    # ```\n    if additional_special_tokens is not None:\n        tokenizer.add_special_tokens(\n            {\"additional_special_tokens\": additional_special_tokens}\n        )\n\n    if is_main_process():\n        LOG.debug(f\"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}\")\n        LOG.debug(f\"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}\")\n        LOG.debug(f\"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}\")\n        LOG.debug(f\"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}\")\n\n    if cfg.chat_template:\n        chat_template_string = get_chat_template_from_config(\n            cfg=cfg,\n            tokenizer=tokenizer,\n        )\n        if cfg.default_system_message and cfg.chat_template == \"chatml\":\n            chat_template_string = chat_template_string.replace(\n                \"You are a helpful assistant.\", cfg.default_system_message\n            )\n\n        tokenizer.chat_template = chat_template_string\n    elif getattr(tokenizer, \"chat_template\", None) is None:\n        LOG.info(\n            \"No Chat template selected. Consider adding a chat template for easier inference.\"\n        )\n\n    # make the tokenizer.pad call quieter 🤐\n    if hasattr(tokenizer, \"deprecation_warnings\"):\n        tokenizer.deprecation_warnings[\"Asking-to-pad-a-fast-tokenizer\"] = True\n\n    return tokenizer\n"
  },
  {
    "path": "src/axolotl/loaders/utils.py",
    "content": "\"\"\"Utilities for axolotl.loaders module\"\"\"\n\nimport contextlib\nfrom typing import Type\n\nimport addict\nimport torch\nimport transformers\nfrom transformers import AutoConfig, PretrainedConfig, PreTrainedModel\n\nfrom axolotl.utils.dict import DictDefault\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\n\ndef get_module_class_from_name(\n    module: torch.nn.Module, name: str\n) -> Type[torch.nn.Module] | None:\n    \"\"\"Gets a class from a module by its name. Copied from `accelerate.utils.dataclasses`\n    (https://github.com/huggingface/accelerate/blob/main/src/accelerate/utils/dataclasses.py#L2805).\n\n    Args:\n        module: The module to get the class from.\n        name: The name of the class.\n\n    Returns:\n        The class type of the matching module, or `None` if no match is found.\n    \"\"\"\n    modules_children = list(module.children())\n    if module.__class__.__name__ == name:\n        return module.__class__\n\n    if len(modules_children) == 0:\n        return None\n\n    for child_module in modules_children:\n        module_class = get_module_class_from_name(child_module, name)\n        if module_class is not None:\n            return module_class\n\n    return None\n\n\ndef check_model_config(cfg: DictDefault, model_config: PretrainedConfig):\n    \"\"\"Validates and adjusts model config based on `axolotl` config.\n\n    This function performs several important checks and adjustments:\n        - Disables model caching for better memory efficiency\n        - Handles multimodal model-specific configurations\n        - Validates quantization settings\n        - Ensures proper LoRA configuration when using adapters with new tokens\n\n    Args:\n        cfg: Dictionary mapping `axolotl` config keys to values.\n        model_config: The model's configuration object from `transformers`.\n\n    Raises:\n        ValueError: If a multimodal model lacks text configuration, if GPTQ settings\n            are inconsistent, or if LoRA `modules_to_save` is improperly configured\n            with new tokens.\n    \"\"\"\n    if hasattr(model_config, \"use_cache\"):\n        model_config.use_cache = False\n\n    if cfg.is_multimodal:\n        # For multimodal configs, use_cache is set in the text_config\n        if hasattr(model_config, \"get_text_config\"):\n            text_config = model_config.get_text_config()\n            if hasattr(text_config, \"use_cache\"):\n                text_config.use_cache = False\n        else:\n            raise ValueError(\n                \"No text config found for multimodal model. Please raise an Issue with model details.\"\n            )\n\n        # Check if image_size is not set and load image size from model config if available\n        if (\n            cfg.image_size is None\n            and hasattr(model_config, \"vision_config\")\n            and hasattr(model_config.vision_config, \"image_size\")\n        ):\n            image_size = model_config.vision_config.image_size\n            if isinstance(image_size, list):\n                cfg.image_size = tuple(image_size)\n            else:\n                cfg.image_size = image_size\n            LOG.debug(f\"Loaded image size: {cfg.image_size} from model config\")\n\n    quant_config_exists = (\n        hasattr(model_config, \"quantization_config\")\n        and model_config.quantization_config\n    )\n\n    # Detect compressed-tensors config\n    is_compressed_tensors_config = (\n        quant_config_exists\n        and model_config.quantization_config.get(\"quant_method\") == \"compressed-tensors\"\n    )\n\n    if is_compressed_tensors_config:\n        if model_config.quantization_config.get(\"config_groups\"):\n            LOG.warning(\n                \"Found `config_groups` in a compressed-tensors config. \"\n                \"QAT integration with llmcompressor is not tested.\"\n            )\n        # Skip further quant checks for compressed-tensors\n        return\n\n    quant_config_method_is_gptq = (\n        quant_config_exists\n        and \"quant_method\" in model_config.quantization_config\n        and model_config.quantization_config[\"quant_method\"] == \"gptq\"\n    )\n\n    if cfg.gptq and not quant_config_method_is_gptq:\n        raise ValueError(\n            \"model_config.quantization_config is not set or quant_method is not set to gptq. \"\n            \"Please make sure to point to a GPTQ model.\"\n        )\n\n    lora_modules_to_save = get_linear_embedding_layers(model_config.model_type)\n    if (\n        cfg.adapter\n        and cfg.tokens\n        and (\n            not cfg.lora_modules_to_save\n            or not all(x in cfg.lora_modules_to_save for x in lora_modules_to_save)\n        )\n    ):\n        lora_modules_to_save_joined = \", \".join(\n            map(lambda x: f\"`{x}`\", lora_modules_to_save)\n        )\n        raise ValueError(\n            \"`lora_modules_to_save` not properly set when adding new tokens. \"\n            f\"Please include [{lora_modules_to_save_joined}] in `lora_modules_to_save`.\"\n        )\n\n    if (\n        cfg.tensor_parallel_size\n        and cfg.tensor_parallel_size > 1\n        and hasattr(model_config, \"tie_word_embeddings\")\n        and model_config.tie_word_embeddings\n    ):\n        raise ValueError(\n            \"Tensor parallelism is incompatible with models configured with `tie_word_embeddings` enabled. \"\n            \"Please use a model without `tie_word_embeddings`, or disable tensor parallelism.\"\n        )\n\n\ndef load_model_config(cfg: DictDefault) -> PretrainedConfig | addict.Dict:\n    \"\"\"Loads and configures a model configuration from HuggingFace or local sources.\n\n    This function determines the appropriate model config source, loads it, applies any\n    necessary overrides, and validates it for compatibility with the `axolotl` config.\n\n    If `cfg.cls_model_config` is set, a custom config class from transformers will be\n    used instead of `AutoConfig` (e.g., 'LlamaConfig', 'MistralConfig').\n\n    Args:\n        cfg: Dictionary mapping `axolotl` config keys to values.\n\n    Returns:\n        A configured model configuration object (`AutoConfig` instance), or a simple\n            dictionary configuration for special cases like Mamba models.\n\n    Raises:\n        ValueError: If configuration loading fails for reasons other than special cases\n            that are handled (e.g., Mamba models).\n    \"\"\"\n    model_config_name = cfg.base_model_config or cfg.base_model\n    if not model_config_name and cfg.tokenizer_config:\n        model_config_name = cfg.tokenizer_config\n    trust_remote_code = cfg.trust_remote_code is True\n    config_kwargs = {}\n    if cfg.revision_of_model:\n        config_kwargs[\"revision\"] = cfg.revision_of_model\n    if cfg.num_labels:\n        # num_labels is used to initialize classifier models\n        config_kwargs[\"num_labels\"] = cfg.num_labels\n\n    config_cls = AutoConfig\n    if cfg.cls_model_config:\n        config_cls = getattr(transformers, cfg.cls_model_config)\n\n    try:\n        model_config = config_cls.from_pretrained(\n            model_config_name,\n            trust_remote_code=trust_remote_code,\n            **config_kwargs,\n        )\n    except ValueError as error:\n        if \"mamba\" in model_config_name:\n            return addict.Dict(\n                {\n                    \"model_type\": \"mamba\",\n                }\n            )\n        raise error\n\n    if cfg.overrides_of_model_config:\n        for key, val in cfg.overrides_of_model_config.items():\n            setattr(model_config, key, val)\n\n    check_model_config(cfg, model_config)\n\n    return model_config\n\n\ndef ensure_dtype(model: PreTrainedModel, dtype: torch.dtype = torch.bfloat16):\n    \"\"\"Ensures all modules in the model are converted to the specified data type.\"\"\"\n    for name, module in model.named_modules():\n        weight_mismatch = False\n        with contextlib.suppress(AttributeError):\n            weight_mismatch = module.weight.dtype != dtype\n\n        bias_mismatch = False\n        with contextlib.suppress(AttributeError):\n            bias_mismatch = module.bias.dtype != dtype\n\n        if weight_mismatch:\n            LOG.debug(\n                f\"Converting module {name}.weight: {module.weight.dtype} -> {dtype}\"\n            )\n        if bias_mismatch:\n            LOG.debug(f\"Converting module {name}.bias: {module.bias.dtype} -> {dtype}\")\n        if weight_mismatch or bias_mismatch:\n            module.to(dtype)\n\n\ndef get_linear_embedding_layers(model_type: str) -> list[str]:\n    \"\"\"Returns layer names of linear embeddings needed for LoRA based on model type.\"\"\"\n    if model_type == \"gpt_neox\":\n        return [\"embed_in\", \"embed_out\"]\n    if model_type == \"falcon\":\n        return [\"word_embeddings\", \"lm_head\"]\n    return [\"embed_tokens\", \"lm_head\"]\n"
  },
  {
    "path": "src/axolotl/logging_config.py",
    "content": "\"\"\"Common logging module for axolotl.\"\"\"\n\nimport logging\nimport os\nfrom logging import Formatter, Logger, LogRecord\nfrom logging.config import dictConfig\nfrom typing import Any, Dict\n\nfrom colorama import Fore, Style, init\n\nDEFAULT_AXOLOTL_LOG_LEVEL = \"INFO\"\nDEFAULT_LOG_LEVEL = \"WARNING\"\n\n\nclass AxolotlOrWarnErrorFilter(logging.Filter):\n    \"\"\"\n    Allows ANY WARNING or higher (unless overridden by LOG_LEVEL). Allows axolotl.* at\n    INFO or higher (unless overridden by AXOLOTL_LOG_LEVEL). Drops all other records\n    (i.e. non-axolotl.INFO, DEBUG, etc. by default).\n    \"\"\"\n\n    def __init__(self, **kwargs):\n        super().__init__(**kwargs)\n\n        axolotl_log_level = os.getenv(\n            \"AXOLOTL_LOG_LEVEL\", DEFAULT_AXOLOTL_LOG_LEVEL\n        ).upper()\n        other_log_level = os.getenv(\"LOG_LEVEL\", DEFAULT_LOG_LEVEL).upper()\n\n        try:\n            # py311+ only\n            level_mapping = logging.getLevelNamesMapping()\n            self.axolotl_level = level_mapping[axolotl_log_level]\n            self.other_level = level_mapping[other_log_level]\n        except AttributeError:\n            # For py310, use getLevelName directly\n            self.axolotl_level = logging.getLevelName(axolotl_log_level)\n            self.other_level = logging.getLevelName(other_log_level)\n\n    def filter(self, record: LogRecord) -> bool:\n        # General filter\n        if record.levelno >= self.other_level:\n            return True\n\n        # Axolotl filter\n        return (\n            record.name.startswith(\"axolotl\") and record.levelno >= self.axolotl_level\n        )\n\n\nclass AxolotlLogger(Logger):\n    \"\"\"Logger that applies filtering to non-axolotl loggers.\"\"\"\n\n    def __init__(self, name: str, level: int = logging.NOTSET):\n        super().__init__(name, level)\n        if not name.startswith(\"axolotl\"):\n            self.addFilter(AxolotlOrWarnErrorFilter())\n\n\nclass ColorfulFormatter(Formatter):\n    \"\"\"\n    Formatter to add coloring to log messages by log type\n    \"\"\"\n\n    COLORS = {\n        \"WARNING\": Fore.YELLOW,\n        \"ERROR\": Fore.RED,\n        \"CRITICAL\": Fore.RED + Style.BRIGHT,\n    }\n\n    def format(self, record):\n        record.rank = int(os.getenv(\"LOCAL_RANK\", \"0\"))\n        record.rank_fmt = f\" [RANK:{record.rank}]\" if record.rank != 0 else \"\"\n        log_message = super().format(record)\n        return self.COLORS.get(record.levelname, \"\") + log_message + Fore.RESET\n\n\nDEFAULT_LOGGING_CONFIG: Dict[str, Any] = {\n    \"version\": 1,\n    \"disable_existing_loggers\": False,\n    \"formatters\": {\n        \"simple\": {\n            \"format\": \"[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d] %(message)s\",\n        },\n        \"colorful\": {\n            \"()\": ColorfulFormatter,\n            \"format\": \"[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d]%(rank_fmt)s %(message)s\",\n        },\n        \"concise\": {\n            \"format\": \"[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s\",\n        },\n        \"concise_color\": {\n            \"()\": ColorfulFormatter,\n            \"format\": \"[%(asctime)s] [%(levelname)s] [%(name)s]%(rank_fmt)s %(message)s\",\n        },\n    },\n    \"filters\": {\n        \"ax_or_warn\": {\n            \"()\": \"axolotl.logging_config.AxolotlOrWarnErrorFilter\",\n        },\n    },\n    \"handlers\": {\n        \"console\": {\n            \"class\": \"logging.StreamHandler\",\n            \"formatter\": \"concise\",\n            \"filters\": [\"ax_or_warn\"],\n            \"stream\": \"ext://sys.stdout\",\n        },\n        \"color_console\": {\n            \"class\": \"logging.StreamHandler\",\n            \"formatter\": \"concise_color\",\n            \"filters\": [\"ax_or_warn\"],\n            \"stream\": \"ext://sys.stdout\",\n        },\n        \"ax_file_only\": {\n            \"class\": \"logging.StreamHandler\",\n            \"level\": \"DEBUG\",\n            \"formatter\": \"simple\",\n            \"stream\": \"ext://axolotl.utils.tee.file_only_stream\",\n        },\n        \"root_file_only\": {\n            \"class\": \"logging.StreamHandler\",\n            \"level\": \"DEBUG\",\n            \"formatter\": \"simple\",\n            \"stream\": \"ext://axolotl.utils.tee.file_only_stream\",\n        },\n    },\n    \"root\": {\n        \"handlers\": [\"console\", \"root_file_only\"],\n        \"level\": os.getenv(\"LOG_LEVEL\", DEFAULT_LOG_LEVEL).upper(),\n    },\n    \"loggers\": {\n        \"axolotl\": {\n            \"handlers\": [\"color_console\", \"ax_file_only\"],\n            \"level\": os.getenv(\"AXOLOTL_LOG_LEVEL\", DEFAULT_AXOLOTL_LOG_LEVEL).upper(),\n            \"propagate\": False,\n        },\n    },\n}\n\n\ndef configure_logging():\n    \"\"\"Configure with default logging\"\"\"\n    init()  # Initialize colorama\n\n    dictConfig(DEFAULT_LOGGING_CONFIG)\n    logging.setLoggerClass(AxolotlLogger)\n\n    # Route Python warnings through logging so they reach file handlers\n    logging.captureWarnings(True)\n\n    # Set default `ACCELERATE_LOG_LEVEL` to `LOG_LEVEL` if available and not set\n    if \"ACCELERATE_LOG_LEVEL\" not in os.environ:\n        os.environ[\"ACCELERATE_LOG_LEVEL\"] = os.getenv(\n            \"LOG_LEVEL\", DEFAULT_LOG_LEVEL\n        ).upper()\n"
  },
  {
    "path": "src/axolotl/models/__init__.py",
    "content": ""
  },
  {
    "path": "src/axolotl/models/mamba/__init__.py",
    "content": "\"\"\"\nModeling module for Mamba models\n\"\"\"\n\nimport importlib\n\n\ndef check_mamba_ssm_installed():\n    mamba_ssm_spec = importlib.util.find_spec(\"mamba_ssm\")\n    if mamba_ssm_spec is None:\n        raise ImportError(\n            \"MambaLMHeadModel requires mamba_ssm. Please install it with `pip install -e .[mamba-ssm]`\"\n        )\n\n\ndef fix_mamba_attn_for_loss():\n    check_mamba_ssm_installed()\n\n    from mamba_ssm.models import mixer_seq_simple\n\n    from .modeling_mamba import MambaLMHeadModel as MambaLMHeadModelFixed\n\n    mixer_seq_simple.MambaLMHeadModel = MambaLMHeadModelFixed\n    return mixer_seq_simple.MambaLMHeadModel\n"
  },
  {
    "path": "src/axolotl/models/mamba/configuration_mamba.py",
    "content": "\"\"\"\nHF Transformers MambaConfig\n\"\"\"\n\nfrom transformers import PretrainedConfig\n\n\nclass MambaConfig(PretrainedConfig):\n    \"\"\"\n    modeling configuration for state space model/mamba\n    \"\"\"\n\n    model_type = \"mamba\"\n\n    def __init__(\n        self,\n        vocab_size=50280,\n        d_model=2560,\n        n_layer=64,\n        rms_norm=True,\n        residual_in_fp32=True,\n        fused_add_norm=True,\n        pad_vocab_size_multiple=8,\n        pad_token_id=50277,\n        bos_token_id=0,\n        eos_token_id=0,\n        tie_word_embeddings=False,\n        **kwargs,\n    ):\n        self.vocab_size = vocab_size\n        self.d_model = d_model\n        self.n_layer = n_layer\n        self.rms_norm = rms_norm\n        self.residual_in_fp32 = residual_in_fp32\n        self.fused_add_norm = fused_add_norm\n        self.pad_vocab_size_multiple = pad_vocab_size_multiple\n        super().__init__(\n            pad_token_id=pad_token_id,\n            bos_token_id=bos_token_id,\n            eos_token_id=eos_token_id,\n            tie_word_embeddings=tie_word_embeddings,\n            **kwargs,\n        )\n"
  },
  {
    "path": "src/axolotl/models/mamba/modeling_mamba.py",
    "content": "import os\nfrom collections import namedtuple\nfrom functools import partial\nfrom typing import Optional, Union\n\nimport torch\nfrom mamba_ssm.models.mixer_seq_simple import MixerModel, _init_weights\nfrom mamba_ssm.utils.generation import GenerationMixin\nfrom mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf\nfrom torch import nn\nfrom torch.nn import CrossEntropyLoss\n\nfrom axolotl.models.mamba.configuration_mamba import MambaConfig\n\n\nclass MambaLMHeadModel(nn.Module, GenerationMixin):\n    def __init__(\n        self,\n        d_model: int,\n        n_layer: int,\n        vocab_size: int,\n        initializer_cfg=None,\n        pad_vocab_size_multiple: int = 1,\n        device=None,\n        dtype=None,\n        **backbone_kwargs,\n    ) -> None:\n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n        super().__init__()\n        if vocab_size % pad_vocab_size_multiple != 0:\n            vocab_size += pad_vocab_size_multiple - (\n                vocab_size % pad_vocab_size_multiple\n            )\n        self.config = MambaConfig(\n            vocab_size=vocab_size,\n            d_model=d_model,\n            n_layer=n_layer,\n            pad_vocab_size_multiple=pad_vocab_size_multiple,\n        )\n        self.backbone = MixerModel(\n            d_model=d_model,\n            n_layer=n_layer,\n            vocab_size=vocab_size,\n            initializer_cfg=initializer_cfg,\n            **backbone_kwargs,\n            **factory_kwargs,\n        )\n        self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)\n\n        # Initialize weights and apply final processing\n        self.apply(\n            partial(\n                _init_weights,\n                n_layer=n_layer,\n                **(initializer_cfg if initializer_cfg is not None else {}),\n            )\n        )\n        self.tie_weights()\n\n    def tie_weights(self):\n        self.lm_head.weight = self.backbone.embedding.weight\n\n    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):\n        return self.backbone.allocate_inference_cache(\n            batch_size, max_seqlen, dtype=dtype, **kwargs\n        )\n\n    def forward(\n        self,\n        input_ids,\n        position_ids=None,\n        inference_params=None,\n        num_last_tokens=0,\n        labels=None,\n        **kwargs,\n    ):\n        \"\"\"\n        \"position_ids\" is just to be compatible with Transformer generation. We don't use it.\n        num_last_tokens: if > 0, only return the logits for the last n tokens\n        \"\"\"\n        hidden_states = self.backbone(input_ids, inference_params=inference_params)\n        if num_last_tokens > 0:\n            hidden_states = hidden_states[:, -num_last_tokens:]\n        lm_logits = self.lm_head(hidden_states)\n\n        CausalLMOutput = namedtuple(\"CausalLMOutput\", [\"logits\"])\n        return CausalLMOutput(logits=lm_logits)\n\n        loss = None\n        if labels is not None:\n            logits = lm_logits\n            # Shift so that tokens < n predict n\n            shift_logits = logits[..., :-1, :].contiguous()\n            shift_labels = labels[..., 1:].contiguous()\n            # Flatten the tokens\n            loss_fct = CrossEntropyLoss()\n            shift_logits = shift_logits.view(-1, self.config.vocab_size)\n            shift_labels = shift_labels.view(-1)\n            # Enable model parallelism\n            shift_labels = shift_labels.to(shift_logits.device)\n            loss = loss_fct(shift_logits, shift_labels)\n            CausalLMOutput = namedtuple(\"CausalLMOutput\", [\"logits\", \"loss\"])\n            print(loss)\n            return CausalLMOutput(logits=lm_logits, loss=loss)\n\n        else:\n            CausalLMOutput = namedtuple(\"CausalLMOutput\", [\"logits\"])\n            return CausalLMOutput(logits=lm_logits)\n\n    def save_pretrained(\n        self,\n        save_directory: Union[str, os.PathLike],\n        state_dict: Optional[dict] = None,\n        **kwargs,\n    ):\n        if state_dict is None:\n            state_dict = self.state_dict()\n        torch.save(state_dict, os.path.join(save_directory, \"pytorch_model.bin\"))\n\n    @classmethod\n    def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):\n        config = load_config_hf(pretrained_model_name)\n        model = cls(**config, device=device, dtype=dtype, **kwargs)\n        model.load_state_dict(\n            load_state_dict_hf(pretrained_model_name, device={\"\": device}, dtype=dtype)\n        )\n        return model\n"
  },
  {
    "path": "src/axolotl/monkeypatch/__init__.py",
    "content": ""
  },
  {
    "path": "src/axolotl/monkeypatch/accelerate/__init__.py",
    "content": ""
  },
  {
    "path": "src/axolotl/monkeypatch/accelerate/fsdp2.py",
    "content": "\"\"\"\nmonkeypatch for accelerate fsdp2 fix when modifying ordereddict during interation, and saving full state dicts\n\"\"\"\n\nimport copy\nimport functools\nimport os\nimport sys\n\nimport torch\nimport torch.distributed as dist\nfrom torch import nn\n\nfrom axolotl.utils.bench import log_gpu_memory_usage\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\n\ndef fsdp2_load_full_state_dict(\n    _accelerator, model: torch.nn.Module, full_sd: dict, offload_to_cpu: bool = False\n):\n    \"\"\"\n    Loads the full state dict (could be only on rank 0) into the sharded model. This is done by broadcasting the\n    parameters from rank 0 to all other ranks. This function modifies the model in-place.\n    Args:\n        accelerator (`Accelerator`): The accelerator instance\n        model (`torch.nn.Module`):\n            The model to load the state dict into, expected to be on meta device or a VRAM spike can occur\n        full_sd (`dict`): The full state dict to load, can only be on rank 0\n    \"\"\"\n    from torch.distributed.tensor import distribute_tensor\n\n    LOG.info(\"Broadcasting full state dict to all ranks...\")\n    import time\n\n    start_time = time.time()\n\n    meta_sharded_sd = model.state_dict()\n    sharded_sd = {}\n    for param_name, sharded_meta_param in meta_sharded_sd.items():\n        full_tensor = None\n        if _accelerator.is_main_process:\n            full_tensor = full_sd[param_name]\n            full_tensor = full_tensor.to(sharded_meta_param.dtype)\n\n        if hasattr(sharded_meta_param, \"device_mesh\"):\n            device_mesh = sharded_meta_param.device_mesh\n            if _accelerator.is_main_process:\n                full_tensor = full_tensor.to(device_mesh.device_type)\n            else:\n                full_tensor = torch.empty(\n                    sharded_meta_param.size(),\n                    device=device_mesh.device_type,\n                    dtype=sharded_meta_param.dtype,\n                )\n            sharded_param = distribute_tensor(\n                full_tensor,\n                device_mesh,\n                sharded_meta_param.placements,\n                src_data_rank=0,\n            )\n        else:\n            # Non-sharded parameters\n            if _accelerator.is_main_process:\n                sharded_param = full_tensor.to(torch.device(\"cuda\"))\n            else:\n                # broadcast manually\n                sharded_param = torch.empty_like(\n                    sharded_meta_param,\n                    device=torch.device(\"cuda\"),\n                    dtype=sharded_meta_param.dtype,\n                )\n            dist.broadcast(sharded_param, src=0)\n\n        if offload_to_cpu:\n            sharded_param = sharded_param.cpu()\n\n        sharded_sd[param_name] = nn.Parameter(sharded_param)\n\n        del full_tensor\n        full_sd[param_name] = None\n\n    model.load_state_dict(sharded_sd, assign=True, strict=True)\n    end_time = time.time()\n    LOG.debug(\n        f\"Time taken to load full state dict: {(end_time - start_time):.2f} seconds\"\n    )\n    log_gpu_memory_usage(LOG, \"Memory usage after broadcasting full state dict\", 0)\n    return model\n\n\ndef get_state_dict(self, model, unwrap=True):\n    \"\"\"\n    Returns the state dictionary of a model sent through [`Accelerator.prepare`] potentially without full\n    precision.\n\n    Args:\n        model (`torch.nn.Module`):\n            A PyTorch model sent through [`Accelerator.prepare`]\n        unwrap (`bool`, *optional*, defaults to `True`):\n            Whether to return the original underlying state_dict of `model` or to return the wrapped state_dict\n\n    Returns:\n        `dict`: The state dictionary of the model potentially without full precision.\n\n    Example:\n\n    ```python\n    >>> import torch\n    >>> from accelerate import Accelerator\n\n    >>> accelerator = Accelerator()\n    >>> net = torch.nn.Linear(2, 2)\n    >>> net = accelerator.prepare(net)\n    >>> state_dict = accelerator.get_state_dict(net)\n    ```\n    \"\"\"\n    from accelerate import DistributedType\n    from accelerate.utils import compare_versions\n\n    if self.distributed_type == DistributedType.DEEPSPEED:\n        zero3_sharding = self.deepspeed_config[\"zero_optimization\"][\"stage\"] == 3\n        tp_sharding = (\n            self.deepspeed_config.get(\"tensor_parallel\", {}).get(\"autotp_size\", 0) > 1\n        )\n        if zero3_sharding or tp_sharding:\n            if model.zero_gather_16bit_weights_on_model_save():\n                if tp_sharding and not compare_versions(\"deepspeed\", \">=\", \"0.16.4\"):\n                    raise ImportError(\n                        \"Deepspeed TP requires deepspeed >= 0.16.4, Please update DeepSpeed via `pip install deepspeed -U`.\"\n                    )\n                state_dict = (\n                    model._consolidated_16bit_state_dict()\n                    if tp_sharding\n                    else model._zero3_consolidated_16bit_state_dict()\n                )\n            else:\n                raise ValueError(\n                    \"Cannot get 16bit model weights because `stage3_gather_16bit_weights_on_model_save` in DeepSpeed config is False. \"\n                    \"To save the model weights in 16bit, set `stage3_gather_16bit_weights_on_model_save` to True in DeepSpeed config file or \"\n                    \"set `zero3_save_16bit_model` to True when using `accelerate config`. \"\n                    \"To save the full checkpoint, run `model.save_checkpoint(save_dir)` and use `zero_to_fp32.py` to recover weights.\"\n                )\n        else:\n            from deepspeed.checkpoint.utils import clone_tensors_for_torch_save\n\n            state_dict = clone_tensors_for_torch_save(\n                self.unwrap_model(model).state_dict()\n            )\n    elif self.is_fsdp2:\n        # https://github.com/pytorch/torchtune/blob/main/torchtune/training/_distributed.py#L465\n        from torch.distributed.tensor import DTensor\n\n        state_dict = {}\n        sharded_state_dict = model.state_dict()\n        for param_name, param in sharded_state_dict.items():\n            if param.is_cpu:\n                param = param.to(torch.device(\"cuda\"))\n\n            if isinstance(param, DTensor):\n                param = param.full_tensor()\n\n            if torch.distributed.get_rank() == 0:\n                state_dict[param_name] = param.cpu()\n            torch.distributed.barrier()\n    elif self.distributed_type == DistributedType.FSDP:\n        from torch.distributed.fsdp import (\n            FullStateDictConfig,\n            FullyShardedDataParallel as FSDP,\n            StateDictType,\n        )\n\n        full_state_dict_config = FullStateDictConfig(\n            offload_to_cpu=True, rank0_only=True\n        )\n        with FSDP.state_dict_type(\n            model, StateDictType.FULL_STATE_DICT, full_state_dict_config\n        ):\n            state_dict = model.state_dict()\n    else:\n        if unwrap:\n            model = self.unwrap_model(model)\n        state_dict = model.state_dict()\n\n    return state_dict\n\n\ndef patch_peft_param_wrapper_for_fsdp2():\n    \"\"\"Patch PEFT's _LoraParameterProxy.forward for FSDP2 DTensor compatibility.\n\n    PEFT's ParamWrapper applies LoRA via torch.nn.utils.parametrize, which adds\n    delta_weight to the base weight W inside _LoraParameterProxy.forward().\n    Under FSDP2, W may be a DTensor (from FSDP unshard) while delta_weight is a\n    regular Tensor (or vice versa), causing a RuntimeError on mixed types.\n\n    This patch promotes the non-DTensor operand to match the DTensor's spec\n    using DTensor.from_local(), which is free for Replicate placement (just\n    metadata wrapping, no communication).\n    \"\"\"\n    from peft.tuners.lora.layer import _LoraParameterProxy\n\n    if getattr(_LoraParameterProxy, \"_axolotl_fsdp2_patched\", False):\n        return\n\n    _original_forward = _LoraParameterProxy.forward\n\n    # NOTE: Replaces (not wraps) forward; assumes original is just `W + self.delta_weight`.\n    def _patched_forward(self, W):\n        from torch.distributed.tensor import DTensor\n\n        delta = self.delta_weight\n        w_is_dt = isinstance(W, DTensor)\n        d_is_dt = isinstance(delta, DTensor)\n\n        with torch.nn.utils.parametrize.cached():\n            if w_is_dt == d_is_dt:\n                return W + delta\n            if w_is_dt:\n                return W + DTensor.from_local(delta, W.device_mesh, W.placements)\n            return DTensor.from_local(W, delta.device_mesh, delta.placements) + delta\n\n    _LoraParameterProxy.forward = _patched_forward\n    _LoraParameterProxy._axolotl_fsdp2_patched = True\n    LOG.info(\"Patched PEFT _LoraParameterProxy.forward for FSDP2 DTensor compatibility\")\n\n\ndef _process_lora_module_for_fsdp(module, fsdp2_kwargs):\n    \"\"\"Helper function to process LoRA modules for FSDP2.\"\"\"\n    from peft.tuners.lora.layer import ParamWrapper\n    from torch.distributed.fsdp import fully_shard\n\n    # Skip ParamWrapper — its lora_A/B must not be independently sharded.\n    # The parent decoder layer's FSDP wrapper handles unsharding them.\n    # TODO: review if we even need to shard them separately in first place.\n    if isinstance(module, ParamWrapper):\n        return False\n\n    log_bias_dtype_mismatch = False\n\n    # Linear4Bit will keep it's bias term in fp32. If the weight dtype is in bf16 we are not able to\n    # wrap this. Therefore we must ensure the bias has the same dtype as the weight\n    if hasattr(module.base_layer, \"bias\") and module.base_layer.bias is not None:\n        if module.base_layer.weight.dtype != module.base_layer.bias.dtype:\n            log_bias_dtype_mismatch = True\n            module.base_layer.bias.data = module.base_layer.bias.data.to(\n                module.base_layer.weight.dtype\n            )\n\n    for active_adapter in module.active_adapters:\n        if module.lora_A:\n            fully_shard(module.lora_A[active_adapter], **fsdp2_kwargs)\n        if module.lora_B:\n            fully_shard(module.lora_B[active_adapter], **fsdp2_kwargs)\n        if module.lora_magnitude_vector:\n            fully_shard(module.lora_magnitude_vector[active_adapter], **fsdp2_kwargs)\n\n    # lora_embedding_A/B are ParameterDicts containing nn.Parameter (Tensors),\n    # not nn.Module. fully_shard() only accepts nn.Module, so we cannot shard\n    # individual embedding Parameters. Instead, shard the entire LoraLayer module. fully_shard() can be used hierarchically because it does not\n    # override groups already assigned by fully_shard(), so modules\n    # where fully_shard() was already called are not affected [see https://docs.pytorch.org/docs/stable/distributed.fsdp.fully_shard.html]\n    if module.lora_embedding_A or module.lora_embedding_B:\n        from torch.distributed.fsdp import FSDPModule\n\n        if not isinstance(module, FSDPModule):\n            fully_shard(module, **fsdp2_kwargs)\n\n    return log_bias_dtype_mismatch\n\n\ndef fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:\n    \"\"\"Prepares the model for FSDP2 in-place. Also returns the model to avoid misuse of the original model.\n\n    Args:\n        accelerator (`Accelerator`): The accelerator instance\n        model (`torch.nn.Module`): The model to prepare\n\n    Returns:\n        `torch.nn.Module`: Prepared model\n    \"\"\"\n    from accelerate.utils import get_module_children_bottom_up, is_compiled_module\n    from accelerate.utils.fsdp_utils import fsdp2_prepare_auto_wrap_policy\n    from accelerate.utils.modeling import get_non_persistent_buffers\n    from peft import PeftModel\n    from peft.tuners.lora import LoraLayer\n    from torch.distributed.fsdp import (\n        CPUOffloadPolicy,\n        FSDPModule,\n        MixedPrecisionPolicy,\n        fully_shard,\n    )\n\n    is_type_fsdp = isinstance(model, FSDPModule) or (\n        is_compiled_module(model) and isinstance(model._orig_mod, FSDPModule)\n    )\n    if is_type_fsdp:\n        return model\n\n    fsdp2_plugin = accelerator.state.fsdp_plugin\n\n    original_sd = model.state_dict()\n\n    from torch.distributed.fsdp.wrap import (\n        size_based_auto_wrap_policy,\n        transformer_auto_wrap_policy,\n    )\n\n    # We need the `auto_wrap_policy` original type to create a custom poilicy function for sharding\n    # This is because `fully_shard` doesn't support old auto wrap policies, rather we have to imitate the behaviour\n    if fsdp2_plugin.auto_wrap_policy is transformer_auto_wrap_policy:\n        pass  # auto_wrap_policy_type = \"transformer\"\n    elif fsdp2_plugin.auto_wrap_policy is size_based_auto_wrap_policy:\n        pass  # auto_wrap_policy_type = \"size\"\n\n    # We set `auto_wrap_policy` to `functools.partial` to avoid creating it again\n    # This is because of `apply_activation_checkpointing` which will can reuse this function\n    fsdp2_plugin.set_auto_wrap_policy(model)\n\n    if fsdp2_plugin.activation_checkpointing:\n        from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (\n            CheckpointImpl,\n            apply_activation_checkpointing,\n            checkpoint_wrapper,\n        )\n\n        # Apply activation checkpointing before applying `fully_shard`\n        apply_activation_checkpointing(\n            model,\n            checkpoint_wrapper_fn=functools.partial(\n                checkpoint_wrapper,\n                checkpoint_impl=CheckpointImpl.NO_REENTRANT,\n            ),\n            auto_wrap_policy=fsdp2_plugin.auto_wrap_policy,\n        )\n\n    mesh = getattr(accelerator.state, \"device_mesh\", None)\n\n    # Disable memory pinning if requested\n    offload_to_cpu = isinstance(fsdp2_plugin.cpu_offload, CPUOffloadPolicy)\n    if offload_to_cpu and os.environ.get(\"FSDP_CPU_OFFLOAD_PIN_MEMORY\", \"\") == \"false\":\n        fsdp2_plugin.cpu_offload.pin_memory = False\n\n    fsdp2_kwargs = {\n        \"reshard_after_forward\": fsdp2_plugin.reshard_after_forward,\n        \"offload_policy\": fsdp2_plugin.cpu_offload,\n        # `fully_shard` doesn't accept `None` in case of `MixedPrecisionPolicy`\n        \"mp_policy\": fsdp2_plugin.mixed_precision_policy or MixedPrecisionPolicy(),\n        \"mesh\": (\n            mesh[tuple(accelerator.state.parallelism_config.fsdp_dim_names)]\n            if mesh is not None\n            else None\n        ),\n    }\n    model_has_params4bit = False\n    for _, param in model.named_parameters():\n        # this is a temporary fix whereby loading models with bnb params cannot be moved from\n        # GPU to a meta device due with FSDP2 because torch operations don't return the original class type\n        # bypassing the move to meta will still cause the VRAM spike, but at least it still will load\n        if param.__class__.__name__ == \"Params4bit\":\n            model_has_params4bit = True\n            break\n\n    if fsdp2_plugin.cpu_ram_efficient_loading and not model_has_params4bit:\n        # Context: `fully_shard` moves the model to GPU if it was on CPU, however it can also be on `meta` and then it stays there even after `fully_shard`\n        # For this reason, we need to move the model to `meta` device, as then sharding happens on `meta` device\n        # If we kept the model on CPU (`cpu_ram_efficient_loading` has model be on CPU on all ranks, though non-main ranks only have `torch.emtpy`), `fully_shard` would move it to GPU\n        # Afterwards, when we call `fsdp2_load_full_state_dict`, us creating the state_dict would result into briefly having two copies of model state_dict on the GPU -> VRAM spike\n\n        # We need to keep the original non-persistent buffers, as those MAY not be in the state_dict, resulting in them staying on meta device\n        # Also, these buffers aren't getting sharded by default\n        # We get the FQNs of all non-persistent buffers, to re-register them after\n        non_persistent_buffer_fqns = get_non_persistent_buffers(\n            model, recurse=True, fqns=True\n        )\n        original_non_persistent_buffers = copy.deepcopy(\n            {k: v for k, v in model.named_buffers() if k in non_persistent_buffer_fqns}\n        )\n        # We move the model to meta device, as then sharding happens on meta device\n        model = model.to(torch.device(\"meta\"))\n        # We need to re-tie the weights, not exactly sure why, but if we don't do this, reference to `lm_head/embed_tokens` stay hanging -> more VRAM usage\n        # We assume `transformers` models have a `tie_weights` method if they support it\n        if hasattr(model, \"tie_weights\"):\n            model.tie_weights()\n\n    is_peft_model = isinstance(model, PeftModel)\n\n    # Patch PEFT's _LoraParameterProxy for DTensor compatibility if any\n    # ParamWrapper modules exist (used for target_parameters / 3D expert params).\n    if is_peft_model:\n        from peft.tuners.lora.layer import ParamWrapper\n\n        if any(isinstance(m, ParamWrapper) for m in model.modules()):\n            patch_peft_param_wrapper_for_fsdp2()\n\n    auto_wrap_policy = fsdp2_prepare_auto_wrap_policy(fsdp2_plugin, model)\n    log_bias_dtype_mismatch = False\n    if auto_wrap_policy is not None:\n        for module in get_module_children_bottom_up(model)[:-1]:\n            if is_peft_model and isinstance(module, LoraLayer):\n                module_log_bias_mismatch = _process_lora_module_for_fsdp(\n                    module, fsdp2_kwargs\n                )\n                log_bias_dtype_mismatch |= module_log_bias_mismatch\n            if auto_wrap_policy(module) and not isinstance(module, FSDPModule):\n                fully_shard(module, **fsdp2_kwargs)\n\n    fully_shard(model, **fsdp2_kwargs)\n\n    if log_bias_dtype_mismatch:\n        LOG.warning(\n            \"Bias dtype mismatch detected in LoRA base linear layer. Bias parameters have been cast to weight dtype.\"\n        )\n\n    if fsdp2_plugin.cpu_ram_efficient_loading:\n        fsdp2_load_full_state_dict(\n            accelerator, model, original_sd, offload_to_cpu=offload_to_cpu\n        )\n\n    if fsdp2_plugin.cpu_ram_efficient_loading and not model_has_params4bit:\n        # We re-register the buffers, as they may not be in the state_dict\n        for fqn, buffer_tensor in original_non_persistent_buffers.items():\n            buffer_tensor = buffer_tensor.to(accelerator.device)\n\n            if \".\" in fqn:\n                parent_fqn, local_buffer_name = fqn.rsplit(\".\", 1)\n                parent_module = model.get_submodule(parent_fqn)\n            else:\n                local_buffer_name = fqn\n                parent_module = model\n\n            parent_module.register_buffer(\n                local_buffer_name, buffer_tensor, persistent=False\n            )\n\n        # We need to tie the weights again, as call to `load_full_state_dict` breaks the tie\n        # Needs to be called both here and above\n        # removing this call makes the have slightly different loss\n        # removing the call above leads to extra memory usage as explained in the comment above\n        if hasattr(model, \"tie_weights\"):\n            model.tie_weights()\n    return model\n\n\ndef patch_tied_keys_for_meta_device():\n    \"\"\"Patch _adjust_tied_keys_with_tied_pointers to skip meta tensors.\n\n    Meta tensors all share data_ptr()==0, causing every parameter to be incorrectly\n    grouped as \"tied\". Skipping them is safe since they have no real storage.\n    \"\"\"\n    from collections import defaultdict\n\n    from transformers import PreTrainedModel\n\n    def _patched_adjust_tied_keys_with_tied_pointers(self, missing_keys):\n        param_pointers = defaultdict(list)\n        for param_name, param_value in self.state_dict().items():\n            if param_value.is_meta:\n                continue\n            param_pointers[param_value.data_ptr()].append(param_name)\n\n        tied_param_names = [\n            names\n            for names in param_pointers.values()\n            if len(names) > 1\n            and not any(name in self.all_tied_weights_keys.keys() for name in names)\n            and not all(name in missing_keys for name in names)\n        ]\n\n        tied_weights_keys_by_pointers = {\n            param_name: group[0]\n            for group in tied_param_names\n            for param_name in group[1:]\n        }\n        self.all_tied_weights_keys.update(tied_weights_keys_by_pointers)\n\n    PreTrainedModel._adjust_tied_keys_with_tied_pointers = (\n        _patched_adjust_tied_keys_with_tied_pointers\n    )\n\n\ndef patch_initialize_missing_keys_for_fsdp():\n    \"\"\"Patch _initialize_missing_keys to skip re-initialization on FSDP non-rank-0.\n\n    When using cpu_ram_efficient_loading, non-rank-0 processes load weights on\n    meta device and move them to CPU as empty tensors. Without this patch,\n    initialize_weights() re-initializes ALL parameters (via guarded init\n    functions), which is slow and uses extra RAM per process.\n\n    The fix marks all params/buffers with _is_hf_initialized=True before calling\n    the original method, so guarded init functions (init.normal_, init.zeros_,\n    etc.) become no-ops on non-rank-0 processes. The real weights arrive later\n    via FSDP broadcast from rank 0.\n\n    Upstream fix: https://github.com/huggingface/transformers/pull/44473\n    Remove this patch once transformers includes the fix in a stable release.\n    \"\"\"\n    from transformers import PreTrainedModel\n    from transformers.modeling_utils import is_fsdp_enabled, is_local_dist_rank_0\n\n    if getattr(PreTrainedModel._initialize_missing_keys, \"_axolotl_patched\", False):\n        return\n\n    _original_initialize_missing_keys = PreTrainedModel._initialize_missing_keys\n\n    def _patched_initialize_missing_keys(self, is_quantized: bool) -> None:\n        if is_fsdp_enabled() and not is_local_dist_rank_0():\n            for key in self.state_dict():\n                try:\n                    param_or_buffer = self.get_parameter_or_buffer(key)\n                    param_or_buffer._is_hf_initialized = True\n                except AttributeError:\n                    pass  # may happen when handling pre-quantized weights\n            self._is_hf_initialized = True\n\n        _original_initialize_missing_keys(self, is_quantized)\n\n    PreTrainedModel._initialize_missing_keys = _patched_initialize_missing_keys\n    PreTrainedModel._initialize_missing_keys._axolotl_patched = True\n\n\ndef patch_accelerate_fsdp2():\n    import accelerate\n\n    accelerate.accelerator.fsdp2_prepare_model = fsdp2_prepare_model\n    accelerate.Accelerator.get_state_dict = get_state_dict\n    setattr(\n        sys.modules[\"accelerate\"],\n        \"Accelerator.get_state_dict\",\n        get_state_dict,\n    )\n"
  },
  {
    "path": "src/axolotl/monkeypatch/accelerate/parallelism_config.py",
    "content": "\"\"\"\nworkaround to allow parallelism config for pure CP\n\"\"\"\n\nimport os\nimport warnings\n\nfrom accelerate import DistributedType\n\n\ndef _validate_accelerator(self, accelerator):\n    _warnings = set()\n    if not accelerator.multi_device and self.total_size == 1:\n        # No distributed setup, valid parallelism config\n        return\n\n    # We need this to ensure DDP works\n    if self.total_size == 1:\n        self._set_size(\"dp_replicate\", accelerator.num_processes)\n\n    if self.total_size != accelerator.num_processes:\n        raise ValueError(\n            f\"ParallelismConfig total_size ({self.total_size}) does not match \"\n            f\"num_processes ({accelerator.num_processes}). Please adjust dp_replicate_size/ \"\n            f\"dp_shard_size/tp_size/cp_size.\"\n        )\n\n    # allow parallelism config when not using fsdp if using pure context parallelism\n    allow_parallelism_config = False\n\n    if (\n        self.cp_size > 1\n        and self.dp_shard_size <= 1\n        and os.environ.get(\"ACCELERATE_ALLOW_CP_STANDALONE\", \"false\").lower() == \"true\"\n    ):\n        allow_parallelism_config = True\n\n    if (\n        self.total_size > 1\n        and not allow_parallelism_config\n        and not (accelerator.is_fsdp2 or accelerator.multi_device)\n    ):\n        raise ValueError(\n            f\"ParallelismConfig is only compatible DistributedType.FSDP (version 2) or DistributedType.Multi{{Device}}, but got {accelerator.distributed_type}.\"\n        )\n\n    for parallelism, size in self._sizes.items():\n        if size == 1 and getattr(self, f\"{parallelism}_handler\", None) is not None:\n            _warnings.add(\n                f\"ParallelismConfig.{parallelism}_handler is set, but {parallelism}_size is set to 1. This handler will be ignored.\"\n            )\n\n    if _warnings and accelerator.is_main_process:\n        warnings.warn(\n            \"ParallelismConfig has the following warnings:\\n\" + \"\\n\".join(_warnings),\n            UserWarning,\n            stacklevel=2,\n        )\n\n\ndef patched_is_fsdp2(self) -> bool:\n    \"\"\"\n    Patched version of is_fsdp2 that guards against a None fsdp_plugin.\n    \"\"\"\n    # The new logic checks if fsdp_plugin exists before accessing its attributes\n    return (\n        self.distributed_type == DistributedType.FSDP\n        and self.fsdp_plugin\n        and self.fsdp_plugin.fsdp_version == 2\n    )\n\n\ndef patch_parallelism_config():\n    from accelerate.accelerator import AcceleratorState, ParallelismConfig\n\n    ParallelismConfig._validate_accelerator = _validate_accelerator\n    AcceleratorState.is_fsdp2 = property(patched_is_fsdp2)\n\n\ndef patch_prepare_cp():\n    import contextlib\n\n    from accelerate import Accelerator\n\n    def patched_prepare_cp(self, *args):\n        if self.parallelism_config.cp_backend == \"deepspeed\":\n            return args\n\n        @contextlib.contextmanager\n        def _noop_cp_context(\n            buffers=None, buffer_seq_dims=None, no_restore_buffers=None\n        ):\n            yield\n\n        self._cp_context = _noop_cp_context\n        return args\n\n    Accelerator._prepare_cp = patched_prepare_cp\n"
  },
  {
    "path": "src/axolotl/monkeypatch/attention/__init__.py",
    "content": "\"\"\"\nattention module for attention monkeypatches\n\"\"\"\n\nfrom transformers.integrations.flash_attention import flash_attention_forward\n\n\ndef patch_xformers_attn_over_fa2():\n    from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS\n\n    from .xformers import xformers_attention_forward\n\n    ALL_ATTENTION_FUNCTIONS[\"flash_attention_2\"] = xformers_attention_forward\n\n\ndef unpatch_xformers_attn_over_fa2():\n    from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS\n\n    ALL_ATTENTION_FUNCTIONS[\"flash_attention_2\"] = flash_attention_forward()\n"
  },
  {
    "path": "src/axolotl/monkeypatch/attention/flash_attn_4.py",
    "content": "\"\"\"Transparently upgrade FA2 to FA4 when available on SM90+ hardware.\"\"\"\n\nimport torch\n\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\n\ndef _get_head_dims(model_config):\n    \"\"\"Extract (head_dim, head_dim_v) from a model config.\n\n    Handles composite models (e.g. Qwen3.5 VL) via text_config and\n    MLA models (DeepSeek/Kimi) that have separate Q/V head dimensions.\n    \"\"\"\n    cfg = model_config\n    if hasattr(cfg, \"text_config\"):\n        cfg = cfg.text_config\n\n    # MLA models: Q head_dim = qk_nope + qk_rope, V head_dim = v_head_dim\n    if hasattr(cfg, \"qk_nope_head_dim\") and hasattr(cfg, \"qk_rope_head_dim\"):\n        head_dim = cfg.qk_nope_head_dim + cfg.qk_rope_head_dim\n        head_dim_v = getattr(cfg, \"v_head_dim\", head_dim)\n        return head_dim, head_dim_v\n\n    # Standard models\n    if hasattr(cfg, \"head_dim\"):\n        return cfg.head_dim, cfg.head_dim\n    if hasattr(cfg, \"hidden_size\") and hasattr(cfg, \"num_attention_heads\"):\n        head_dim = cfg.hidden_size // cfg.num_attention_heads\n        return head_dim, head_dim\n\n    return None, None\n\n\ndef patch_flash_attn_4(model_config=None):\n    \"\"\"Patch _lazy_imports to redirect FA2 imports to FA4 if available on supported hardware.\"\"\"\n    if not torch.cuda.is_available():\n        return\n\n    major, _ = torch.cuda.get_device_capability()\n    # Matches flash_attn/cute/interface.py: arch / 10 in [9, 10, 11]\n    if major not in (9, 10, 11):\n        return\n\n    try:\n        from flash_attn.cute import (  # noqa: F401\n            flash_attn_func,\n            flash_attn_varlen_func,\n        )\n    except ImportError:\n        LOG.info(\n            \"Flash Attention 4 is available for your GPU and offers faster training speeds. \"\n            \"To enable: pip install flash-attn-4\"\n        )\n        return\n\n    # Validate head dimensions against FA4's own constraints\n    head_dim = None\n    if model_config is not None:\n        head_dim, head_dim_v = _get_head_dims(model_config)\n        if head_dim is not None:\n            try:\n                from flash_attn.cute.interface import _validate_head_dims\n            except ImportError:\n                LOG.warning(\n                    \"Could not import _validate_head_dims from flash_attn.cute.interface, \"\n                    \"unable to verify head dimension compatibility, falling back to FA2\"\n                )\n                return\n\n            # alignment = 16 // element_size; bf16/fp16 = 2 bytes -> alignment = 8\n            alignment = 8\n            try:\n                _validate_head_dims(head_dim, head_dim_v, major, alignment)\n            except AssertionError as exc:\n                LOG.warning(\n                    \"Model head dimensions not supported by FA4, \"\n                    \"falling back to FA2: %s\",\n                    exc,\n                )\n                return\n\n    import transformers.modeling_flash_attention_utils as fa_utils\n\n    if getattr(fa_utils._lazy_imports, \"_axolotl_patched\", False):\n        return\n\n    def _patched_lazy_imports(\n        implementation, attention_wrapper=None, allow_all_kernels=False\n    ):\n        return (\n            flash_attn_func,\n            flash_attn_varlen_func,\n            fa_utils._pad_input,\n            fa_utils._unpad_input,\n        )\n\n    _patched_lazy_imports._axolotl_patched = True\n    fa_utils._lazy_imports = _patched_lazy_imports\n    LOG.info(\n        \"Flash Attention 4 enabled (head_dim=%s)\",\n        head_dim if model_config else \"unknown\",\n    )\n"
  },
  {
    "path": "src/axolotl/monkeypatch/attention/flex_attn.py",
    "content": "\"\"\"Flex attention monkey patch\"\"\"\n\nimport sys\n\nimport torch\nimport transformers\nfrom packaging import version\nfrom transformers.utils.import_utils import _torch_version, is_torch_less_or_equal\n\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\n\ndef patch_flex_wrapper(**flex_attn_compile_kwargs):\n    # TODO remove this patch when transformers#37285 is merged and in a release\n    is_torch_2_6 = torch.__version__.startswith(\"2.6\")\n\n    if not is_torch_2_6:\n        return\n\n    from torch.nn.attention.flex_attention import flex_attention\n\n    class WrappedFlexAttention:\n        \"\"\"\n        We are doing a singleton class so that flex attention is compiled once when it's first called.\n        \"\"\"\n\n        _instance = None\n        _is_flex_compiled = False\n        _compiled_flex_attention = None\n\n        def __new__(cls, *args, **kwargs):\n            if cls._instance is None:\n                # Create a new instance if one doesn't already exist\n                cls._instance = super().__new__(cls)\n            return cls._instance\n\n        @classmethod\n        def del_singleton(cls):\n            cls._instance = None\n\n        @torch.compiler.disable(recursive=False)\n        def __init__(self, training):\n            \"\"\"\n            Initialize or update the singleton instance.\n            \"\"\"\n            self.training = None\n            if not self._is_flex_compiled or training != self.training:\n                self.training = training\n                if is_torch_less_or_equal(\"2.5.1\"):\n                    self._compiled_flex_attention = torch.compile(\n                        flex_attention, dynamic=False\n                    )\n                # In PyTorch 2.6.0, there's a known issue with flex attention compilation which may\n                # cause errors. The suggested fix is to compile with \"max-autotune-no-cudagraphs\"\n                # see https://github.com/pytorch/pytorch/issues/146260 for training\n                elif version.parse(_torch_version).base_version == \"2.6.0\" and training:\n                    self._compiled_flex_attention = torch.compile(\n                        flex_attention, dynamic=False, mode=\"max-autotune-no-cudagraphs\"\n                    )\n                # Fallback, usually the most recent torch 2.7.x+ versions\n                else:\n                    LOG.info(\n                        \"Compiling flex attention with kwargs: %s. This may take a while...\",\n                        flex_attn_compile_kwargs,\n                    )\n                    self._compiled_flex_attention = torch.compile(\n                        flex_attention,\n                        **flex_attn_compile_kwargs,\n                    )\n                    LOG.info(\"Flex attention compiled successfully.\")\n\n                self._is_flex_compiled = True\n\n        def __call__(self):\n            return self._compiled_flex_attention\n\n    transformers.integrations.flex_attention.WrappedFlexAttention = WrappedFlexAttention\n    sys.modules[\n        \"transformers.integrations.flex_attention\"\n    ].WrappedFlexAttention = WrappedFlexAttention\n"
  },
  {
    "path": "src/axolotl/monkeypatch/attention/sage_attn.py",
    "content": "\"\"\"\nMonkeypatch for SageAttention for use with transformers.\n\nhttps://github.com/thu-ml/SageAttention/\n\"\"\"\n\nimport torch\nfrom transformers.integrations.sdpa_attention import repeat_kv\n\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\nsageattn = None  # pylint: disable=invalid-name\nsageattn_varlen = None  # pylint: disable=invalid-name\n\n\ndef _is_sageattn_available():\n    \"\"\"Determine if SageAttention is available\"\"\"\n    try:\n        import sageattention  # noqa: F401 # pylint: disable=unused-import\n\n        return True\n    except ImportError:\n        return False\n\n\nif _is_sageattn_available():\n    # import sageattn here if available\n    from sageattention import sageattn, sageattn_varlen\n\n\ndef _check_sageattn_imported():\n    \"\"\"Check if SageAttention is imported. Raises an ImportError if not.\"\"\"\n    if sageattn is None:\n        raise ImportError(\n            \"SageAttention is not installed. Please install it from source: \"\n            \"`pip install git+https://github.com/thu-ml/SageAttention.git@1718ddc06dbc694bcf3c6b49ac28c1921aa2d8bd`\"\n        )\n\n\ndef sage_attention_forward(\n    module: torch.nn.Module,\n    query: torch.Tensor,\n    key: torch.Tensor,\n    value: torch.Tensor,\n    attention_mask: torch.Tensor | None = None,\n    dropout: float = 0.0,\n    scaling: float | None = None,\n    is_causal: bool | None = None,\n    **kwargs,\n) -> tuple[torch.Tensor, None]:\n    \"\"\"\n    Forward pass for SageAttention compatible with transformers attention interfaces.\n\n    https://github.com/thu-ml/SageAttention/\n    \"\"\"\n\n    _check_sageattn_imported()\n\n    if kwargs.get(\"output_attentions\", False) or kwargs.get(\"head_mask\") is not None:\n        raise NotImplementedError(\n            \"SageAttention does not support `output_attentions=True` or `head_mask`.\"\n        )\n\n    # The base sageattn API does not support dropout.\n    if dropout > 0.0:\n        raise NotImplementedError(\"SageAttention does not support dropout.\")\n\n    # Handle Grouped-Query Attention (GQA) and Multi-Query Attention (MQA)\n    if hasattr(module, \"num_key_value_groups\"):\n        key = repeat_kv(key, module.num_key_value_groups)\n        value = repeat_kv(value, module.num_key_value_groups)\n\n    # Calculate is_causal following transformers\n    assert is_causal is not False, \"is_causal must be True or None\"\n    is_causal = True\n\n    position_ids = kwargs.get(\"position_ids\", None)\n    query_length = query.shape[2]\n\n    cu_seqlens_q = kwargs.get(\"cu_seqlens_q\", None)\n    cu_seqlens_k = kwargs.get(\"cu_seqlens_k\", None)\n    max_length_q = kwargs.get(\"max_length_q\", None)\n    max_length_k = kwargs.get(\"max_length_k\", None)\n\n    # Sample packing uses position_ids, so we check for it first\n    if position_ids is not None and (\n        max_length_q is not None\n        or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all())\n    ):\n        # transpose inputs to NHD layout for use with FA2 utils\n        query = query.transpose(1, 2)\n        key = key.transpose(1, 2)\n        value = value.transpose(1, 2)\n\n        batch_size = query.size(0)\n\n        from transformers.modeling_flash_attention_utils import (\n            prepare_fa2_from_position_ids,\n        )\n\n        if cu_seqlens_q is None or cu_seqlens_k is None:\n            query, key, value, indices_q, cu_seq_lens, max_seq_lens = (\n                prepare_fa2_from_position_ids(query, key, value, position_ids)\n            )\n\n            cu_seqlens_q, cu_seqlens_k = cu_seq_lens\n            max_length_q, max_length_k = max_seq_lens\n\n        else:\n            query = query.reshape(-1, query.size(-2), query.size(-1))\n            key = key.reshape(-1, key.size(-2), key.size(-1))\n            value = value.reshape(-1, value.size(-2), value.size(-1))\n\n        attn_output_unpad = sageattn_varlen(\n            q=query,\n            k=key,\n            v=value,\n            cu_seqlens_q=cu_seqlens_q,\n            cu_seqlens_k=cu_seqlens_k,\n            max_seqlen_q=max_length_q,\n            max_seqlen_k=max_length_k,\n            is_causal=is_causal,\n            sm_scale=scaling,\n            smooth_k=False,  # reduces loss 0 / nan grad norms\n            tensor_layout=\"NHD\",\n        )\n\n        attn_output = attn_output_unpad.view(\n            batch_size, -1, attn_output_unpad.size(-2), attn_output_unpad.size(-1)\n        )\n\n    elif attention_mask is not None:\n        # NOTE: When used without `pad_to_sequence_len`, the loss becomes unstable after a few steps.\n\n        assert attention_mask.ndim == 2, \"Attention mask must be 2D\"\n\n        from transformers.modeling_flash_attention_utils import (\n            _upad_input,\n        )\n\n        # transpose inputs to NHD layout for use with FA2 utils\n        query = query.transpose(1, 2)\n        key = key.transpose(1, 2)\n        value = value.transpose(1, 2)\n\n        batch_size = query.shape[0]\n\n        query, key, value, indices_q, cu_seq_lens, max_seq_lens = _upad_input(\n            query, key, value, attention_mask, query_length\n        )\n        cu_seqlens_q, cu_seqlens_k = cu_seq_lens\n        max_seqlen_q, max_seqlen_k = max_seq_lens\n\n        attn_output_unpad = sageattn_varlen(\n            q=query,\n            k=key,\n            v=value,\n            cu_seqlens_q=cu_seqlens_q,\n            cu_seqlens_k=cu_seqlens_k,\n            max_seqlen_q=max_seqlen_q,\n            max_seqlen_k=max_seqlen_k,\n            is_causal=is_causal,\n            sm_scale=scaling,\n            tensor_layout=\"NHD\",\n        )\n\n        from flash_attn.bert_padding import pad_input\n\n        attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)\n    else:\n        # Use standard sageattn\n        # The input layout for transformers models is (batch_size, num_heads, seq_len, head_dim),\n        # which corresponds to SageAttention's \"HND\" layout.\n        attn_output = sageattn(\n            q=query,\n            k=key,\n            v=value,\n            tensor_layout=\"HND\",\n            is_causal=is_causal,\n            sm_scale=scaling,\n        )\n\n        # SageAttention with \"HND\" returns (batch, heads, seq_len, head_dim)\n        # Transformers expects (batch, seq_len, heads, head_dim) for the output\n        # So we need to transpose dimensions 1 and 2\n        attn_output = attn_output.transpose(1, 2).contiguous()\n\n    return attn_output, None\n\n\ndef patch_sageattn():\n    \"\"\"Patch SageAttention for use with transformers.\"\"\"\n\n    _check_sageattn_imported()\n\n    from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS\n\n    # Replace flash attention with sage attention\n    ALL_ATTENTION_FUNCTIONS.register(\"flash_attention_2\", sage_attention_forward)\n\n    # Note: New method after transformers refactor to use ALL_MASK_ATTENTION_FUNCTIONS\n    # Register sage_attention with the global attention interface\n    # ALL_ATTENTION_FUNCTIONS.register(\"sage_attention\", sage_attention_forward)\n\n    # from transformers.masking_utils import ALL_MASK_ATTENTION_FUNCTIONS, flash_attention_mask\n\n    # ALL_MASK_ATTENTION_FUNCTIONS.register(\"sage_attention\", flash_attention_mask)\n\n    LOG.info(\"SageAttention patched successfully\")\n"
  },
  {
    "path": "src/axolotl/monkeypatch/attention/xformers.py",
    "content": "\"\"\"\nxformers attention implementation for packing\n\"\"\"\n\nfrom typing import Optional\n\nimport torch\nimport xformers\nimport xformers.ops.fmha\nfrom transformers.modeling_flash_attention_utils import (\n    _upad_input,\n)\n\nfrom axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids\n\nxformers_attention = xformers.ops.fmha.memory_efficient_attention\n\n\ndef xformers_attention_forward(\n    module: torch.nn.Module,\n    query: torch.Tensor,\n    key: torch.Tensor,\n    value: torch.Tensor,\n    attention_mask: Optional[torch.Tensor] = None,\n    position_ids: Optional[torch.LongTensor] = None,\n    dropout: float = 0.0,\n    scaling: Optional[float] = None,\n    sliding_window: Optional[int] = None,\n    softcap: Optional[float] = None,\n    cu_seq_lens_q: Optional[torch.LongTensor] = None,\n    cu_seq_lens_k: Optional[torch.LongTensor] = None,\n    max_length_q: Optional[int] = None,\n    max_length_k: Optional[int] = None,\n    **kwargs,\n):\n    # Get dimensions\n    # query: [batch, heads, seq_len, hidden_dim]\n    batch_size = query.size(0)\n    query_length = query.shape[2]\n    key_length = key.shape[2]\n\n    # Default causal mask\n    attn_bias = xformers.ops.LowerTriangularMask()\n\n    # Check if we have sliding window attention\n    has_sliding_window = sliding_window is not None and sliding_window < query_length\n\n    # Transpose dimensions for xformers (Q: [b, h, s, d] -> [b, s, h, d])\n    query = query.transpose(1, 2)\n    key = key.transpose(1, 2)\n    value = value.transpose(1, 2)\n\n    # Get GQA parameters\n    num_attention_heads = module.config.num_attention_heads\n    num_key_value_heads = module.config.num_key_value_heads\n    head_dim = query.size(-1)\n    is_gqa = num_attention_heads != num_key_value_heads\n    n_groups = num_attention_heads // num_key_value_heads if is_gqa else 1\n\n    # If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing\n    # then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage.\n    # Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach\n    if position_ids is not None and (\n        max_length_q is not None\n        or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all())\n    ):\n        if cu_seq_lens_q is None or cu_seq_lens_k is None:\n            cu_seq_lens_q = get_cu_seqlens_from_pos_ids(position_ids)[0]\n            cu_seq_lens_q = cu_seq_lens_q.squeeze()\n            seq_lengths = cu_seq_lens_q[1:] - cu_seq_lens_q[:-1]\n            attn_bias = (\n                xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens(\n                    q_seqlen=seq_lengths.tolist(),\n                )\n            )\n        else:\n            query = query.reshape(-1, query.size(-2), query.size(-1))\n            key = key.reshape(-1, key.size(-2), key.size(-1))\n            value = value.reshape(-1, value.size(-2), value.size(-1))\n\n        # Handle GQA\n        if is_gqa:\n            key = key.repeat_interleave(n_groups, dim=2)\n            value = value.repeat_interleave(n_groups, dim=2)\n\n    elif attention_mask is not None:\n        query, key, value, _, cu_seq_lens, _ = _upad_input(\n            query, key, value, attention_mask, query_length\n        )\n        cu_seq_lens_q, cu_seq_lens_k = cu_seq_lens\n        seq_lengths = []\n        for i in range(len(cu_seq_lens_q) - 1):\n            seq_lengths.append(cu_seq_lens_q[i + 1] - cu_seq_lens_q[i])\n        attn_bias = xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens(\n            q_seqlen=seq_lengths,\n            kv_seqlen=seq_lengths,\n        )\n\n        # Handle GQA\n        if is_gqa:\n            key = key.repeat_interleave(n_groups, dim=2)\n            value = value.repeat_interleave(n_groups, dim=2)\n    else:\n        # Handle Group Query Attention (GQA) using view/expand approach from reference\n        key = key.view(batch_size, key_length, num_key_value_heads, 1, head_dim)\n        value = value.view(batch_size, key_length, num_key_value_heads, 1, head_dim)\n        key = key.expand(\n            batch_size, key_length, num_key_value_heads, n_groups, head_dim\n        )\n        value = value.expand(\n            batch_size, key_length, num_key_value_heads, n_groups, head_dim\n        )\n\n        if module.training:\n            key = key.reshape(batch_size, key_length, num_attention_heads, head_dim)\n            value = value.reshape(batch_size, key_length, num_attention_heads, head_dim)\n\n            if has_sliding_window:\n                query = query.view(\n                    1, batch_size * query_length, num_attention_heads, head_dim\n                )\n                key = key.view(\n                    1, batch_size * key_length, num_attention_heads, head_dim\n                )\n                value = value.view(\n                    1, batch_size * key_length, num_attention_heads, head_dim\n                )\n        else:\n            query = query.view(\n                batch_size, query_length, num_key_value_heads, n_groups, head_dim\n            )\n\n            # If we need a sliding window attention\n            if has_sliding_window:\n                query = query.view(\n                    1,\n                    batch_size * query_length,\n                    num_key_value_heads,\n                    n_groups,\n                    head_dim,\n                )\n                key = key.view(\n                    1, batch_size * key_length, num_key_value_heads, n_groups, head_dim\n                )\n                value = value.view(\n                    1, batch_size * key_length, num_key_value_heads, n_groups, head_dim\n                )\n\n    # Run the xformers attention\n    attn_output = xformers_attention(\n        query,\n        key,\n        value,\n        attn_bias=attn_bias,\n    )\n\n    attn_output = attn_output.view(\n        batch_size, -1, attn_output.size(-2), attn_output.size(-1)\n    )\n    return attn_output, None\n"
  },
  {
    "path": "src/axolotl/monkeypatch/btlm_attn_hijack_flash.py",
    "content": "\"\"\"\nFlash attention monkey patch for cerebras btlm model\n\"\"\"\n\nimport importlib\nfrom typing import Optional, Tuple\n\nimport torch\nfrom accelerate import init_empty_weights\nfrom flash_attn.flash_attn_interface import flash_attn_func\nfrom transformers import AutoConfig, AutoModelForCausalLM\n\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\n\ndef replace_btlm_attn_with_flash_attn(model_name=\"cerebras/btlm-3b-8k-base\"):\n    # this is a wonky hack to get the remotely loaded module\n    model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)\n    # we need to load the model here in order for modeling_btlm to be available\n    with init_empty_weights():\n        AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)\n    module_name = model_config.__class__.__module__.replace(\n        \".configuration_btlm\", \".modeling_btlm\"\n    )\n    modeling_btlm = importlib.import_module(module_name)\n    modeling_btlm.BTLMAttention._attn = flashattn_attn\n\n\ndef flashattn_attn(\n    self,\n    query: torch.Tensor,\n    key: Optional[torch.Tensor] = None,\n    value: Optional[torch.Tensor] = None,\n    attention_mask: Optional[torch.Tensor] = None,\n    head_mask: Optional[torch.Tensor] = None,\n    position_bias: Optional[torch.Tensor] = None,\n) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:\n    softmax_scale = (\n        1 / (key.size(-1) ** self.attn_scale_power) if self.scale_attn_weights else None\n    )\n\n    query = query.permute(0, 2, 1, 3)\n    key = key.permute(0, 2, 1, 3)\n    value = value.permute(0, 2, 1, 3)\n\n    # Perform Flash attention\n    attn_output = flash_attn_func(\n        query,\n        key,\n        value,\n        dropout_p=0.0,  # Assuming you have this attribute\n        softmax_scale=softmax_scale,  # Set this if you have specific scaling in mind\n        causal=not self.is_cross_attention,  # Assuming you have this attribute\n        return_attn_probs=False,  # Set this based on your needs\n    )\n\n    # Optional: Apply head mask if it's not None\n    if head_mask is not None:\n        attn_output *= head_mask\n\n    attn_output = attn_output.permute(0, 2, 1, 3)\n\n    return attn_output, None  # We don't have explicit attn_weights in Flash attention\n"
  },
  {
    "path": "src/axolotl/monkeypatch/data/__init__.py",
    "content": ""
  },
  {
    "path": "src/axolotl/monkeypatch/data/batch_dataset_fetcher.py",
    "content": "\"\"\"Monkey patches for the dataset fetcher to handle batches of packed indexes.\"\"\"\n\nimport torch\nfrom torch.utils.data._utils.fetch import _BaseDatasetFetcher\nfrom torch.utils.data._utils.worker import _worker_loop\n\n_ORIGINAL_MAP_DATASET_FETCHER = None\n_ORIGINAL_WORKER_LOOP = None\n_IS_PATCHED = False\n\n\nclass _MapDatasetFetcher(_BaseDatasetFetcher):\n    \"\"\"\n    Custom dataset fetcher that handles nested batch structures from\n    MultipackBatchSampler.\n    \"\"\"\n\n    def fetch(self, possibly_batched_index):\n        if isinstance(possibly_batched_index[0], list):\n            # Handle nested structure from MultipackBatchSampler\n            data = [None for i in possibly_batched_index]\n            for i, possibly_batched_index_ in enumerate(possibly_batched_index):\n                if self.auto_collation:\n                    if (\n                        hasattr(self.dataset, \"__getitems__\")\n                        and self.dataset.__getitems__\n                    ):\n                        data[i] = self.dataset.__getitems__(possibly_batched_index_)\n                    else:\n                        data[i] = [self.dataset[idx] for idx in possibly_batched_index_]\n                else:\n                    data[i] = self.dataset[possibly_batched_index_]\n        else:\n            # Standard batch handling\n            if self.auto_collation:\n                if hasattr(self.dataset, \"__getitems__\") and self.dataset.__getitems__:\n                    data = self.dataset.__getitems__(possibly_batched_index)\n                else:\n                    data = [self.dataset[idx] for idx in possibly_batched_index]\n            else:\n                data = self.dataset[possibly_batched_index]\n        return self.collate_fn(data)\n\n\ndef patch_fetchers():\n    \"\"\"Apply patches to PyTorch's DataLoader components.\"\"\"\n    torch.utils.data._utils.fetch._MapDatasetFetcher = _MapDatasetFetcher\n    torch.utils.data.dataloader._utils.fetch._MapDatasetFetcher = _MapDatasetFetcher\n\n\ndef patched_worker_loop(*args, **kwargs):\n    \"\"\"Worker loop that ensures patches are applied in worker processes.\"\"\"\n    patch_fetchers()\n    return _worker_loop(*args, **kwargs)\n\n\ndef apply_multipack_dataloader_patch():\n    \"\"\"\n    This patch allows DataLoader to correctly process batches that contain multiple bins\n    of packed sequences.\n    \"\"\"\n    # pylint: disable=global-statement\n    global _ORIGINAL_MAP_DATASET_FETCHER, _ORIGINAL_WORKER_LOOP, _IS_PATCHED\n\n    if _IS_PATCHED:\n        return\n\n    # Store original implementations\n    _ORIGINAL_MAP_DATASET_FETCHER = torch.utils.data._utils.fetch._MapDatasetFetcher\n    _ORIGINAL_WORKER_LOOP = torch.utils.data._utils.worker._worker_loop\n\n    # Apply patches\n    patch_fetchers()\n    torch.utils.data._utils.worker._worker_loop = patched_worker_loop\n\n    _IS_PATCHED = True\n\n\ndef remove_multipack_dataloader_patch():\n    \"\"\"Remove the monkeypatch and restore original PyTorch DataLoader behavior.\"\"\"\n    # pylint: disable=global-statement\n    global _IS_PATCHED\n\n    if not _IS_PATCHED:\n        return\n\n    if _ORIGINAL_MAP_DATASET_FETCHER:\n        torch.utils.data._utils.fetch._MapDatasetFetcher = _ORIGINAL_MAP_DATASET_FETCHER\n        torch.utils.data.dataloader._utils.fetch._MapDatasetFetcher = (\n            _ORIGINAL_MAP_DATASET_FETCHER\n        )\n\n    if _ORIGINAL_WORKER_LOOP:\n        torch.utils.data._utils.worker._worker_loop = _ORIGINAL_WORKER_LOOP\n\n    _IS_PATCHED = False\n"
  },
  {
    "path": "src/axolotl/monkeypatch/deepspeed_utils.py",
    "content": "import importlib\nimport importlib.util\n\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\n\ndef patch_checkpoint_wrapper_setattr():\n    \"\"\"\n    Patch CheckpointWrapper to properly forward DeepSpeed attributes to wrapped modules.\n\n    This fixes the issue where CheckpointWrapper doesn't forward ds_* attributes\n    (like ds_grads_remaining) to the actual wrapped module, causing DeepSpeed\n    ZeRO-3 to fail when gradient checkpointing is enabled.\n\n    This issue occurs specifically with:\n    - QLoRA + DeepSpeed ZeRO-3\n    - gradient_checkpointing: true\n    - activation_offloading: true\n\n    References:\n    - https://github.com/deepspeedai/DeepSpeed/issues/7203\n    - https://github.com/deepspeedai/DeepSpeed/blob/38d1a9eb64c9e01e32eccc50b25ba18925287441/deepspeed/runtime/zero/parameter_offload.py#L424-L458\n    - https://github.com/axolotl-ai-cloud/axolotl/pull/3102\n    \"\"\"\n\n    try:\n        from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (\n            CheckpointWrapper,\n        )\n\n        # Check if already patched\n        if hasattr(CheckpointWrapper, \"_axolotl_setattr_patched\"):\n            LOG.debug(\"CheckpointWrapper already patched\")\n            return\n\n        original_setattr = CheckpointWrapper.__setattr__\n\n        def new_setattr(self, name: str, value) -> None:\n            if name.startswith(\"ds_\") and hasattr(self, \"_checkpoint_wrapped_module\"):\n                setattr(self._checkpoint_wrapped_module, name, value)\n                LOG.debug(\n                    f\"Forwarded {name} to wrapped module {type(self._checkpoint_wrapped_module).__name__}\"\n                )\n            else:\n                original_setattr(self, name, value)\n\n        CheckpointWrapper.__setattr__ = new_setattr\n        CheckpointWrapper._axolotl_setattr_patched = True\n\n        LOG.info(\"CheckpointWrapper patched to forward DeepSpeed attributes\")\n\n    except ImportError as e:\n        LOG.debug(f\"CheckpointWrapper not available: {e}\")\n    except Exception as e:\n        LOG.warning(f\"Failed to patch CheckpointWrapper: {e}\")\n\n\ndef apply_deepspeed_patches():\n    \"\"\"\n    Apply DeepSpeed-related patches\n    \"\"\"\n    if importlib.util.find_spec(\"deepspeed\") is not None:\n        patch_checkpoint_wrapper_setattr()\n    else:\n        LOG.debug(\"DeepSpeed not available, skipping patches\")\n"
  },
  {
    "path": "src/axolotl/monkeypatch/fsdp2_qlora.py",
    "content": "\"\"\"\nMonkeypatch to add Params4bit and Int8Params support to FSDP2. This enables QLoRA + FSDP2\nand 8-bit LoRA + FSDP2, as well as our LoRA / QLoRA Triton kernels to work with FSDP2.\n\nThis patch modifies the _init_sharded_param and init_unsharded_param methods in FSDPParam\nto handle bitsandbytes Params4bit and Int8Params parameters, preserving their quantization\nmetadata through the FSDP2 shard/unshard cycle.\n\"\"\"\n\nimport importlib\nimport inspect\n\nfrom axolotl.monkeypatch.utils import detab_code\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\n\ndef apply_init_sharded_param_patch():\n    \"\"\"Apply patch to FSDPParam._init_sharded_param to support Params4bit.\"\"\"\n    if getattr(apply_init_sharded_param_patch, \"_axolotl_patched\", False):\n        return\n    from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam\n\n    # Get original source\n    original_source = inspect.getsource(FSDPParam._init_sharded_param)\n    original_source, _ = detab_code(original_source)\n\n    # Define the replacement\n    original_param_creation = \"\"\"    self.sharded_param = nn.Parameter(self.to_sharded_dtensor(sharded_param))\n    self.sharded_param.requires_grad_(param.requires_grad)\"\"\"\n\n    patched_param_creation = \"\"\"    import bitsandbytes as bnb\n    if isinstance(param, bnb.nn.modules.Params4bit):\n        self.sharded_param = bnb.nn.modules.Params4bit(\n            data=sharded_param,\n            requires_grad=param.requires_grad,\n            quant_state=param.quant_state,\n            blocksize=param.blocksize,\n            compress_statistics=param.compress_statistics,\n            quant_type=param.quant_type,\n            quant_storage=param.quant_storage,\n            module=param.module,\n            bnb_quantized=param.bnb_quantized,\n        )\n        self.sharded_param = self.to_sharded_dtensor(self.sharded_param)\n    elif isinstance(param, bnb.nn.modules.Int8Params):\n        self.sharded_param = bnb.nn.modules.Int8Params(\n            data=sharded_param,\n            requires_grad=param.requires_grad,\n            has_fp16_weights=param.has_fp16_weights,\n            CB=None,\n            SCB=param.SCB,\n        )\n        self.sharded_param = self.to_sharded_dtensor(self.sharded_param)\n    else:\n        self.sharded_param = nn.Parameter(\n            self.to_sharded_dtensor(sharded_param),\n            requires_grad=param.requires_grad,\n        )\"\"\"\n\n    # Apply the replacement\n    if original_param_creation in original_source:\n        patched_source = original_source.replace(\n            original_param_creation, patched_param_creation\n        )\n        patched_source = patched_source.replace(\n            \"def _init_sharded_param(\",\n            \"def patched_init_sharded_param(\",\n            1,\n        )\n\n        # Load necessary imports\n        module_name = FSDPParam.__module__\n        module = importlib.import_module(module_name)\n\n        items_to_import = []\n        for item in dir(module):\n            if item in patched_source:\n                items_to_import.append(item)\n\n        exec(  # nosec B102\n            f\"from {module_name} import ({', '.join(items_to_import)})\",\n            globals(),\n        )\n        exec(patched_source, globals())  # nosec B102\n\n        # Replace the method\n        FSDPParam._init_sharded_param = patched_init_sharded_param\n        apply_init_sharded_param_patch._axolotl_patched = True\n        LOG.info(\"Successfully applied FSDP _init_sharded_param patch\")\n    else:\n        LOG.warning(\"Could not find target code for _init_sharded_param patching\")\n\n\ndef apply_init_unsharded_param_patch():\n    \"\"\"Apply patch to FSDPParam.init_unsharded_param to support Params4bit.\"\"\"\n    if getattr(apply_init_unsharded_param_patch, \"_axolotl_patched\", False):\n        return\n    from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam\n\n    # Get original source\n    original_source = inspect.getsource(FSDPParam.init_unsharded_param)\n    original_source, _ = detab_code(original_source)\n\n    # Define the replacement\n    original_param_creation = \"\"\"        self._unsharded_param = nn.Parameter(\n            unsharded_param, requires_grad=self.sharded_param.requires_grad\n        )\"\"\"\n\n    patched_param_creation = \"\"\"        import bitsandbytes as bnb\n        local_tensor = self.sharded_param._local_tensor\n        if isinstance(local_tensor, bnb.nn.modules.Params4bit):\n            self._unsharded_param = bnb.nn.modules.Params4bit(\n                data=unsharded_param,\n                requires_grad=self.sharded_param.requires_grad,\n                quant_state=local_tensor.quant_state,\n                blocksize=local_tensor.blocksize,\n                compress_statistics=local_tensor.compress_statistics,\n                quant_type=local_tensor.quant_type,\n                quant_storage=local_tensor.quant_storage,\n                module=local_tensor.module,\n                bnb_quantized=local_tensor.bnb_quantized,\n            )\n        elif isinstance(local_tensor, bnb.nn.modules.Int8Params):\n            self._unsharded_param = bnb.nn.modules.Int8Params(\n                data=unsharded_param,\n                requires_grad=self.sharded_param.requires_grad,\n                has_fp16_weights=local_tensor.has_fp16_weights,\n                CB=unsharded_param,\n                SCB=local_tensor.SCB,\n            )\n        else:\n            self._unsharded_param = nn.Parameter(\n                unsharded_param, requires_grad=self.sharded_param.requires_grad\n            )\"\"\"\n\n    # Apply the replacement\n    if original_param_creation in original_source:\n        patched_source = original_source.replace(\n            original_param_creation, patched_param_creation\n        )\n        patched_source = patched_source.replace(\n            \"def init_unsharded_param(\",\n            \"def patched_init_unsharded_param(\",\n            1,\n        )\n\n        # Load necessary imports\n        module_name = FSDPParam.__module__\n        module = importlib.import_module(module_name)\n\n        items_to_import = []\n        for item in dir(module):\n            if item in patched_source:\n                items_to_import.append(item)\n\n        exec(  # nosec B102\n            f\"from {module_name} import ({', '.join(items_to_import)})\",\n            globals(),\n        )\n        exec(patched_source, globals())  # nosec B102\n\n        # Replace the method\n        FSDPParam.init_unsharded_param = patched_init_unsharded_param\n        apply_init_unsharded_param_patch._axolotl_patched = True\n        LOG.info(\"Successfully applied FSDP init_unsharded_param patch\")\n    else:\n        LOG.warning(\"Could not find target code for patching\")\n\n\ndef apply_linear8bitlt_save_patch():\n    \"\"\"Patch Linear8bitLt._save_to_state_dict to handle DTensor-wrapped Int8Params.\n\n    After FSDP2 sharding, Linear8bitLt.weight is a DTensor wrapping Int8Params.\n    BnB's _save_to_state_dict accesses self.weight.SCB directly, but DTensor\n    doesn't proxy custom attribute access to its _local_tensor. This patch\n    temporarily unwraps the DTensor during saving so BnB can find the SCB attribute.\n    \"\"\"\n    if getattr(apply_linear8bitlt_save_patch, \"_axolotl_patched\", False):\n        return\n    import bitsandbytes as bnb\n    from torch.distributed.tensor import DTensor\n\n    original_save = bnb.nn.Linear8bitLt._save_to_state_dict\n\n    def _patched_save_to_state_dict(self, destination, prefix, keep_vars):\n        # Use _parameters dict directly to bypass nn.Module.__setattr__ type check.\n        weight = self._parameters[\"weight\"]\n        unwrapped = False\n        if isinstance(weight, DTensor) and hasattr(weight, \"_local_tensor\"):\n            self._parameters[\"weight\"] = weight._local_tensor\n            unwrapped = True\n        try:\n            original_save(self, destination, prefix, keep_vars)\n        finally:\n            if unwrapped:\n                self._parameters[\"weight\"] = weight\n\n    bnb.nn.Linear8bitLt._save_to_state_dict = _patched_save_to_state_dict\n    apply_linear8bitlt_save_patch._axolotl_patched = True\n    LOG.info(\"Patched Linear8bitLt._save_to_state_dict for DTensor compatibility\")\n\n\ndef apply_init_dtype_attrs_patch():\n    \"\"\"Prevent FSDP2 mixed precision from casting non-float quantized params.\n\n    When mixed precision is enabled (e.g., bf16), FSDP2's init_dtype_attrs sets\n    param_dtype=bf16 for ALL params. During all-gather, _to_dtype_if_needed casts\n    the sharded param to param_dtype. For non-float params (uint8 packed 4-bit,\n    int8 quantized) without FSDP2 extensions, this destroys the quantized data.\n\n    Params4bit handles this via fsdp_pre/post_all_gather extensions, but our\n    parametrize-based expert quantization uses plain nn.Parameter(uint8/int8)\n    without extensions.\n    \"\"\"\n    if getattr(apply_init_dtype_attrs_patch, \"_axolotl_patched\", False):\n        return\n    from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam\n\n    original_init_dtype_attrs = FSDPParam.init_dtype_attrs\n\n    def patched_init_dtype_attrs(self, mp_policy):\n        original_init_dtype_attrs(self, mp_policy)\n        # Skip casting non-float quantized params (uint8/int8) without FSDP2\n        # extensions — the parametrization chain handles dequantization.\n        if self.param_dtype is not None and not self.sharded_param.is_floating_point():\n            local = self.sharded_param\n            if hasattr(local, \"_local_tensor\"):\n                local = local._local_tensor\n            if not hasattr(local, \"fsdp_pre_all_gather\"):\n                self.param_dtype = None\n\n    FSDPParam.init_dtype_attrs = patched_init_dtype_attrs\n    apply_init_dtype_attrs_patch._axolotl_patched = True\n    LOG.info(\"Patched FSDPParam.init_dtype_attrs for non-float quantized params\")\n"
  },
  {
    "path": "src/axolotl/monkeypatch/gradient_checkpointing/__init__.py",
    "content": "\"\"\"custom checkpointing utils\"\"\"\n\nimport importlib\nfrom functools import partial\n\nfrom packaging import version\n\nfrom axolotl.monkeypatch.gradient_checkpointing.offload_cpu import (  # noqa: F401\n    CPU_Offloaded_Gradient_Checkpointer,\n)\nfrom axolotl.monkeypatch.gradient_checkpointing.offload_disk import (\n    Disco,\n)\n\ntransformers_version = version.parse(importlib.metadata.version(\"transformers\"))\nif transformers_version > version.parse(\"4.51.3\"):\n    from transformers.modeling_layers import GradientCheckpointingLayer\n\n    def uses_gc_layers(decoder_layer):\n        return isinstance(decoder_layer.func.__self__, GradientCheckpointingLayer)\n\nelse:\n\n    def uses_gc_layers(_):\n        return False\n\n\ndef hf_grad_checkpoint_offload_wrapper(decoder_layer, *args, use_reentrant=None):\n    if uses_gc_layers(decoder_layer):\n        return CPU_Offloaded_Gradient_Checkpointer.apply(\n            decoder_layer,\n            *args,\n        )\n\n    return CPU_Offloaded_Gradient_Checkpointer.apply(\n        (\n            decoder_layer.func.__self__\n            if isinstance(decoder_layer, partial)\n            else decoder_layer.__self__\n        ),\n        *args,\n    )\n\n\ndef hf_grad_checkpoint_disk_offload_wrapper(decoder_layer, *args, use_reentrant=None):\n    if uses_gc_layers(decoder_layer):\n        return Disco.apply(\n            decoder_layer,\n            *args,\n        )\n\n    return Disco.apply(\n        (\n            decoder_layer.func.__self__\n            if isinstance(decoder_layer, partial)\n            else decoder_layer.__self__\n        ),\n        *args,\n    )\n"
  },
  {
    "path": "src/axolotl/monkeypatch/gradient_checkpointing/offload_cpu.py",
    "content": "\"\"\"CPU offloaded checkpointing\"\"\"\n\n# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport inspect\n\nimport torch\nfrom packaging import version\nfrom torch.utils.checkpoint import (\n    set_device_states,\n)\n\n# support different pytorch versions\nhas_device_type = \"device_type\" in inspect.signature(set_device_states).parameters\n\ntorch_version = version.parse(torch.__version__)\n\nif torch_version < version.parse(\"2.4.0\"):\n    torch_cuda_amp_custom_fwd = torch.cuda.amp.custom_fwd\n    torch_cuda_amp_custom_bwd = torch.cuda.amp.custom_bwd\nelse:\n    torch_cuda_amp_custom_fwd = torch.amp.custom_fwd(device_type=\"cuda\")\n    torch_cuda_amp_custom_bwd = torch.amp.custom_bwd(device_type=\"cuda\")\n\n\nclass CPU_Offloaded_Gradient_Checkpointer(torch.autograd.Function):\n    \"\"\"\n    Saves VRAM by smartly offloading to RAM.\n    Tiny hit to performance, since we mask the movement via non blocking calls.\n    \"\"\"\n\n    @staticmethod\n    @torch_cuda_amp_custom_fwd\n    def forward(ctx, forward_function, hidden_states, *args):\n        saved_hidden_states = hidden_states.to(\"cpu\", non_blocking=True)\n        with torch.no_grad():\n            output = forward_function(hidden_states, *args)\n        ctx.save_for_backward(saved_hidden_states)\n        ctx.forward_function = forward_function\n        ctx.args = args\n        return output\n\n    @staticmethod\n    @torch_cuda_amp_custom_bwd\n    def backward(ctx, dY):\n        (hidden_states,) = ctx.saved_tensors\n        hidden_states = hidden_states.to(\"cuda\", non_blocking=True).detach()\n        hidden_states.requires_grad = True\n        with torch.enable_grad():\n            output = ctx.forward_function(hidden_states, *ctx.args)\n            # Newer HF models (e.g. Qwen3MoE) using GradientCheckpointingLayer\n            # return a plain tensor, not a tuple.  Older models return tuples\n            # like (hidden_states, present_kv, ...).  Unwrap if needed.\n            if isinstance(output, (tuple, list)):\n                (output,) = output\n        torch.autograd.backward(output, dY)\n        return (\n            None,\n            hidden_states.grad,\n        ) + (None,) * len(ctx.args)\n"
  },
  {
    "path": "src/axolotl/monkeypatch/gradient_checkpointing/offload_disk.py",
    "content": "\"\"\"\nDISCO - DIsk-based Storage and Checkpointing with Optimized prefetching\n\"\"\"\n\n# Copyright 2025 Axolotl AI. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport atexit\nimport concurrent.futures\nimport os\nimport queue\nimport shutil\nimport tempfile\nimport threading\nimport time\nimport uuid\nfrom collections import deque\nfrom concurrent.futures import Future\nfrom typing import Dict\n\nimport torch\n\nfrom axolotl.utils.logging import get_logger\n\ntorch_cuda_amp_custom_fwd = torch.amp.custom_fwd(device_type=\"cuda\")\ntorch_cuda_amp_custom_bwd = torch.amp.custom_bwd(device_type=\"cuda\")\n\n# Setup logger\nlogger = get_logger(__name__)\n\n\nclass DiskOffloadManager:\n    \"\"\"\n    Manages offloaded tensors and handles prefetching in a separate thread.\n    Includes synchronization to prevent race conditions.\n    \"\"\"\n\n    def __init__(\n        self,\n        prefetch_size: int = 3,\n        prefetch_to_gpu: bool = True,\n        save_workers: int = 4,\n    ):\n        \"\"\"\n        Args:\n            prefetch_size: Maximum number of tensors to prefetch in the background.\n            prefetch_to_gpu: Whether to prefetch tensors directly to GPU memory.\n            save_workers: Maximum number of concurrent save operations.\n        \"\"\"\n        self.temp_dir = tempfile.mkdtemp(prefix=\"disco_\")\n\n        # Track tensor paths and their status\n        self.tensor_paths: deque = deque()  # Ordered history of tensor paths (LIFO)\n        self.file_locks: Dict[\n            str, threading.Lock\n        ] = {}  # Maps file_path -> threading.Lock()\n        # Maps file_path -> status (\"saving\", \"ready\", \"prefetching\", \"loaded\", \"deleted\")\n        self.file_status: Dict[str, str] = {}\n\n        self.max_prefetch = prefetch_size\n        self.prefetch_to_gpu = prefetch_to_gpu\n\n        # Thread synchronization\n        self.manager_lock = threading.RLock()  # Used for thread-safe operations\n\n        # Prefetch queue and cache\n        self.prefetch_queue: queue.Queue = queue.Queue()\n        self.prefetch_cache: Dict[str, torch.Tensor] = {}  # Maps file_path -> tensor\n\n        # Save queue and thread pool\n        self.save_queue: queue.Queue = queue.Queue()\n        self.save_pool = concurrent.futures.ThreadPoolExecutor(max_workers=save_workers)\n        self.save_futures: Dict[str, Future] = {}\n        self.save_semaphore = threading.Semaphore(\n            save_workers * 2\n        )  # Limit concurrent save operations\n\n        # Start prefetch worker thread\n        self.stop_event = threading.Event()\n        # start multiple threads for prefetching\n        self.prefetch_worker_count = 2\n        self.prefetch_workers = []\n        for _ in range(self.prefetch_worker_count):\n            worker = threading.Thread(target=self._prefetch_worker, daemon=True)\n            worker.start()\n            self.prefetch_workers.append(worker)\n\n        # Start save worker thread\n        self.save_worker = threading.Thread(target=self._save_worker, daemon=True)\n        self.save_worker.start()\n        self.idx = 0\n\n        atexit.register(self.cleanup)\n\n    def _save_worker(self):\n        \"\"\"Background thread that processes the save queue\"\"\"\n        while not self.stop_event.is_set():\n            try:\n                save_item = self.save_queue.get(timeout=0.5)\n                if save_item is None:\n                    continue\n\n                tensor, file_path = save_item\n\n                # Submit the save task to the thread pool\n                future = self.save_pool.submit(\n                    self._save_tensor_to_disk, tensor, file_path\n                )\n                with self.manager_lock:\n                    self.save_futures[file_path] = future\n\n                self.save_queue.task_done()\n\n            except queue.Empty:\n                time.sleep(0.01)  # Small sleep to prevent CPU spinning\n                continue\n\n    def _save_tensor_to_disk(self, tensor: torch.Tensor, file_path: str):\n        \"\"\"Actually save the tensor to disk\"\"\"\n        try:\n            # Save tensor to disk\n            cpu_tensor = tensor.detach().cpu()\n            torch.save(cpu_tensor, file_path)\n            del cpu_tensor\n\n            with self.manager_lock:\n                # Mark file as ready\n                self.file_status[file_path] = \"ready\"\n\n            # Release semaphore\n            self.save_semaphore.release()\n\n            return True\n        except FileNotFoundError as e:\n            logger.error(f\"Error saving tensor to {file_path}: {e}\")\n            with self.manager_lock:\n                self.file_status[file_path] = \"error\"\n\n            # Release semaphore\n            self.save_semaphore.release()\n\n            return False\n\n    def _prefetch_worker(self):\n        \"\"\"Background thread that loads tensors from disk ahead of time\"\"\"\n        while not self.stop_event.is_set():\n            try:\n                file_path = self.prefetch_queue.get(timeout=0.5)\n                if file_path is None:\n                    continue\n\n                # Check if file is available and not already in cache\n                with self.manager_lock:\n                    if (\n                        file_path not in self.file_status\n                        or self.file_status[file_path] == \"deleted\"\n                    ):\n                        self.prefetch_queue.task_done()\n                    if file_path in self.prefetch_cache:\n                        self.prefetch_queue.task_done()\n                        continue\n\n                    # If file is still being saved, wait for it\n                    if (\n                        self.file_status[file_path] == \"saving\"\n                        and file_path in self.save_futures\n                    ):\n                        # Re-queue this prefetch request with a little delay\n                        self.prefetch_queue.task_done()\n                        time.sleep(0.1)\n                        self.prefetch_queue.put(file_path)\n                        continue\n\n                    # Mark file as being prefetched\n                    self.file_status[file_path] = \"prefetching\"\n\n                # Load tensor from disk and store in cache\n                try:\n                    if os.path.exists(file_path):\n                        if self.prefetch_to_gpu:\n                            tensor = torch.load(\n                                file_path,\n                                map_location=torch.device(\"cuda\"),\n                                weights_only=True,\n                            )\n                        else:\n                            tensor = torch.load(file_path, weights_only=True)\n\n                        with self.manager_lock:\n                            self.prefetch_cache[file_path] = tensor\n                            self.file_status[file_path] = \"ready\"\n                    else:\n                        with self.manager_lock:\n                            if self.file_status.get(file_path) != \"deleted\":\n                                logger.warning(\n                                    f\"Prefetch error: File not found {file_path}\"\n                                )\n                                self.file_status[file_path] = \"missing\"\n\n                except FileNotFoundError as e:\n                    with self.manager_lock:\n                        if self.file_status.get(file_path) != \"deleted\":\n                            logger.warning(f\"Prefetch error for {file_path}: {e}\")\n                            self.file_status[file_path] = \"error\"\n\n                self.prefetch_queue.task_done()\n\n            except queue.Empty:\n                time.sleep(0.01)  # Small sleep to prevent CPU spinning\n                continue\n\n    def save_tensor(self, tensor: torch.Tensor):\n        \"\"\"Save tensor to disk asynchronously and return file path with thread-safe operations\"\"\"\n        # Generate unique file path\n        self.idx += 1\n        file_path: str = os.path.join(\n            self.temp_dir, f\"{self.idx:06d}-{uuid.uuid4()}.pt\"\n        )\n\n        with self.manager_lock:\n            # Mark file as being saved\n            self.file_locks[file_path] = threading.Lock()\n            self.file_status[file_path] = \"saving\"\n            # Add to history\n            self.tensor_paths.append(file_path)\n\n        # Acquire semaphore to limit concurrent save operations\n        self.save_semaphore.acquire()\n        # Queue tensor for saving in background\n        self.save_queue.put((tensor.detach(), file_path))\n\n        return file_path\n\n    def wait_for_save(self, file_path, timeout=None) -> None:\n        \"\"\"Wait for a tensor to be saved to disk\"\"\"\n        start_time = time.time()\n        while timeout is None or time.time() - start_time < timeout:\n            with self.manager_lock:\n                if self.file_status.get(file_path) == \"ready\":\n                    return\n                if self.file_status.get(file_path) in [\"error\", \"missing\", \"deleted\"]:\n                    return\n\n                if file_path in self.save_futures:\n                    future = self.save_futures[file_path]\n                    if future.done():\n                        return\n\n            # Small sleep to prevent CPU spinning\n            time.sleep(0.01)\n\n        # Timeout\n        logger.warning(f\"Timeout waiting for tensor to be saved: {file_path}\")\n        return\n\n    def load_tensor(self, file_path, target_device=\"cuda\"):\n        \"\"\"Load tensor from disk or prefetch cache with proper synchronization\"\"\"\n        # Wait for tensor to be saved if it's still in progress\n        self.wait_for_save(file_path)\n\n        tensor = None\n\n        # Try to get from cache first\n        with self.manager_lock:\n            # Check if tensor is already in cache\n            if file_path in self.prefetch_cache:\n                tensor = self.prefetch_cache[file_path]\n                del self.prefetch_cache[file_path]\n                self.file_status[file_path] = \"loaded\"\n\n        if tensor is not None:\n            # Ensure tensor is on correct device\n            if target_device != \"cpu\" and tensor.device.type == \"cpu\":\n                tensor = tensor.to(target_device, non_blocking=True)\n            return tensor\n\n        # If not in cache, load directly from disk\n        try:\n            if not os.path.exists(file_path):\n                logger.error(f\"File not found for loading: {file_path}\")\n                raise FileNotFoundError(f\"File not found: {file_path}\")\n\n            tensor = torch.load(file_path, weights_only=True)\n\n            with self.manager_lock:\n                self.file_status[file_path] = \"loaded\"\n\n            if target_device != \"cpu\":\n                tensor = tensor.to(target_device, non_blocking=True)\n\n            return tensor\n\n        except Exception as e:\n            logger.error(f\"Error loading tensor from {file_path}: {e}\")\n            raise\n\n    def _safe_delete_file(self, file_path):\n        \"\"\"Safely delete a file with proper synchronization\"\"\"\n        with self.manager_lock:\n            # Make sure any save operation is completed\n            if file_path in self.save_futures:\n                future = self.save_futures[file_path]\n                try:\n                    if not future.done():\n                        future.cancel()\n                    del self.save_futures[file_path]\n                except FileNotFoundError as e:\n                    logger.warning(\n                        f\"Error canceling save operation for {file_path}: {e}\"\n                    )\n\n            # Only delete if file exists and is not being prefetched\n            status = self.file_status.get(file_path)\n            if status in [\"ready\", \"loaded\", \"error\", \"missing\"]:\n                try:\n                    if os.path.exists(file_path):\n                        os.remove(file_path)\n                    self.file_status[file_path] = \"deleted\"\n                    return True\n                except FileNotFoundError as e:\n                    logger.warning(f\"Error deleting file {file_path}: {e}\")\n            return False\n\n    def trigger_prefetch(self, n=None):\n        \"\"\"Trigger prefetching of the next N tensors with proper synchronization\"\"\"\n        if n is None:\n            n = self.max_prefetch\n\n        prefetch_paths = []\n        with self.manager_lock:\n            # Find files that are ready to be prefetched (not already in cache or being prefetched)\n            for path in reversed(self.tensor_paths):\n                if (\n                    path not in self.prefetch_cache\n                    and self.file_status.get(path) == \"ready\"\n                ):\n                    prefetch_paths.append(path)\n                    if len(prefetch_paths) >= n:\n                        break\n\n        # Queue files for prefetching\n        for path in prefetch_paths:\n            self.prefetch_queue.put(path)\n\n    def cleanup_tensor(self, file_path: str):\n        \"\"\"Clean up a specific tensor file after it's been used\"\"\"\n        with self.manager_lock:\n            if file_path in self.tensor_paths:\n                self.tensor_paths.remove(file_path)\n\n            # Remove from prefetch cache if present\n            if file_path in self.prefetch_cache:\n                del self.prefetch_cache[file_path]\n\n            # Remove from save futures if present\n            if file_path in self.save_futures:\n                future = self.save_futures[file_path]\n                if not future.done():\n                    future.cancel()\n                del self.save_futures[file_path]\n\n        # Try to delete the file\n        self._safe_delete_file(file_path)\n\n    def cleanup(self):\n        \"\"\"Clean up all temp files and stop prefetch thread with proper synchronization\"\"\"\n        self.stop_event.set()\n\n        # Cancel all pending save operations\n        with self.manager_lock:\n            for _, future in self.save_futures.items():\n                if not future.done():\n                    future.cancel()\n            self.save_futures.clear()\n\n        # Drain the save queue\n        while not self.save_queue.empty():\n            try:\n                self.save_queue.get_nowait()\n                self.save_queue.task_done()\n            except queue.Empty:\n                break\n\n        # Shutdown the save pool\n        self.save_pool.shutdown(wait=False)\n\n        # Join the save worker thread\n        if self.save_worker.is_alive():\n            self.save_worker.join(timeout=2.0)\n\n        # Join the prefetch worker threads\n        for thread in self.prefetch_workers:\n            if thread.is_alive():\n                thread.join(timeout=2.0)\n\n        # Clear cache and remove all temporary files\n        with self.manager_lock:\n            self.prefetch_cache.clear()\n            paths_to_delete = list(self.tensor_paths)\n            self.tensor_paths.clear()\n\n        # Delete all temporary files\n        for path in paths_to_delete:\n            self._safe_delete_file(path)\n\n        # Remove temp directory\n        try:\n            if os.path.exists(self.temp_dir):\n                shutil.rmtree(self.temp_dir, ignore_errors=True)\n        except FileNotFoundError as e:\n            logger.warning(f\"Error removing temporary directory {self.temp_dir}: {e}\")\n\n\nclass Disco(torch.autograd.Function):\n    \"\"\"\n    Disco: DIsk-based Storage and Checkpointing with Optimized prefetching\n    Advanced disk-based gradient checkpointer with prefetching.\n    \"\"\"\n\n    # Shared manager instance across all checkpointing operations\n    _manager = None\n\n    @staticmethod\n    def get_instance(prefetch_size=1, prefetch_to_gpu=True, save_workers=4):\n        \"\"\"Get or create the offload manager\"\"\"\n        if Disco._manager is None:\n            Disco._manager = DiskOffloadManager(\n                prefetch_size=prefetch_size,\n                prefetch_to_gpu=prefetch_to_gpu,\n                save_workers=save_workers,\n            )\n        return Disco._manager\n\n    @staticmethod\n    @torch_cuda_amp_custom_fwd\n    def forward(\n        ctx,\n        forward_function,\n        hidden_states,\n        *args,\n        prefetch_size=1,\n        prefetch_to_gpu=True,\n        save_workers=4,\n    ):\n        \"\"\"Forward pass that offloads activations to disk asynchronously\"\"\"\n        # Get or create the manager\n        manager = Disco.get_instance(\n            prefetch_size=prefetch_size,\n            prefetch_to_gpu=prefetch_to_gpu,\n            save_workers=save_workers,\n        )\n\n        # Save tensor to disk asynchronously\n        file_path = manager.save_tensor(hidden_states)\n\n        # Run forward pass immediately without waiting for save to complete\n        with torch.no_grad():\n            output = forward_function(hidden_states, *args)\n\n        # Store what we need for backward\n        ctx.save_for_backward(torch.tensor([0]))  # Dummy tensor\n        ctx.file_path = file_path\n        ctx.forward_function = forward_function\n        ctx.args = args\n\n        return output\n\n    @staticmethod\n    @torch_cuda_amp_custom_bwd\n    def backward(ctx, *grad_outputs):\n        \"\"\"Backward pass that loads activations from disk with prefetching\"\"\"\n        # Get the manager\n        manager = Disco._manager\n\n        # Trigger prefetching for future tensors\n        # This happens at the start of backward, so should have time to complete\n        manager.trigger_prefetch()\n\n        # Load hidden states from disk or prefetch cache\n        file_path = ctx.file_path\n        try:\n            # Ensure the file is saved before we try to load it\n            manager.wait_for_save(file_path)\n\n            hidden_states = manager.load_tensor(file_path)\n            hidden_states.requires_grad = True\n\n            # Compute gradients\n            with torch.enable_grad():\n                output = ctx.forward_function(hidden_states, *ctx.args)\n\n                # Handle tuple outputs properly\n                if isinstance(output, tuple):\n                    if len(grad_outputs) == len(output):\n                        torch.autograd.backward(output, grad_outputs)\n                    else:\n                        torch.autograd.backward(output, grad_outputs[0])\n                else:\n                    torch.autograd.backward(output, grad_outputs[0])\n\n            # Clean up the file after we're done with it\n            manager.cleanup_tensor(file_path)\n\n            return (\n                (\n                    None,  # forward_function\n                    hidden_states.grad,  # hidden_states grad\n                )\n                + (None,) * len(ctx.args)  # for each arg\n                + (\n                    None,  # prefetch_size\n                    None,  # prefetch_to_gpu\n                    None,  # save_workers\n                )\n            )\n\n        except Exception as e:\n            logger.error(f\"Error in backward pass: {e}\")\n            # Clean up the file even on error\n            manager.cleanup_tensor(file_path)\n            raise\n"
  },
  {
    "path": "src/axolotl/monkeypatch/llama_attn_hijack_flash.py",
    "content": "\"\"\"Flash attention monkey patch for llama model\"\"\"\n\n# copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py\n\nimport importlib.util\nimport warnings\nfrom typing import Optional, Tuple\n\nimport torch\nimport transformers\nfrom einops import rearrange\nfrom flash_attn.bert_padding import pad_input, unpad_input\nfrom transformers.models.llama.modeling_llama import (\n    LlamaMLP,\n    apply_rotary_pos_emb,\n    repeat_kv,\n)\n\nfrom axolotl.monkeypatch.utils import set_module_name\nfrom axolotl.utils.logging import get_logger\n\ntry:\n    from flash_attn.flash_attn_interface import (\n        flash_attn_varlen_qkvpacked_func,\n    )\nexcept ImportError:\n    from flash_attn.flash_attn_interface import (\n        flash_attn_unpadded_qkvpacked_func as flash_attn_varlen_qkvpacked_func,\n    )\n\n\nLOG = get_logger(__name__)\n\n\ndef is_xformers_available() -> bool:\n    return importlib.util.find_spec(\"xformers\") is not None\n\n\ndef is_xformers_swiglu_available() -> bool:\n    if not is_xformers_available():\n        return False\n\n    from xformers.ops.common import get_xformers_operator\n\n    try:\n        get_xformers_operator(\"swiglu_packedw\")()\n        return True\n    except RuntimeError as exc:\n        if \"No such operator xformers::swiglu_packedw \" in str(exc):\n            return False\n        return True\n\n\ndef replace_llama_mlp_with_swiglu(model):\n    if is_xformers_swiglu_available():\n        from axolotl.monkeypatch.xformers_ import FusedMLP\n    else:\n        raise RuntimeError(\"xformers SwiGLU not available for this environment\")\n\n    for name, module in model.named_modules():\n        if isinstance(module, LlamaMLP):\n            mlp = FusedMLP(\n                module.config, module.gate_proj, module.up_proj, module.down_proj\n            )\n            set_module_name(model, name, mlp)\n\n\ndef patch_fa_llama_cross_entropy():\n    LOG.info(\n        \"patching transformers.loss.loss_utils.fixed_cross_entropy with flash_attn.ops.triton.cross_entropy\"\n    )\n    from flash_attn.ops.triton.cross_entropy import (\n        cross_entropy_loss as flash_attn_cross_entropy_loss,\n    )\n\n    def fa2_fixed_cross_entropy(\n        source,\n        target,\n        num_items_in_batch: int = None,\n        ignore_index: int = -100,\n        **kwargs,\n    ):\n        reduction = \"sum\" if num_items_in_batch is not None else \"mean\"\n        loss, _ = flash_attn_cross_entropy_loss(\n            source, target, ignore_index=ignore_index\n        )\n        if reduction == \"sum\":\n            loss = loss.sum() / num_items_in_batch\n        else:\n            loss = loss.sum() / (target != ignore_index).sum()\n        return loss\n\n    transformers.loss.loss_utils.fixed_cross_entropy = fa2_fixed_cross_entropy\n\n\ndef patch_llama_rms_norm():\n    try:\n        from flash_attn.ops.rms_norm import RMSNorm\n\n        class LlamaRMSNorm(RMSNorm):\n            \"\"\"Patched LLamaRMSNorm\"\"\"\n\n            def __init__(self, hidden_size, eps=1e-6):\n                super().__init__(hidden_size, eps=eps)\n\n        LOG.info(\"patching with flash_attn.ops.rms_norm\")\n        transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm\n    except ImportError:\n        LOG.warning(\n            \"optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)\"\n        )\n\n\ndef replace_llama_attn_with_flash_attn(\n    cross_entropy: Optional[bool] = False,\n    rms_norm: Optional[bool] = False,\n    use_shifted_sparse_attn: Optional[bool] = False,\n):\n    transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask\n    if use_shifted_sparse_attn:\n        transformers.models.llama.modeling_llama.LlamaAttention.forward = (\n            flashattn_forward_with_s2attn\n        )\n\n    # skip only if explicitly disabled\n    if cross_entropy:\n        patch_fa_llama_cross_entropy()\n\n    # skip only if explicitly disabled\n    if rms_norm:\n        patch_llama_rms_norm()\n\n\n# Disable the transformation of the attention mask in LlamaModel as the flash attention\n# requires the attention mask to be the same as the key_padding_mask\ndef _prepare_decoder_attention_mask(\n    self,\n    attention_mask,\n    input_shape,\n    inputs_embeds,\n    past_key_values_length,\n):\n    # [bsz, seq_len]\n    return attention_mask\n\n\nGROUP_SIZE_RATIO = 1 / 4\n\n\ndef flashattn_forward_with_s2attn(\n    self,\n    hidden_states: torch.Tensor,\n    attention_mask: Optional[torch.Tensor] = None,\n    position_ids: Optional[torch.Tensor] = None,\n    past_key_value: Optional[Tuple[torch.Tensor]] = None,\n    output_attentions: bool = False,\n    use_cache: bool = False,\n    padding_mask: Optional[torch.LongTensor] = None,\n    cu_seqlens: Optional[torch.Tensor] = None,\n    max_seqlen: Optional[torch.Tensor] = None,\n) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n    \"\"\"Input shape: Batch x Time x Channel\n\n    From: https://github.com/dvlab-research/LongLoRA/blob/main/llama_attn_replace.py\n\n    attention_mask: [bsz, q_len]\n\n    `cu_seqlens` will be ignored if provided\n    `max_seqlen` will be ignored if provided\n    \"\"\"\n    if output_attentions:\n        warnings.warn(\n            \"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead.\",\n            stacklevel=2,\n        )\n\n    bsz, q_len, _ = hidden_states.size()\n\n    query_states = (\n        self.q_proj(hidden_states)\n        .view(bsz, q_len, self.num_heads, self.head_dim)\n        .transpose(1, 2)\n    )\n    key_states = (\n        self.k_proj(hidden_states)\n        .view(bsz, q_len, self.num_key_value_heads, self.head_dim)\n        .transpose(1, 2)\n    )\n    value_states = (\n        self.v_proj(hidden_states)\n        .view(bsz, q_len, self.num_key_value_heads, self.head_dim)\n        .transpose(1, 2)\n    )\n    # [bsz, q_len, nh, hd]\n    # [bsz, nh, q_len, hd]\n\n    cos, sin = self.rotary_emb(value_states, position_ids=position_ids)\n    query_states, key_states = apply_rotary_pos_emb(\n        query_states, key_states, cos, sin, position_ids\n    )\n\n    # Past Key value support\n    if past_key_value is not None:\n        # reuse k, v, self_attention\n        key_states = torch.cat([past_key_value[0], key_states], dim=2)\n        value_states = torch.cat([past_key_value[1], value_states], dim=2)\n\n    past_key_value = (key_states, value_states) if use_cache else None\n\n    # repeat k/v heads if n_kv_heads < n_heads\n    key_states = repeat_kv(key_states, self.num_key_value_groups)\n    value_states = repeat_kv(value_states, self.num_key_value_groups)\n\n    # Flash attention codes from\n    # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py\n\n    # transform the data into the format required by flash attention\n    qkv = torch.stack(\n        [query_states, key_states, value_states], dim=2\n    )  # [bsz, nh, 3, q_len, hd]\n    qkv = qkv.transpose(1, 3)  # [bsz, q_len, 3, nh, hd]\n\n    # We have disabled _prepare_decoder_attention_mask in LlamaModel\n    # the attention_mask should be the same as the key_padding_mask\n\n    key_padding_mask = attention_mask.repeat(2, 1)\n    nheads = qkv.shape[-2]\n    # shift\n\n    group_size = int(q_len * GROUP_SIZE_RATIO)\n    if q_len % group_size > 0:\n        raise ValueError(\n            f\"q_len {q_len} should be divisible by group size {group_size}.\"\n        )\n\n    qkv = (\n        qkv.reshape(bsz, q_len, 3, 2, self.num_heads // 2, self.head_dim)\n        .permute(0, 3, 1, 2, 4, 5)\n        .reshape(bsz * 2, q_len, 3, self.num_heads // 2, self.head_dim)\n    )\n    x = rearrange(qkv, \"b s three h d -> b s (three h d)\")\n    x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)\n    cu_q_len_tmp = torch.arange(\n        0, max_s, group_size, device=key_padding_mask.device, dtype=cu_q_lens.dtype\n    )\n    cu_q_len_tmp = torch.stack([cu_q_len_tmp, cu_q_len_tmp + group_size // 2]).repeat(\n        bsz, 1\n    ) + cu_q_lens[:-1].unsqueeze(-1)\n    cu_q_lens = torch.cat([cu_q_len_tmp, cu_q_lens[1:].unsqueeze(-1)], dim=-1).view(-1)\n\n    x_unpad = rearrange(\n        x_unpad, \"nnz (three h d) -> nnz three h d\", three=3, h=nheads // 2\n    )\n    output_unpad = flash_attn_varlen_qkvpacked_func(\n        x_unpad, cu_q_lens, group_size, 0.0, softmax_scale=None, causal=True\n    )\n    output = rearrange(\n        pad_input(\n            rearrange(output_unpad, \"nnz h d -> nnz (h d)\"), indices, bsz * 2, q_len\n        ),\n        \"b s (h d) -> b s h d\",\n        h=nheads // 2,\n    )\n    output = (\n        output.reshape(bsz, 2, q_len, nheads // 2, self.head_dim)\n        .transpose(1, 2)\n        .reshape(bsz, q_len, nheads, self.head_dim)\n    )\n    return self.o_proj(rearrange(output, \"b s h d -> b s (h d)\")), None, past_key_value\n"
  },
  {
    "path": "src/axolotl/monkeypatch/llama_attn_hijack_xformers.py",
    "content": "\"\"\"\nDirectly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments\n\"\"\"\n\nimport warnings\nfrom typing import Optional, Tuple\n\nimport torch\nimport torch.nn.functional as F\nimport transformers.models.llama.modeling_llama\nfrom transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv\n\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\ntry:\n    import xformers.ops\nexcept ImportError:\n    LOG.error(\"xformers not found! Please install it before trying to use it.\")\n\n\ndef hijack_llama_attention():\n    transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward\n\n\ndef xformers_forward(\n    self,\n    hidden_states: torch.Tensor,\n    attention_mask: Optional[torch.Tensor] = None,\n    position_ids: Optional[torch.LongTensor] = None,\n    past_key_value: Optional[Tuple[torch.Tensor]] = None,\n    output_attentions: bool = False,\n    use_cache: bool = False,\n    padding_mask: Optional[torch.LongTensor] = None,\n    **kwargs,\n) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n    bsz, q_len, _ = hidden_states.size()\n\n    if not hasattr(self, \"pretraining_tp\"):\n        self.pretraining_tp = 1\n\n    if self.pretraining_tp > 1:\n        key_value_slicing = (\n            self.num_key_value_heads * self.head_dim\n        ) // self.pretraining_tp\n        query_slices = self.q_proj.weight.split(\n            (self.num_heads * self.head_dim) // self.pretraining_tp, dim=0\n        )\n        key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)\n        value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)\n\n        query_states = [\n            F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)\n        ]\n        query_states = torch.cat(query_states, dim=-1)\n\n        key_states = [\n            F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)\n        ]\n        key_states = torch.cat(key_states, dim=-1)\n\n        value_states = [\n            F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)\n        ]\n        value_states = torch.cat(value_states, dim=-1)\n\n    else:\n        query_states = self.q_proj(hidden_states)\n        key_states = self.k_proj(hidden_states)\n        value_states = self.v_proj(hidden_states)\n\n    query_states = query_states.view(\n        bsz, q_len, self.num_heads, self.head_dim\n    ).transpose(1, 2)\n    key_states = key_states.view(\n        bsz, q_len, self.num_key_value_heads, self.head_dim\n    ).transpose(1, 2)\n    value_states = value_states.view(\n        bsz, q_len, self.num_key_value_heads, self.head_dim\n    ).transpose(1, 2)\n    # [bsz, q_len, nh, hd]\n    # [bsz, nh, q_len, hd]\n\n    cos, sin = self.rotary_emb(value_states)\n    query_states, key_states = apply_rotary_pos_emb(\n        query_states, key_states, cos, sin, position_ids\n    )\n    # [bsz, nh, t, hd]\n\n    if past_key_value is not None:\n        # reuse k, v, self_attention\n        key_states = torch.cat([past_key_value[0], key_states], dim=2)\n        value_states = torch.cat([past_key_value[1], value_states], dim=2)\n\n    past_key_value = (key_states, value_states) if use_cache else None\n\n    # repeat k/v heads if n_kv_heads < n_heads\n    key_states = repeat_kv(key_states, self.num_key_value_groups)\n    value_states = repeat_kv(value_states, self.num_key_value_groups)\n\n    if output_attentions:\n        warnings.warn(\n            \"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead.\",\n            stacklevel=2,\n        )\n\n    #\n    # xformers-attn start\n    #\n\n    query_states = query_states.transpose(1, 2)\n    key_states = key_states.transpose(1, 2)\n    value_states = value_states.transpose(1, 2)\n\n    # This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.\n    # We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.\n    if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:\n        # input and output should be of form (bsz, q_len, num_heads, head_dim)\n        attn_output = xformers.ops.memory_efficient_attention(\n            query_states, key_states, value_states, attn_bias=None\n        )\n    else:\n        # input and output should be of form (bsz, q_len, num_heads, head_dim)\n        attn_output = xformers.ops.memory_efficient_attention(\n            query_states,\n            key_states,\n            value_states,\n            # attn_bias=attention_mask,\n            attn_bias=xformers.ops.LowerTriangularMask(),\n        )\n\n    if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):\n        raise ValueError(\n            f\"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is\"\n            f\" {attn_output.size()}\"\n        )\n    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)\n\n    #\n    # xformers-attn end\n    #\n\n    if self.pretraining_tp > 1:\n        attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)\n        o_proj_slices = self.o_proj.weight.split(\n            self.hidden_size // self.pretraining_tp, dim=1\n        )\n        attn_output = sum(\n            F.linear(attn_output[i], o_proj_slices[i])\n            for i in range(self.pretraining_tp)\n        )\n    else:\n        attn_output = self.o_proj(attn_output)\n\n    return attn_output, None, past_key_value\n"
  },
  {
    "path": "src/axolotl/monkeypatch/lora_kernels.py",
    "content": "\"\"\"Module for patching custom LoRA Triton kernels and `torch.autograd` functions.\"\"\"\n\nimport importlib\nimport inspect\nimport logging\nimport types\nfrom typing import Generator, Tuple, Type\n\nimport torch\nfrom peft import PeftModelForCausalLM\nfrom torch import nn\nfrom transformers import AutoConfig\n\nfrom axolotl.kernels.lora import (\n    apply_lora_mlp_geglu,\n    apply_lora_mlp_swiglu,\n    apply_lora_o,\n    apply_lora_qkv,\n)\nfrom axolotl.monkeypatch.utils import detab_code\nfrom axolotl.utils.callbacks.models import get_causal_lm_model_cls_prefix\nfrom axolotl.utils.dict import DictDefault\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\nQKV_PATCHES = [\n    (\n        \"\"\"\n    query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)\n    key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)\n    value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)\n\"\"\".lstrip(\"\\n\"),\n        \"\"\"\n    query_states, key_states, value_states = self.apply_qkv(hidden_states)\n    query_states = query_states.view(hidden_shape).transpose(1, 2)\n    key_states = key_states.view(hidden_shape).transpose(1, 2)\n    value_states = value_states.view(hidden_shape).transpose(1, 2)\n\"\"\".lstrip(\"\\n\"),\n    ),\n    (\n        \"\"\"\n    query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)\n    key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)\n    value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)\n\"\"\".lstrip(\"\\n\"),\n        \"\"\"\n    query_states, key_states, value_states = self.apply_qkv(hidden_states)\n    query_states = self.q_norm(query_states.view(hidden_shape)).transpose(1, 2)\n    key_states = self.k_norm(key_states.view(hidden_shape)).transpose(1, 2)\n    value_states = value_states.view(hidden_shape).transpose(1, 2)\n\"\"\".lstrip(\"\\n\"),\n    ),\n    (\n        \"\"\"\n    query_states, gate = torch.chunk(\n        self.q_proj(hidden_states).view(*input_shape, -1, self.head_dim * 2), 2, dim=-1\n    )\n    gate = gate.reshape(*input_shape, -1)\n\n    query_states = self.q_norm(query_states.view(hidden_shape)).transpose(1, 2)\n    key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)\n    value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)\n\"\"\".lstrip(\"\\n\"),\n        \"\"\"\n    query_states, key_states, value_states = self.apply_qkv(hidden_states)\n    query_states, gate = torch.chunk(\n        query_states.view(*input_shape, -1, self.head_dim * 2), 2, dim=-1\n    )\n    gate = gate.reshape(*input_shape, -1)\n\n    query_states = self.q_norm(query_states.view(hidden_shape)).transpose(1, 2)\n    key_states = self.k_norm(key_states.view(hidden_shape)).transpose(1, 2)\n    value_states = value_states.view(hidden_shape).transpose(1, 2)\n\"\"\".lstrip(\"\\n\"),\n    ),\n]\n\nORIGINAL_O_CODE = \"\"\"\n    attn_output = self.o_proj(attn_output)\n\"\"\".lstrip(\"\\n\")\n\nPATCHED_O_CODE = \"\"\"\n    attn_output = self.apply_o(attn_output)\n\"\"\".lstrip(\"\\n\")\n\nSUPPORTED_ACTIVATIONS = [\"silu\", \"gelu\"]\nAPPLY_FN_MAPPING = {\n    \"silu\": apply_lora_mlp_swiglu,\n    \"gelu\": apply_lora_mlp_geglu,\n}\n\n\ndef original_apply_qkv(\n    self: nn.Module, hidden_states: torch.Tensor\n) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n    \"\"\"\n    Original implementation of QKV projection without optimizations.\n\n    Args:\n        self: The attention module instance.\n        hidden_states: Input tensor of shape [batch_size, seq_len, hidden_dim].\n\n    Returns:\n        A tuple `(query_states, key_states, value_states)` containing the projected\n            states for query, key, and value.\n    \"\"\"\n    query_states = self.q_proj(hidden_states)\n    key_states = self.k_proj(hidden_states)\n    value_states = self.v_proj(hidden_states)\n\n    return query_states, key_states, value_states\n\n\ndef original_apply_o(self: nn.Module, hidden_states: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    Original implementation of output projection without optimizations.\n\n    Args:\n        self: The attention module instance.\n        hidden_states: Input tensor of shape `[`batch_size, seq_len, hidden_dim]`.\n\n    Returns:\n        The output projection result.\n    \"\"\"\n    attn_output = self.o_proj(hidden_states)\n\n    return attn_output\n\n\ndef get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:\n    \"\"\"\n    Get the appropriate attention class by inspecting the model config.\n    Uses dynamic import to support any model architecture that follows\n    the standard transformers naming convention.\n\n    Args:\n        cfg: Dictionary mapping `axolotl` config keys to values.\n\n    Returns:\n        The appropriate attention class for the model.\n\n    Raises:\n        ValueError: If `base_model` not specified or attention class cannot be imported\n        ImportError: If the model module or attention class doesn't exist\n    \"\"\"\n    if \"base_model\" not in cfg:\n        raise ValueError(\"base_model must be specified in config\")\n\n    # Get model config without loading the model\n    model_config = AutoConfig.from_pretrained(cfg[\"base_model\"])\n    model_type = model_config.model_type\n\n    # Special case for model_type = \"qwen2\"\n    if model_type == \"qwen2\":\n        from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention\n\n        return Qwen2Attention\n\n    if model_type == \"qwen3_vl\":\n        from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLTextAttention\n\n        return Qwen3VLTextAttention\n\n    if model_type == \"mllama\":\n        from transformers.models.mllama.modeling_mllama import MllamaTextSelfAttention\n\n        return MllamaTextSelfAttention\n\n    if model_type == \"llama4\":\n        from transformers.models.llama4.modeling_llama4 import Llama4TextAttention\n\n        return Llama4TextAttention\n\n    if model_type == \"mistral3\":\n        from transformers.models.mistral.modeling_mistral import MistralAttention\n\n        return MistralAttention\n\n    if model_type == \"gemma3_text\":\n        from transformers.models.gemma3.modeling_gemma3 import Gemma3Attention\n\n        return Gemma3Attention\n\n    try:\n        # Dynamically import the module and attention class\n        module_path = f\"transformers.models.{model_type}.modeling_{model_type}\"\n        model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type)\n        module = __import__(module_path, fromlist=[f\"{model_cls_prefix}Attention\"])\n        attention_cls = getattr(module, f\"{model_cls_prefix}Attention\")\n\n        return attention_cls\n    except (ImportError, AttributeError) as e:\n        raise ValueError(\n            f\"Axolotl could not import attention class for model_type: {model_type}. \"\n            \"Please raise an Issue and turn off lora kernels to continue training. \"\n            f\"Error: {str(e)}\"\n        ) from e\n\n\ndef patch_self_attn_lora(cfg: DictDefault):\n    \"\"\"\n    Given an `axolotl` config, this method patches the inferred attention class forward\n    pass with optimized LoRA implementations.\n\n    It modifies the attention class to use optimized QKV and output projections. The\n    original implementation is preserved and can be restored if needed.\n\n    Args:\n        cfg: Dictionary mapping `axolotl` config keys to values.\n\n    Raises:\n        AssertionError: If the required code blocks are not found in the attention\n            implementation.\n    \"\"\"\n    attention_cls = get_attention_cls_from_config(cfg)\n\n    # Check if already patched\n    if hasattr(attention_cls, \"_original_forward\"):\n        LOG.info(f\"{attention_cls.__name__} already patched\")\n        return\n\n    self_attn_forward = inspect.getsource(attention_cls.forward)\n    attention_cls._original_forward = self_attn_forward\n    self_attn_forward, _ = detab_code(self_attn_forward)\n\n    assert any(qkv_options[0] in self_attn_forward for qkv_options in QKV_PATCHES), (\n        \"Original QKV code not found\"\n    )\n    assert ORIGINAL_O_CODE in self_attn_forward, \"Original O code not found\"\n\n    for qkv_orig, qkv_patched in QKV_PATCHES:\n        if qkv_orig in self_attn_forward:\n            self_attn_forward = self_attn_forward.replace(\n                qkv_orig,\n                qkv_patched,\n            )\n            break\n    self_attn_forward = self_attn_forward.replace(ORIGINAL_O_CODE, PATCHED_O_CODE)\n    self_attn_forward = self_attn_forward.replace(\n        \"def forward(\",\n        \"def axolotl_attn_forward(\",\n        1,\n    )\n\n    # Load necessary imports\n    module_name = attention_cls.__module__\n    module = importlib.import_module(module_name)\n\n    items_to_import = []\n    for item in dir(module):\n        if item in self_attn_forward:\n            items_to_import.append(item)\n\n    exec(\n        f\"from {module_name} import ({', '.join(items_to_import)})\",\n        globals(),\n    )\n    exec(self_attn_forward, globals())\n\n    LOG.info(f\"Patched attention class with LoRA optims: {attention_cls.__name__}\")\n    attention_cls.forward = axolotl_attn_forward\n\n\ndef find_self_attn_in_layer(\n    layer: nn.Module,\n) -> Generator[Tuple[nn.Module], None, None]:\n    # general case of most models\n    if hasattr(layer, \"self_attn\"):\n        if all(\n            hasattr(layer.self_attn, proj)\n            for proj in [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\"]\n        ):\n            yield layer.self_attn\n\n\ndef find_mlp_in_layer(\n    layer: nn.Module,\n) -> Generator[Tuple[nn.Module, nn.Module, nn.Module, nn.Module], None, None]:\n    # general case of most models\n    if hasattr(layer, \"mlp\"):\n        if all(\n            hasattr(layer.mlp, proj) for proj in [\"gate_proj\", \"up_proj\", \"down_proj\"]\n        ):\n            yield layer.mlp.gate_proj, layer.mlp.up_proj, layer.mlp.down_proj, layer.mlp\n    # llama4 linearized experts\n    if hasattr(layer, \"feedforward\") and hasattr(layer.feedforward, \"shared_expert\"):\n        mlp = layer.feedforward.shared_expert\n        yield mlp.gate_proj, mlp.up_proj, mlp.down_proj, mlp\n    if hasattr(layer, \"feedforward\") and hasattr(layer.feedforward, \"experts\"):\n        if all(\n            hasattr(layer.feedforward.experts, proj)\n            for proj in [\"gate_projs\", \"up_projs\", \"down_projs\"]\n        ):\n            for gate_proj, up_proj, down_proj in zip(\n                layer.feedforward.experts.gate_projs,\n                layer.feedforward.experts.up_projs,\n                layer.feedforward.experts.down_projs,\n                strict=False,\n            ):\n                yield (\n                    gate_proj,\n                    up_proj,\n                    down_proj,\n                    FakeMLP(gate_proj, up_proj, down_proj),\n                )\n\n\ndef get_layers(model: PeftModelForCausalLM) -> list[nn.Module]:\n    \"\"\"\n    Get the layers of the model. Handles text-only and multimodal models.\n\n    Args:\n        model: A PEFT model.\n\n    Returns:\n        A list of layers.\n    \"\"\"\n    pretrained_model = model.model\n\n    # check for multimodal models first\n    if hasattr(pretrained_model, \"language_model\"):\n        return pretrained_model.language_model.layers\n    if hasattr(pretrained_model, \"model\"):\n        if hasattr(pretrained_model.model, \"language_model\"):\n            return pretrained_model.model.language_model.layers\n        return pretrained_model.model.layers\n\n    raise NotImplementedError(\n        f\"Model type {model.config.model_type} is not supported yet. Please create an Issue.\"\n    )\n\n\ndef apply_lora_kernel_patches(\n    model: PeftModelForCausalLM, cfg: DictDefault\n) -> PeftModelForCausalLM:\n    \"\"\"\n    Applies optimized Triton kernel patches to a PEFT model.\n\n    Patches a PEFT model with optimized implementations for MLP and attention\n    computations. The optimizations include custom Triton kernels for activation\n    functions and specialized autograd functions for LoRA computations.\n\n    Args:\n        model: A PEFT model to be patched with optimized kernels.\n        cfg: Dictionary mapping `axolotl` config keys to values.\n\n    Returns:\n        PeftModelForCausalLM: The patched model with optimized kernels.\n\n    Raises:\n        TypeError: If the provided model is not a `PeftModelForCausalLM`.\n        NotImplementedError: If the model type is not supported.\n        AssertionError: If multiple adapters are active (currently unsupported).\n\n    Note:\n        The optimizations require LoRA adapters with no dropout and no bias terms. The\n            function will skip patching if these conditions aren't met.\n    \"\"\"\n    if not isinstance(model, PeftModelForCausalLM):\n        raise TypeError(\"Model must be a PeftModelForCausalLM\")\n\n    # Get active LoRA adapter config\n    if hasattr(model, \"active_adapters\"):\n        assert len(model.active_adapters) == 1, (\n            \"Axolotl currently does not support LoRA Triton kernels for multiple adapters\"\n        )\n        active_adapter = model.active_adapters[0]\n    else:\n        active_adapter = model.active_adapter\n    lora_config = model.model.peft_config[active_adapter]\n\n    # Only patch if conditions are met\n    can_patch = lora_config.lora_dropout == 0 and lora_config.bias == \"none\"\n\n    if not can_patch:\n        LOG.warning(\"Cannot patch layers - requires no dropout and no bias\")\n        LOG.warning(\"Please specify `lora_dropout: 0` in your axolotl config file\")\n        return model\n\n    # This needs to be reset after patching\n    original_level = LOG.getEffectiveLevel()\n    LOG.setLevel(logging.INFO)\n\n    # Choose activation based on model type\n    activation = None\n    text_config = (\n        model.config.get_text_config()\n        if hasattr(model.config, \"get_text_config\")\n        else model.config\n    )\n    if hasattr(text_config, \"hidden_act\"):\n        activation = text_config.hidden_act\n    elif hasattr(text_config, \"hidden_activation\"):\n        activation = text_config.hidden_activation\n\n    # map activation to supported activation\n    if \"gelu\" in activation:\n        # gemma3 uses gelu_pytorch_tanh\n        activation = \"gelu\"\n\n    if activation not in SUPPORTED_ACTIVATIONS:\n        raise NotImplementedError(f\"Activation {activation} is not supported\")\n\n    layers = get_layers(model)\n\n    # Patch each layer\n    for layer in layers:\n        # Add QKV, O fallback implementations to start\n        # These will be overwritten later (if some conditions apply)\n        for self_attn in find_self_attn_in_layer(layer):\n            self_attn.apply_qkv = types.MethodType(original_apply_qkv, self_attn)\n            self_attn.apply_o = types.MethodType(original_apply_o, self_attn)\n\n            if cfg.lora_qkv_kernel:\n                # Query, key, value patching\n                layer_modules = [\n                    getattr(self_attn, linear_proj)\n                    for linear_proj in [\"q_proj\", \"k_proj\", \"v_proj\"]\n                ]\n                can_patch_qkv = all(\n                    hasattr(module, \"lora_A\")\n                    and len(getattr(module, \"lora_magnitude_vector\", []) or []) == 0\n                    for module in layer_modules\n                )\n\n                if can_patch_qkv:\n                    # Add optimized implementation\n                    self_attn.apply_qkv = types.MethodType(apply_lora_qkv, self_attn)\n                else:\n                    LOG.warning_once(\n                        \"Cannot patch some attention QKV projections - requires LoRA \"\n                        \"adapters and no lora_magnitude_vector (DoRA)\"\n                    )\n            if cfg.lora_o_kernel:\n                # Output patching\n                layer_modules = [\n                    getattr(self_attn, linear_proj) for linear_proj in [\"o_proj\"]\n                ]\n                can_patch_o = all(\n                    hasattr(module, \"lora_A\")\n                    and len(getattr(module, \"lora_magnitude_vector\", []) or []) == 0\n                    for module in layer_modules\n                )\n\n                if can_patch_o:\n                    self_attn.apply_o = types.MethodType(apply_lora_o, self_attn)\n                else:\n                    LOG.warning_once(\n                        \"Cannot patch some attention output projection - requires LoRA \"\n                        \"adapters and no lora_magnitude_vector (DoRA)\"\n                    )\n        for gate_proj, up_proj, down_proj, mlp in find_mlp_in_layer(layer):\n            if cfg.lora_mlp_kernel:\n                # MLP patching\n                can_patch_mlp = all(\n                    hasattr(proj, \"lora_A\")\n                    and len(getattr(proj, \"lora_magnitude_vector\", []) or []) == 0\n                    for proj in (gate_proj, up_proj, down_proj)\n                )\n\n                if can_patch_mlp:\n                    apply_fn = APPLY_FN_MAPPING[activation]\n                    layer.mlp.forward = types.MethodType(apply_fn, mlp)\n                else:\n                    LOG.warning_once(\n                        \"Cannot patch some MLP layers - requires LoRA adapters and no \"\n                        \"lora_magnitude_vector (DoRA)\"\n                    )\n\n    LOG.setLevel(original_level)\n\n    return model\n\n\nclass FakeMLP(nn.Module):\n    \"\"\"\n    placeholder MLP for triton patching\n    \"\"\"\n\n    gate_proj: nn.Linear\n    up_proj: nn.Linear\n    down_proj: nn.Linear\n\n    def __init__(self, gate_proj, up_proj, down_proj):\n        super().__init__()\n        self.gate_proj = gate_proj\n        self.up_proj = up_proj\n        self.down_proj = down_proj\n"
  },
  {
    "path": "src/axolotl/monkeypatch/loss/__init__.py",
    "content": ""
  },
  {
    "path": "src/axolotl/monkeypatch/loss/chunked.py",
    "content": "\"\"\"\nchunked ce loss\n\"\"\"\n\nfrom typing import List, Optional\n\nimport torch\nimport torch.nn.functional as F\n\n\n# copied and modified from torchtune.modules.loss.CEWithChunkedOutputLoss\nclass CEWithChunkedOutputLoss(torch.nn.Module):\n    \"\"\"\n    Cross-entropy with chunked outputs that saves memory by only upcasting one chunk at a time.\n\n    For more details, please refer to: https://github.com/pytorch/torchtune/pull/1390\n    \"\"\"\n\n    def __init__(self, num_output_chunks: int = 8, ignore_index: int = -100):\n        super().__init__()\n        self.num_output_chunks = num_output_chunks\n        self.ignore_index = ignore_index\n\n    def compute_cross_entropy(\n        self,\n        logits: torch.Tensor,\n        labels: torch.Tensor,\n        normalize: bool = True,\n    ) -> torch.Tensor:\n        \"\"\"\n        Upcast logits to fp32 and compute cross entropy loss.\n        \"\"\"\n        return F.cross_entropy(\n            logits.float(), labels, ignore_index=self.ignore_index, reduction=\"sum\"\n        )\n\n    def forward(\n        self, logits: List[torch.Tensor], labels: torch.Tensor, reduction=\"sum\"\n    ) -> torch.Tensor:\n        \"\"\"\n        Args:\n            logits (List[torch.Tensor]): List of chunked logits of length\n                ``self.num_output_chunks``, where each chunk has shape\n                ``(batch_size, num_tokens / num_output_chunks, vocab_size)``.\n            labels (torch.Tensor): Ground truth labels of shape ``(batch_size, num_tokens)``.\n            reduction (str): The reduction to apply to the output.\n\n        Returns:\n            torch.Tensor: Cross entropy loss of shape (1,).\n        \"\"\"\n\n        total_elements = (labels != self.ignore_index).sum()\n\n        # chunk and reshape labels (bsz, num_tokens, vocab) -> [(bsz*num_tokens/num_chunks, vocab)]\n        labels = [\n            target_chunk.reshape(-1)\n            for target_chunk in labels.chunk(self.num_output_chunks, dim=1)\n        ]\n        # reshape logits [(bsz, num_tokens/num_chunks, vocab)] -> [(bsz*num_tokens/num_chunks, vocab)]\n        logits = [\n            logit_chunk.reshape(-1, logit_chunk.size(-1)) for logit_chunk in logits\n        ]\n\n        # compute one chunk at a time\n        total_loss = 0.0\n        for logits_chunk, labels_chunk in zip(logits, labels, strict=False):\n            total_loss += self.compute_cross_entropy(logits_chunk, labels_chunk)\n\n        if reduction == \"sum\":\n            return total_loss\n        return total_loss / total_elements\n\n\ndef _build_chunked_ce_loss_fn(num_output_chunks: int = 8, ignore_index: int = -100):\n    loss_fn_ce = CEWithChunkedOutputLoss(num_output_chunks, ignore_index)\n    loss_fn_ce.compute_cross_entropy = torch.compile(\n        loss_fn_ce.compute_cross_entropy, backend=\"inductor\"\n    )\n    return loss_fn_ce\n\n\ndef get_causal_lm_loss(num_output_chunks: int = 8, ignore_index: int = -100):\n    loss_fn_ce = _build_chunked_ce_loss_fn(num_output_chunks, ignore_index)\n\n    def chunked_fix_cross_entropy(\n        source,\n        target,\n        num_items_in_batch: int = None,\n        ignore_index: int = -100,\n        **kwargs,\n    ):\n        reduction = \"sum\" if num_items_in_batch is not None else \"mean\"\n        logit_chunks = [\n            chunk for chunk in source.chunk(loss_fn_ce.num_output_chunks, dim=1)\n        ]\n        loss = loss_fn_ce(logit_chunks, target, reduction=reduction)\n        if reduction == \"sum\":\n            loss = loss / num_items_in_batch\n        return loss\n\n    def for_causal_lm_chunked_loss(\n        logits,\n        labels,\n        vocab_size: int = None,\n        num_items_in_batch: Optional[int] = None,\n        ignore_index: int = -100,\n        shift_labels: Optional[torch.Tensor] = None,\n        **kwargs,\n    ) -> torch.Tensor:\n        # skip the upcast to float since we handle that in the chunking loss\n        if shift_labels is None:\n            # Shift so that tokens < n predict n\n            labels = F.pad(labels, (0, 1), value=ignore_index)\n            shift_labels = labels[..., 1:].contiguous()\n\n        # Skip Flattening the tokens\n        # Enable model parallelism\n        shift_labels = shift_labels.to(logits.device)\n        loss = chunked_fix_cross_entropy(\n            logits, shift_labels, num_items_in_batch, ignore_index, **kwargs\n        )\n        return loss\n\n    return for_causal_lm_chunked_loss\n\n\ndef patch_chunked_ce_loss_fn(num_output_chunks: int = 8, ignore_index: int = -100):\n    import transformers.loss.loss_utils\n\n    for_causal_lm_chunked_loss = get_causal_lm_loss(num_output_chunks, ignore_index)\n    transformers.loss.loss_utils.ForCausalLMLoss = for_causal_lm_chunked_loss\n    transformers.loss.loss_utils.LOSS_MAPPING[\"ForCausalLM\"] = (\n        for_causal_lm_chunked_loss\n    )\n"
  },
  {
    "path": "src/axolotl/monkeypatch/loss/eaft.py",
    "content": "\"\"\"\neaft (entropy-aware focal training) loss implementation\nweights examples by entropy approximation from top-k logits\n\nReference: https://github.com/ymxyll/LlamaFactory-EAFT/blob/e2ce19e8efcc226450ee8f2b81dfe4e69f1f945d/src/llamafactory/train/trainer_utils.py\n\"\"\"\n\nimport torch\nimport torch.nn.functional as F\n\n\ndef eaft_loss(outputs, labels, num_items_in_batch=None, alpha=1.0, k=20):\n    \"\"\"\n    compute eaft loss with entropy weighting\n\n    args:\n        outputs: model outputs containing logits\n        labels: target labels for computing loss\n        num_items_in_batch: for sample packing support\n        alpha: exponent for entropy weighting (default 1.0)\n        k: number of top logits for entropy approximation (default 20)\n    \"\"\"\n    logits = outputs.logits\n\n    shift_logits = logits[..., :-1, :].contiguous()\n    shift_labels = labels[..., 1:].contiguous()\n\n    vocab_size = shift_logits.size(-1)\n    shift_logits_view = shift_logits.view(-1, vocab_size)\n    shift_labels_view = shift_labels.view(-1)\n\n    mask = shift_labels_view != -100\n\n    with torch.no_grad():\n        top_k_logits, _ = torch.topk(\n            shift_logits_view[mask].float(), k=min(k, vocab_size), dim=-1\n        )\n        top_k_probs = F.softmax(top_k_logits, dim=-1)\n        entropy = -(top_k_probs * torch.log(top_k_probs + 1e-10)).sum(dim=-1)\n        weights = torch.pow(entropy, alpha)\n\n    loss_fct = torch.nn.CrossEntropyLoss(reduction=\"none\")\n    per_token_loss = loss_fct(shift_logits_view[mask], shift_labels_view[mask])\n    weighted_loss = per_token_loss * weights\n\n    if num_items_in_batch is not None:\n        loss = weighted_loss.sum() / num_items_in_batch\n    else:\n        loss = weighted_loss.mean()\n\n    return loss\n"
  },
  {
    "path": "src/axolotl/monkeypatch/mistral_attn_hijack_flash.py",
    "content": "\"\"\"Flash attention monkey patch for mistral model\"\"\"\n\nfrom functools import partial\n\nimport transformers\n\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\n\ndef patch_mistral_cross_entropy():\n    from flash_attn.losses.cross_entropy import CrossEntropyLoss\n\n    LOG.info(\"patching with flash_attn.losses.cross_entropy\")\n    transformers.models.mistral.modeling_mistral.CrossEntropyLoss = partial(\n        CrossEntropyLoss, inplace_backward=True\n    )\n"
  },
  {
    "path": "src/axolotl/monkeypatch/mixtral/__init__.py",
    "content": "\"\"\"\nPatches to support multipack for mixtral\n\"\"\"\n\nimport torch\n\n\ndef patch_mixtral_moe_forward_zero3() -> None:\n    import torch.nn.functional as F\n\n    def mlp_forward(self, hidden_states):\n        current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(\n            hidden_states\n        )\n        current_hidden_states = self.w2(current_hidden_states)\n        return current_hidden_states\n\n    # Ref. https://huggingface.co/deepseek-ai/deepseek-moe-16b-base/blob/main/modeling_deepseek.py\n    def moe_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        batch_size, sequence_length, hidden_dim = hidden_states.shape\n        hidden_states = hidden_states.view(-1, hidden_dim)\n        # router_logits: (batch * sequence_length, n_experts)\n        router_logits = self.gate(hidden_states)\n\n        routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)\n        topk_weight, topk_idx = torch.topk(\n            routing_weights, self.top_k, dim=-1, sorted=False\n        )\n        topk_weight /= topk_weight.sum(dim=-1, keepdim=True)\n        # we cast back to the input dtype\n        topk_weight = topk_weight.to(hidden_states.dtype)\n\n        hidden_states = hidden_states.repeat_interleave(self.top_k, dim=0)\n        y = torch.empty_like(hidden_states)\n        flat_topk_idx = topk_idx.view(-1)\n        for i in range(self.num_experts):\n            expert = self.experts[i]\n            y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i])\n        y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)\n        final_hidden_states = y.reshape(batch_size, sequence_length, hidden_dim)\n        return final_hidden_states, router_logits\n\n    from transformers.models.mixtral.modeling_mixtral import (\n        MixtralBlockSparseTop2MLP,\n        MixtralSparseMoeBlock,\n    )\n\n    MixtralBlockSparseTop2MLP.forward = mlp_forward\n    MixtralSparseMoeBlock.forward = moe_forward\n"
  },
  {
    "path": "src/axolotl/monkeypatch/models/__init__.py",
    "content": ""
  },
  {
    "path": "src/axolotl/monkeypatch/models/apertus/__init__.py",
    "content": ""
  },
  {
    "path": "src/axolotl/monkeypatch/models/apertus/activation.py",
    "content": "\"\"\"Monkeypatch for Apertus to dtype mismatch in XIELU act\"\"\"\n\nfrom torch import Tensor\n\n\ndef patch_apertus_xielu_activation():\n    try:\n        from transformers.activations import XIELUActivation\n    except ImportError as err:\n        raise ImportError(\n            \"Cannot import XIELUActivation. \"\n            \"Please make sure to update your transformers version >= 4.56.1.\"\n        ) from err\n\n    from transformers.activations import logger\n\n    # Store the original method\n    old_fn = XIELUActivation._xielu_cuda\n\n    def _xielu_cuda_fixed(self, x: Tensor) -> Tensor:\n        \"\"\"Firewall function to prevent torch.compile from seeing .item() calls\"\"\"\n        original_shape = x.shape\n        # CUDA kernel expects 3D tensors, reshape if needed\n        while x.dim() < 3:\n            x = x.unsqueeze(0)\n        if x.dim() > 3:\n            x = x.view(-1, 1, x.size(-1))\n        if original_shape != x.shape:\n            logger.warning_once(\n                \"Warning: xIELU input tensor expects 3 dimensions but got (shape: %s). Reshaping to (shape: %s).\",\n                original_shape,\n                x.shape,\n            )\n        result = self._xielu_cuda_obj.forward(\n            x,\n            self.alpha_p.to(x.dtype),\n            self.alpha_n.to(x.dtype),\n            # Temporary until xIELU CUDA fully implemented -> self.{beta,eps}.item()\n            self._beta_scalar,\n            self._eps_scalar,\n            self.with_vector_loads,\n        )\n        return result.view(original_shape)\n\n    # Apply the patch\n    XIELUActivation._xielu_cuda = _xielu_cuda_fixed\n\n    def unpatch():\n        \"\"\"Restore the original method\"\"\"\n        XIELUActivation._xielu_cuda = old_fn\n\n    return unpatch\n"
  },
  {
    "path": "src/axolotl/monkeypatch/models/kimi_linear/__init__.py",
    "content": ""
  },
  {
    "path": "src/axolotl/monkeypatch/models/kimi_linear/configuration_kimi.py",
    "content": "\"\"\"\nKimi-Linear configuration.\n\nSource: https://huggingface.co/moonshotai/Kimi-Linear-48B-A3B-Instruct/blob/main/configuration_kimi.py\nRevision: 6e163f3\n\"\"\"\n\nfrom typing import Optional\n\nfrom transformers.configuration_utils import PretrainedConfig\n\n\nclass KimiLinearConfig(PretrainedConfig):\n    model_type = \"kimi_linear\"\n    keys_to_ignore_at_inference = [\"past_key_values\"]\n\n    def __init__(\n        self,\n        model_type=\"kimi_linear\",\n        vocab_size=163840,\n        hidden_size=4096,\n        head_dim=None,\n        intermediate_size=11008,\n        num_hidden_layers=32,\n        num_attention_heads=32,\n        num_key_value_heads=None,\n        hidden_act=\"silu\",\n        initializer_range=0.02,\n        rms_norm_eps=1e-6,\n        use_cache=True,\n        pad_token_id=0,\n        bos_token_id=1,\n        eos_token_id=2,\n        rope_theta=10000.0,\n        rope_scaling=None,\n        tie_word_embeddings=False,\n        moe_intermediate_size: Optional[int] = None,\n        moe_renormalize: bool = True,\n        moe_router_activation_func: str = \"sigmoid\",\n        num_experts: Optional[int] = None,\n        num_experts_per_token: Optional[int] = None,\n        num_shared_experts: int = 0,\n        routed_scaling_factor: float = 1.0,\n        first_k_dense_replace: int = 0,\n        moe_layer_freq: int = 1,\n        use_grouped_topk: bool = True,\n        num_expert_group: int = 1,\n        topk_group: int = 1,\n        q_lora_rank: Optional[int] = None,\n        kv_lora_rank: Optional[int] = None,\n        qk_nope_head_dim: Optional[int] = None,\n        qk_rope_head_dim: Optional[int] = None,\n        v_head_dim: Optional[int] = None,\n        mla_use_nope: Optional[bool] = False,\n        num_nextn_predict_layers: int = 0,\n        linear_attn_config: Optional[dict] = None,\n        router_aux_loss_coef: float = 0.01,\n        **kwargs,\n    ):\n        self.model_type = model_type\n        self.vocab_size = vocab_size\n        self.hidden_size = hidden_size\n        self.head_dim = (\n            head_dim if head_dim is not None else hidden_size // num_attention_heads\n        )\n        self.intermediate_size = intermediate_size\n        self.num_hidden_layers = num_hidden_layers\n        self.num_attention_heads = num_attention_heads\n\n        # for backward compatibility\n        if num_key_value_heads is None:\n            num_key_value_heads = num_attention_heads\n\n        self.num_key_value_heads = num_key_value_heads\n        self.hidden_act = hidden_act\n        self.initializer_range = initializer_range\n        self.rms_norm_eps = rms_norm_eps\n        self.use_cache = use_cache\n        self.rope_theta = rope_theta\n        self.rope_scaling = rope_scaling\n\n        self.q_lora_rank = q_lora_rank\n        self.kv_lora_rank = kv_lora_rank\n        self.qk_nope_head_dim = qk_nope_head_dim\n        self.qk_rope_head_dim = qk_rope_head_dim\n        self.v_head_dim = v_head_dim\n        self.mla_use_nope = mla_use_nope\n        # moe config\n        self.num_experts = num_experts\n        self.num_experts_per_token = num_experts_per_token\n        self.moe_renormalize = moe_renormalize\n        self.num_shared_experts = num_shared_experts\n        self.routed_scaling_factor = routed_scaling_factor\n        self.moe_router_activation_func = moe_router_activation_func\n        assert self.moe_router_activation_func in (\"softmax\", \"sigmoid\")\n        self.moe_intermediate_size = moe_intermediate_size\n        self.first_k_dense_replace = first_k_dense_replace\n        self.moe_layer_freq = moe_layer_freq\n        self.use_grouped_topk = use_grouped_topk\n        self.num_expert_group = num_expert_group\n        self.topk_group = topk_group\n        self.num_nextn_predict_layers = num_nextn_predict_layers\n        self.router_aux_loss_coef = router_aux_loss_coef\n\n        if linear_attn_config is not None:\n            assert linear_attn_config[\"kda_layers\"] is not None\n            assert linear_attn_config[\"full_attn_layers\"] is not None\n        self.linear_attn_config = linear_attn_config\n\n        super().__init__(\n            pad_token_id=pad_token_id,\n            bos_token_id=bos_token_id,\n            eos_token_id=eos_token_id,\n            tie_word_embeddings=tie_word_embeddings,\n            **kwargs,\n        )\n\n    @property\n    def is_mla(self):\n        return (\n            self.q_lora_rank is not None\n            or self.kv_lora_rank is not None\n            or self.qk_nope_head_dim is not None\n            or self.qk_rope_head_dim is not None\n            or self.v_head_dim is not None\n            or self.mla_use_nope is True\n        )\n\n    @property\n    def is_moe(self):\n        return self.num_experts is not None\n\n    @property\n    def is_linear_attn(self) -> bool:\n        return not (\n            self.linear_attn_config is None\n            or (\n                isinstance(self.linear_attn_config, dict)\n                and self.linear_attn_config[\"kda_layers\"] is not None\n                and len(self.linear_attn_config[\"kda_layers\"]) == 0\n            )\n        )\n\n    def is_kda_layer(self, layer_idx: int):\n        return (\n            self.linear_attn_config is not None\n            and (layer_idx + 1) in self.linear_attn_config[\"kda_layers\"]\n        )\n"
  },
  {
    "path": "src/axolotl/monkeypatch/models/kimi_linear/modeling_kimi.py",
    "content": "\"\"\"\nAdapted Kimi-Linear modeling to enable MoE differentiable.\n\nSource: https://huggingface.co/moonshotai/Kimi-Linear-48B-A3B-Instruct/blob/main/modeling_kimi.py\nRevision: 6e163f3\n\"\"\"\n\nimport math\nfrom collections.abc import Callable\nfrom typing import Any, List, Optional, Tuple, Union\n\nimport torch\nimport torch.nn.functional as F\nimport transformers\nfrom einops import rearrange\nfrom packaging import version\nfrom torch import nn\nfrom transformers.activations import ACT2FN\nfrom transformers.cache_utils import Cache\nfrom transformers.generation import GenerationMixin\nfrom transformers.masking_utils import create_causal_mask\nfrom transformers.modeling_flash_attention_utils import FlashAttentionKwargs\nfrom transformers.modeling_outputs import (\n    BaseModelOutputWithPast,\n    CausalLMOutputWithPast,\n    MoeCausalLMOutputWithPast,\n)\nfrom transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel\nfrom transformers.processing_utils import Unpack\nfrom transformers.pytorch_utils import ALL_LAYERNORM_LAYERS\nfrom transformers.utils import (\n    TransformersKwargs,\n    can_return_tuple,\n    logging,\n)\nfrom transformers.utils.generic import OutputRecorder\n\ntry:\n    from fla.layers.utils import get_unpad_data, index_first_axis, pad_input\n    from fla.modules import FusedRMSNormGated, ShortConvolution\n    from fla.ops.kda import chunk_kda, fused_recurrent_kda\n    from fla.ops.kda.gate import fused_kda_gate\nexcept ImportError as err:\n    raise ImportError(\n        \"Plese run `pip uninstall fla-core flash-linear-attention -y && pip install git+https://github.com/fla-org/flash-linear-attention@v0.4.0`\"\n    ) from err\n\nfrom axolotl.monkeypatch.models.kimi_linear.configuration_kimi import KimiLinearConfig\n\nassert version.parse(transformers.__version__) >= version.parse(\"4.56.0\"), (\n    \"Please upgrade transformers to >= 4.56.0\"\n)\n\nlogger = logging.get_logger(__name__)\n\n\ndef load_balancing_loss_func(\n    gate_logits: Union[torch.Tensor, tuple[torch.Tensor], None],\n    num_experts: Optional[int] = None,\n    top_k=2,\n    attention_mask: Optional[torch.Tensor] = None,\n) -> Union[torch.Tensor, int]:\n    \"\"\"Standard Switch Transformer load balancing loss.\"\"\"\n    if gate_logits is None or not isinstance(gate_logits, tuple):\n        return 0\n\n    # Concatenate all layer logits\n    concatenated_gate_logits = torch.cat(\n        [layer_gate for layer_gate in gate_logits], dim=0\n    )\n\n    routing_weights = F.softmax(concatenated_gate_logits, dim=-1)\n    _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)\n    expert_mask = F.one_hot(selected_experts, num_experts)\n\n    tokens_per_expert = torch.mean(expert_mask.float(), dim=0)\n    router_prob_per_expert = torch.mean(routing_weights, dim=0)\n\n    overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))\n    return overall_loss * num_experts\n\n\nclass KimiDynamicCache:\n    \"\"\"\n    Dynamic cache for Kimi model.\n    Inspired by Qwen3-Next\n    \"\"\"\n\n    is_compileable = False\n\n    def __init__(self, config: KimiLinearConfig):\n        super().__init__()\n        self.config = config\n\n        if config.linear_attn_config is not None:\n            self.layer_types = []\n            for i in range(config.num_hidden_layers):\n                if config.is_kda_layer(i):\n                    self.layer_types.append(\"linear_attention\")\n                else:\n                    self.layer_types.append(\"full_attention\")\n        else:\n            self.layer_types = [\"full_attention\"] * config.num_hidden_layers\n\n        self.transformer_layers = [\n            i\n            for i in range(config.num_hidden_layers)\n            if self.layer_types[i] == \"full_attention\"\n        ]\n\n        linear_layers = [\n            i\n            for i in range(config.num_hidden_layers)\n            if self.layer_types[i] == \"linear_attention\"\n        ]\n        self.last_linear_layer = linear_layers[-1] if linear_layers else -1\n\n        self.conv_states = [None for _ in range(config.num_hidden_layers)]\n        self.recurrent_states = [None for _ in range(config.num_hidden_layers)]\n        self.key_cache = [None for _ in range(config.num_hidden_layers)]\n        self.value_cache = [None for _ in range(config.num_hidden_layers)]\n\n    def __len__(self):\n        return len(self.layer_types)\n\n    def update(\n        self,\n        key_states: torch.Tensor,\n        value_states: torch.Tensor,\n        layer_idx: int,\n        cache_kwargs: Optional[dict[str, Any]] = None,\n    ) -> tuple[torch.Tensor, torch.Tensor]:\n        if self.key_cache[layer_idx] is None:\n            self.key_cache[layer_idx] = key_states\n            self.value_cache[layer_idx] = value_states\n        else:\n            self.key_cache[layer_idx] = torch.cat(\n                [self.key_cache[layer_idx], key_states], dim=2\n            )\n            self.value_cache[layer_idx] = torch.cat(\n                [self.value_cache[layer_idx], value_states], dim=2\n            )\n\n        return self.key_cache[layer_idx], self.value_cache[layer_idx]\n\n    def reorder_cache(self, beam_idx: torch.LongTensor):\n        \"\"\"Reorders the cache for beam search, given the selected beam indices.\"\"\"\n        for layer_idx in range(len(self.key_cache)):\n            if self.key_cache[layer_idx] is not None:\n                device = self.key_cache[layer_idx].device\n                beam_idx = beam_idx.to(device)\n                self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(\n                    0, beam_idx\n                )\n                self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(\n                    0, beam_idx\n                )\n\n            if self.conv_states[layer_idx] is not None:\n                device = self.conv_states[layer_idx][0].device\n                beam_idx = beam_idx.to(device)\n                q_conv, k_conv, v_conv = self.conv_states[layer_idx]\n                self.conv_states[layer_idx] = (\n                    q_conv.index_select(0, beam_idx),\n                    k_conv.index_select(0, beam_idx),\n                    v_conv.index_select(0, beam_idx),\n                )\n                self.recurrent_states[layer_idx] = self.recurrent_states[\n                    layer_idx\n                ].index_select(0, beam_idx)\n\n    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:\n        \"\"\"Returns the sequence length of the cached states. A layer index can be optionally passed.\"\"\"\n        # take any layer that contains cache and not empty tensor\n        layer_idx = (\n            self.transformer_layers[0]\n            if layer_idx not in self.transformer_layers\n            else layer_idx\n        )\n        if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx] is None:\n            return 0\n        return self.key_cache[layer_idx].shape[-2]\n\n    def get_mask_sizes(\n        self, cache_position: torch.Tensor, layer_idx: int\n    ) -> tuple[int, int]:\n        \"\"\"\n        Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for\n        the given layer at `layer_idx`.\n        The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns for each layer.\n        \"\"\"\n        kv_offset = 0\n        query_length = cache_position.shape[0]\n        past_seen_tokens = self.get_seq_length(layer_idx)\n        kv_length = query_length + past_seen_tokens\n        return kv_length, kv_offset\n\n    @property\n    def has_previous_state(self):\n        \"\"\"We have a previous state if the last linear (conv) layer was already updated.\"\"\"\n        if self.last_linear_layer == -1:\n            return False\n        return self.conv_states[self.last_linear_layer] is not None\n\n\nclass KimiRMSNorm(nn.Module):\n    def __init__(self, hidden_size, eps=1e-6):\n        \"\"\"\n        KimiRMSNorm is equivalent to T5LayerNorm\n        \"\"\"\n        super().__init__()\n        self.weight = nn.Parameter(torch.ones(hidden_size))\n        self.variance_epsilon = eps\n\n    def forward(self, hidden_states):\n        input_dtype = hidden_states.dtype\n        hidden_states = hidden_states.to(torch.float32)\n        variance = hidden_states.pow(2).mean(-1, keepdim=True)\n        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)\n        return self.weight * hidden_states.to(input_dtype)\n\n\nALL_LAYERNORM_LAYERS.append(KimiRMSNorm)\n\n\nclass KimiBlockSparseMLP(nn.Module):\n    def __init__(\n        self, config: KimiLinearConfig, hidden_size=None, intermediate_size=None\n    ):\n        super().__init__()\n        self.config = config\n        self.ffn_dim = (\n            config.intermediate_size if intermediate_size is None else intermediate_size\n        )\n        self.hidden_dim = config.hidden_size if hidden_size is None else hidden_size\n\n        self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)  # gate\n        self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)  # down\n        self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)  # up\n\n        self.act_fn = ACT2FN[config.hidden_act]\n\n    def forward(self, hidden_states):\n        current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(\n            hidden_states\n        )\n        current_hidden_states = self.w2(current_hidden_states)\n        return current_hidden_states\n\n\nclass KimiMLP(nn.Module):\n    def __init__(\n        self, config: KimiLinearConfig, hidden_size=None, intermediate_size=None\n    ):\n        super().__init__()\n        self.config = config\n        self.hidden_size = config.hidden_size if hidden_size is None else hidden_size\n        self.intermediate_size = (\n            config.intermediate_size if intermediate_size is None else intermediate_size\n        )\n        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)\n        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)\n        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)\n        self.act_fn = ACT2FN[config.hidden_act]\n\n    def forward(self, x):\n        down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))\n        return down_proj\n\n\ndef repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:\n    \"\"\"\n    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,\n    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)\n    \"\"\"\n    batch, num_key_value_heads, slen, head_dim = hidden_states.shape\n    if n_rep == 1:\n        return hidden_states\n    hidden_states = hidden_states[:, :, None, :, :].expand(\n        batch, num_key_value_heads, n_rep, slen, head_dim\n    )\n    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)\n\n\ndef eager_attention_forward(\n    module: nn.Module,\n    query: torch.Tensor,\n    key: torch.Tensor,\n    value: torch.Tensor,\n    attention_mask: Optional[torch.Tensor],\n    scaling: float,\n    dropout: float = 0.0,\n    **kwargs: Unpack[TransformersKwargs],\n):\n    key_states = repeat_kv(key, module.num_key_value_groups)\n    value_states = repeat_kv(value, module.num_key_value_groups)\n\n    attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling\n    if attention_mask is not None:\n        causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]\n        attn_weights = attn_weights + causal_mask\n\n    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(\n        query.dtype\n    )\n    attn_weights = nn.functional.dropout(\n        attn_weights, p=dropout, training=module.training\n    )\n    attn_output = torch.matmul(attn_weights, value_states)\n    attn_output = attn_output.transpose(1, 2).contiguous()\n\n    return attn_output, attn_weights\n\n\nclass KimiMLAAttention(nn.Module):\n    \"\"\"\n    Multi-Latent Attention adapted from deepseek-v3\n    \"\"\"\n\n    def __init__(self, config: KimiLinearConfig, layer_idx: int):\n        nn.Module.__init__(self)\n        self.config = config\n        self.layer_idx = layer_idx\n        self.hidden_size = config.hidden_size\n        self.num_heads = config.num_attention_heads\n        self.num_key_value_heads = config.num_key_value_heads\n        self.num_key_value_groups = self.num_heads // self.num_key_value_heads\n\n        self.rope_theta = config.rope_theta\n        self.attention_dropout = getattr(config, \"attention_dropout\", 0.0)\n\n        try:\n            self.q_lora_rank = config.q_lora_rank\n            self.qk_rope_head_dim = config.qk_rope_head_dim\n            self.kv_lora_rank = config.kv_lora_rank\n            self.v_head_dim = config.v_head_dim\n            self.qk_nope_head_dim = config.qk_nope_head_dim\n            self.q_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim\n            self.use_nope = config.mla_use_nope\n            self.scaling = self.q_head_dim ** (-0.5)\n        except Exception as e:\n            raise ValueError(\n                f\"Kimi MLA config is not found or not properly formatted: {e}\"\n            ) from e\n\n        assert self.q_lora_rank is None\n        self.q_proj = nn.Linear(\n            self.hidden_size,\n            self.num_heads * self.q_head_dim,\n            bias=False,\n        )\n        self.kv_a_proj_with_mqa = nn.Linear(\n            self.hidden_size,\n            self.kv_lora_rank + self.qk_rope_head_dim,\n            bias=False,\n        )\n        self.kv_a_layernorm = KimiRMSNorm(self.kv_lora_rank)\n        self.kv_b_proj = nn.Linear(\n            self.kv_lora_rank,\n            self.num_heads\n            * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),\n            bias=False,\n        )\n        self.o_proj = nn.Linear(\n            self.num_heads * self.v_head_dim,\n            self.hidden_size,\n            bias=False,\n        )\n        self.is_causal = True\n        assert self.use_nope\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        past_key_values: Optional[Cache] = None,\n        **kwargs,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        batch_size, seq_length = hidden_states.shape[:-1]\n        query_shape = (batch_size, seq_length, -1, self.q_head_dim)\n        key_shape = (\n            batch_size,\n            seq_length,\n            -1,\n            self.qk_nope_head_dim + self.v_head_dim,\n        )\n\n        q_states = self.q_proj(hidden_states)\n        q_states = q_states.view(query_shape).transpose(1, 2)\n        q_pass, q_rot = torch.split(\n            q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1\n        )\n\n        compressed_kv = self.kv_a_proj_with_mqa(hidden_states)\n        k_pass, k_rot = torch.split(\n            compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1\n        )\n\n        k_pass = (\n            self.kv_b_proj(self.kv_a_layernorm(k_pass)).view(key_shape).transpose(1, 2)\n        )\n        k_pass, value_states = torch.split(\n            k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1\n        )\n\n        k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim)\n        k_rot = k_rot.expand(*k_pass.shape[:-1], -1)\n\n        query_states = torch.cat((q_pass, q_rot), dim=-1)\n        key_states = torch.cat((k_pass, k_rot), dim=-1)\n\n        if past_key_values is not None:\n            key_states, value_states = past_key_values.update(\n                key_states, value_states, self.layer_idx\n            )\n\n        if (\n            self.config._attn_implementation == \"flash_attention_2\"\n            and self.q_head_dim != self.v_head_dim\n        ):\n            value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim])\n\n        attention_interface: Callable = eager_attention_forward\n        if self.config._attn_implementation != \"eager\":\n            attention_interface = ALL_ATTENTION_FUNCTIONS[\n                self.config._attn_implementation\n            ]\n\n        attn_output, _ = attention_interface(\n            self,\n            query_states,\n            key_states,\n            value_states,\n            attention_mask,\n            dropout=0.0 if not self.training else self.attention_dropout,\n            scaling=self.scaling,\n            **kwargs,\n        )\n\n        if (\n            self.config._attn_implementation == \"flash_attention_2\"\n            and self.q_head_dim != self.v_head_dim\n        ):\n            attn_output = attn_output[:, :, :, : self.v_head_dim]\n\n        attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous()\n        attn_output = self.o_proj(attn_output)\n        return attn_output\n\n\nclass KimiDeltaAttention(nn.Module):\n    def __init__(self, config: KimiLinearConfig, layer_idx: int):\n        super().__init__()\n        self.config = config\n        self.mode = \"chunk\"\n\n        self.hidden_size = config.hidden_size\n        self.conv_size = config.linear_attn_config[\"short_conv_kernel_size\"]\n        self.head_dim = config.linear_attn_config[\"head_dim\"]\n        self.num_heads = config.linear_attn_config[\"num_heads\"]\n        self.head_k_dim = self.head_dim\n        self.num_k_heads = self.num_heads\n\n        self.layer_idx = layer_idx\n\n        assert self.mode in [\"chunk\", \"fused_recurrent\"], (\n            f\"Not suppoerted mode `{self.mode}`.\"\n        )\n\n        projection_k_size = self.head_k_dim * self.num_k_heads\n        projection_size = self.head_dim * self.num_heads\n\n        self.q_proj = nn.Linear(self.hidden_size, projection_k_size, bias=False)\n        self.k_proj = nn.Linear(self.hidden_size, projection_k_size, bias=False)\n        self.v_proj = nn.Linear(self.hidden_size, projection_size, bias=False)\n\n        self.q_conv1d = ShortConvolution(\n            hidden_size=projection_k_size,\n            kernel_size=self.conv_size,\n            activation=\"silu\",\n        )\n        self.k_conv1d = ShortConvolution(\n            hidden_size=projection_k_size, kernel_size=self.conv_size, activation=\"silu\"\n        )\n        self.v_conv1d = ShortConvolution(\n            hidden_size=projection_size, kernel_size=self.conv_size, activation=\"silu\"\n        )\n\n        self.A_log = torch.nn.Parameter(\n            torch.log(\n                torch.empty(self.num_heads, dtype=torch.float32).uniform_(1, 16)\n            ).view(1, 1, -1, 1)\n        )\n\n        self.f_a_proj = nn.Linear(self.hidden_size, self.head_dim, bias=False)\n        self.f_b_proj = nn.Linear(self.head_dim, projection_size, bias=False)\n\n        self.dt_bias = nn.Parameter(torch.empty(projection_size, dtype=torch.float32))\n\n        self.b_proj = nn.Linear(self.hidden_size, self.num_heads, bias=False)\n\n        self.g_a_proj = nn.Linear(self.hidden_size, self.head_dim, bias=False)\n        self.g_b_proj = nn.Linear(self.head_dim, projection_size, bias=False)\n\n        self.o_norm = FusedRMSNormGated(\n            self.head_dim, eps=config.rms_norm_eps, activation=\"sigmoid\"\n        )\n        self.o_proj = nn.Linear(projection_size, self.hidden_size, bias=False)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        cache_params: Optional[KimiDynamicCache] = None,\n        **kwargs: Unpack[dict],\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:\n        if attention_mask is not None:\n            if attention_mask.dim() != 2:\n                attention_mask = kwargs.get(\"padding_mask\", None)\n\n            if attention_mask is not None and attention_mask.dim() != 2:\n                raise ValueError(\n                    \"attention_mask must be a 0-1 matrix of shape [batch_size, seq_len] \"\n                    \"(0 = padding). 3D masks are not supported here.\"\n                )\n        use_cache = cache_params is not None\n        batch_size, q_len, _ = hidden_states.shape\n        mode = \"fused_recurrent\" if q_len <= 64 else self.mode\n        if self.training:\n            assert mode == \"chunk\", \"Only chunk mode is supported in training.\"\n\n        cu_seqlens = kwargs.get(\"cu_seqlens\", None)\n        indices = None\n        if attention_mask is not None:\n            indices, cu_seqlens, _ = get_unpad_data(attention_mask[:, -q_len:])\n            hidden_states = index_first_axis(\n                rearrange(hidden_states, \"b s ... -> (b s) ...\"), indices\n            ).unsqueeze(0)\n\n        conv_state_q, conv_state_k, conv_state_v = None, None, None\n        recurrent_state = None\n        if cache_params is not None:\n            if cache_params.conv_states[self.layer_idx] is not None:\n                conv_state_q, conv_state_k, conv_state_v = cache_params.conv_states[\n                    self.layer_idx\n                ]\n            recurrent_state = cache_params.recurrent_states[self.layer_idx]\n        q, conv_state_q = self.q_conv1d(\n            x=self.q_proj(hidden_states),\n            cache=conv_state_q,\n            output_final_state=use_cache,\n            cu_seqlens=cu_seqlens,\n        )\n        k, conv_state_k = self.k_conv1d(\n            x=self.k_proj(hidden_states),\n            cache=conv_state_k,\n            output_final_state=use_cache,\n            cu_seqlens=cu_seqlens,\n        )\n        v, conv_state_v = self.v_conv1d(\n            x=self.v_proj(hidden_states),\n            cache=conv_state_v,\n            output_final_state=use_cache,\n            cu_seqlens=cu_seqlens,\n        )\n        g = self.f_b_proj(self.f_a_proj(hidden_states))\n        g = fused_kda_gate(g, self.A_log, self.head_dim, g_bias=self.dt_bias)\n        beta = self.b_proj(hidden_states).float().sigmoid()\n\n        q, k = map(\n            lambda x: rearrange(x, \"... (h d) -> ... h d\", d=self.head_k_dim), (q, k)\n        )\n        v = rearrange(v, \"... (h d) -> ... h d\", d=self.head_dim)\n\n        if mode == \"chunk\":\n            o, recurrent_state = chunk_kda(\n                q=q,\n                k=k,\n                v=v,\n                g=g,\n                beta=beta,\n                initial_state=recurrent_state,\n                output_final_state=True,\n                use_qk_l2norm_in_kernel=True,\n                cu_seqlens=cu_seqlens,\n            )\n        else:\n            o, recurrent_state = fused_recurrent_kda(\n                q=q,\n                k=k,\n                v=v,\n                g=g,\n                beta=beta,\n                initial_state=recurrent_state,\n                output_final_state=True,\n                use_qk_l2norm_in_kernel=True,\n                cu_seqlens=cu_seqlens,\n            )\n        if cache_params is not None:\n            cache_params.recurrent_states[self.layer_idx] = recurrent_state\n            cache_params.conv_states[self.layer_idx] = (\n                conv_state_q,\n                conv_state_k,\n                conv_state_v,\n            )\n\n        g = self.g_b_proj(self.g_a_proj(hidden_states))\n        g = rearrange(g, \"... (h d) -> ... h d\", d=self.head_dim)\n        o = self.o_norm(o, g)\n\n        o = rearrange(o, \"b t h d -> b t (h d)\")\n        o = self.o_proj(o)\n        if attention_mask is not None:\n            o = pad_input(o.squeeze(0), indices, batch_size, q_len)\n\n        return o\n\n\nclass KimiMoEGate(nn.Module):\n    \"\"\"\n    MoE Gate that returns router logits.\n    Routing decisions are made in KimiSparseMoeBlock.\n    \"\"\"\n\n    def __init__(self, config: KimiLinearConfig):\n        super().__init__()\n        self.config = config\n        self.num_experts = config.num_experts\n        self.gating_dim = config.hidden_size\n\n        self.weight = nn.Parameter(torch.empty((self.num_experts, self.gating_dim)))\n        self.e_score_correction_bias = nn.Parameter(torch.zeros((self.num_experts,)))\n        self.reset_parameters()\n\n    def reset_parameters(self) -> None:\n        import torch.nn.init as init\n\n        init.kaiming_uniform_(self.weight, a=math.sqrt(5))\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Args:\n            hidden_states: [batch_size, seq_len, hidden_dim]\n\n        Returns:\n            router_logits: [batch_size * seq_len, num_experts]\n        \"\"\"\n        _, _, h = hidden_states.shape\n        hidden_states = hidden_states.view(-1, h)\n        router_logits = F.linear(\n            hidden_states.type(torch.float32), self.weight.type(torch.float32), None\n        )\n        return router_logits\n\n    # def forward(self, hidden_states):\n    #     bsz, seq_len, h = hidden_states.shape\n    #     # compute gating score\n    #     hidden_states = hidden_states.view(-1, h)\n    #     logits = F.linear(\n    #         hidden_states.type(torch.float32), self.weight.type(\n    #             torch.float32), None\n    #     )\n    #     if self.moe_router_activation_func == \"sigmoid\":\n    #         scores = logits.sigmoid()\n    #     elif self.moe_router_activation_func == \"softmax\":\n    #         scores = logits.softmax(dim=1)\n    #     else:\n    #         raise NotImplementedError(\n    #             f\"insupportable scoring function for MoE gating: {self.moe_router_activation_func}\"\n    #         )\n\n    #     # select top-k experts\n    #     assert not self.training\n    #     scores_for_choice = scores.view(bsz * seq_len, -1)\n    #     scores_for_choice += self.e_score_correction_bias.unsqueeze(0)\n    #     group_scores = (\n    #         scores_for_choice.view(\n    #             bsz * seq_len, self.num_expert_group, -1).topk(2, dim=-1)[0].sum(dim=-1)\n    #     )  # [n, num_expert_group]\n    #     group_idx = torch.topk(\n    #         group_scores, k=self.topk_group, dim=-1, sorted=False\n    #     )[\n    #         1\n    #     ]  # [n, top_k_group]\n    #     group_mask = torch.zeros_like(group_scores)  # [n, num_expert_group]\n    #     group_mask.scatter_(1, group_idx, 1)  # [n, num_expert_group]\n    #     score_mask = (\n    #         group_mask.unsqueeze(-1)\n    #         .expand(\n    #             bsz * seq_len, self.num_expert_group, self.num_experts // self.num_expert_group\n    #         )\n    #         .reshape(bsz * seq_len, -1)\n    #     )  # [n, e]\n    #     tmp_scores = scores_for_choice.masked_fill(\n    #         ~score_mask.bool(), 0.0)  # [n, e]\n    #     _, topk_idx = torch.topk(\n    #         tmp_scores, k=self.top_k, dim=-1, sorted=False\n    #     )\n    #     topk_weight = scores.gather(1, topk_idx)\n\n    #     # norm gate to sum 1\n    #     if self.top_k > 1 and self.moe_renormalize:\n    #         denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20\n    #         topk_weight = topk_weight / denominator\n    #     # must multiply the scaling factor\n    #     topk_weight = topk_weight * self.routed_scaling_factor\n\n    #     return topk_idx, topk_weight\n\n\n# class KimiSparseMoeBlock(nn.Module):\n#     \"\"\"\n#     Adapted from Deepseek-V3's MOE implementation\n#     The namings are consistent with Kimi's version.\n#     \"\"\"\n\n#     def __init__(self, config: KimiLinearConfig):\n#         super().__init__()\n#         self.config = config\n#         self.hidden_dim = config.hidden_size\n#         self.num_experts = config.num_experts\n#         self.top_k = config.num_experts_per_token\n#         self.moe_renormalize = config.moe_renormalize\n\n#         self.ep_size = 1\n#         self.experts_per_rank = config.num_experts\n#         self.ep_rank = 0\n#         self.experts = nn.ModuleList(\n#             [\n#                 KimiBlockSparseMLP(\n#                     config, intermediate_size=config.moe_intermediate_size\n#                 )\n#                 for _ in range(config.num_experts)\n#             ]\n#         )\n#         self.gate = KimiMoEGate(config)\n#         if config.num_shared_experts is not None:\n#             intermediate_size = config.moe_intermediate_size * config.num_shared_experts\n#             self.shared_experts = KimiMLP(\n#                 config=config, intermediate_size=intermediate_size\n#             )\n\n#     def forward(self, hidden_states):\n#         identity = hidden_states\n#         orig_shape = hidden_states.shape\n#         topk_idx, topk_weight = self.gate(hidden_states)\n#         hidden_states = hidden_states.view(-1, hidden_states.shape[-1])\n#         flat_topk_idx = topk_idx.view(-1)\n#         if not self.training:\n#             y = self.moe_infer(hidden_states, topk_idx,\n#                                topk_weight).view(*orig_shape)\n#         else:\n#             raise NotImplementedError(\n#                 \"Training mode is not supported in KimiSparseMoeBlock\")\n#         if self.config.num_shared_experts is not None:\n#             y = y + self.shared_experts(identity)\n#         return y\n\n#     @torch.no_grad()\n#     def moe_infer(self, x, topk_ids, topk_weight):\n#         cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))\n#         cnts.scatter_(1, topk_ids, 1)\n#         tokens_per_expert = cnts.sum(dim=0)\n#         idxs = topk_ids.view(-1).argsort()\n#         sorted_tokens = x[idxs // topk_ids.shape[1]]\n\n#         tokens_per_expert = tokens_per_expert.cpu().numpy()\n\n#         outputs = []\n#         start_idx = 0\n#         for i, num_tokens in enumerate(tokens_per_expert):\n#             end_idx = start_idx + num_tokens\n#             if num_tokens == 0:\n#                 continue\n#             expert = self.experts[i + self.ep_rank * self.experts_per_rank]\n#             tokens_for_this_expert = sorted_tokens[start_idx:end_idx]\n#             expert_out = expert(tokens_for_this_expert)\n#             outputs.append(expert_out)\n#             start_idx = end_idx\n\n#         outs = torch.cat(outputs, dim=0) if len(\n#             outputs) else sorted_tokens.new_empty(0)\n\n#         new_x = torch.empty_like(outs)\n#         new_x[idxs] = outs\n#         final_out = (\n#             new_x.view(*topk_ids.shape, -1)\n#             .type(topk_weight.dtype)\n#             .mul_(topk_weight.unsqueeze(dim=-1))\n#             .sum(dim=1)\n#             .type(new_x.dtype)\n#         )\n#         return final_out\n\n\n# Replace the KimiSparseMoeBlock class with this new version\nclass KimiSparseMoeBlock(nn.Module):\n    \"\"\"\n    MoE block adapted from Deepseek-V3.\n    Returns only hidden_states - router_logits captured by OutputRecorder.\n    \"\"\"\n\n    def __init__(self, config: KimiLinearConfig):\n        super().__init__()\n        self.config = config\n        self.hidden_dim = config.hidden_size\n        self.num_experts = config.num_experts\n        self.top_k = config.num_experts_per_token\n        self.moe_renormalize = config.moe_renormalize\n        self.routed_scaling_factor = config.routed_scaling_factor\n        self.num_expert_group = getattr(config, \"num_expert_group\", 1)\n        self.topk_group = getattr(config, \"topk_group\", 1)\n\n        self.experts = nn.ModuleList(\n            [\n                KimiBlockSparseMLP(\n                    config, intermediate_size=config.moe_intermediate_size\n                )\n                for _ in range(config.num_experts)\n            ]\n        )\n        self.gate = KimiMoEGate(config)\n\n        if config.num_shared_experts is not None:\n            intermediate_size = config.moe_intermediate_size * config.num_shared_experts\n            self.shared_experts = KimiMLP(\n                config=config, intermediate_size=intermediate_size\n            )\n\n    def route_tokens_to_experts(\n        self,\n        router_logits: torch.Tensor,\n    ) -> tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\n        Compute routing decisions from router logits.\n\n        Args:\n            router_logits: [num_tokens, num_experts]\n\n        Returns:\n            topk_idx: [num_tokens, top_k]\n            topk_weight: [num_tokens, top_k]\n        \"\"\"\n        num_tokens = router_logits.shape[0]\n\n        if self.training:\n            # Training: use softmax for standard aux loss compatibility\n            scores = F.softmax(router_logits, dim=-1, dtype=torch.float32)\n            topk_weight, topk_idx = torch.topk(scores, self.top_k, dim=-1, sorted=False)\n        else:\n            # Inference: use original sigmoid + group selection\n            scores = router_logits.sigmoid()\n            scores_for_choice = scores + self.gate.e_score_correction_bias.unsqueeze(0)\n\n            # Group-based selection\n            group_scores = (\n                scores_for_choice.view(num_tokens, self.num_expert_group, -1)\n                .topk(2, dim=-1)[0]\n                .sum(dim=-1)\n            )\n            group_idx = torch.topk(\n                group_scores, k=self.topk_group, dim=-1, sorted=False\n            )[1]\n            group_mask = torch.zeros_like(group_scores)\n            group_mask.scatter_(1, group_idx, 1)\n            score_mask = (\n                group_mask.unsqueeze(-1)\n                .expand(\n                    num_tokens,\n                    self.num_expert_group,\n                    self.num_experts // self.num_expert_group,\n                )\n                .reshape(num_tokens, -1)\n            )\n            tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)\n            _, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False)\n            topk_weight = scores.gather(1, topk_idx)\n\n        # Normalize and scale\n        if self.top_k > 1 and self.moe_renormalize:\n            denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20\n            topk_weight = topk_weight / denominator\n        topk_weight = topk_weight * self.routed_scaling_factor\n\n        return topk_idx, topk_weight\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Forward pass returning only hidden_states.\n        Router logits are captured by OutputRecorder for aux loss.\n        \"\"\"\n        identity = hidden_states\n        batch_size, seq_len, hidden_dim = hidden_states.shape\n        num_tokens = batch_size * seq_len\n\n        # Flatten for routing\n        hidden_states_flat = hidden_states.view(num_tokens, hidden_dim)\n\n        # Get router logits - OutputRecorder captures this!\n        router_logits = self.gate(hidden_states)\n\n        # Get routing decisions\n        topk_idx, topk_weight = self.route_tokens_to_experts(router_logits)\n\n        if self.training:\n            final_hidden_states = self._training_forward(\n                hidden_states_flat, topk_idx, topk_weight, num_tokens, hidden_dim\n            )\n        else:\n            final_hidden_states = self._inference_forward(\n                hidden_states_flat, topk_idx, topk_weight\n            )\n\n        final_hidden_states = final_hidden_states.view(batch_size, seq_len, hidden_dim)\n\n        # Add shared experts if present\n        if self.config.num_shared_experts is not None:\n            final_hidden_states = final_hidden_states + self.shared_experts(identity)\n\n        return final_hidden_states\n\n    def _training_forward(\n        self,\n        hidden_states: torch.Tensor,\n        topk_idx: torch.Tensor,\n        topk_weight: torch.Tensor,\n        num_tokens: int,\n        hidden_dim: int,\n    ) -> torch.Tensor:\n        \"\"\"\n        Differentiable training forward using scatter-gather pattern.\n        \"\"\"\n        # Flatten expert indices: [num_tokens * top_k]\n        flat_topk_idx = topk_idx.view(-1)\n\n        # Sort by expert index to group tokens going to same expert\n        sorted_indices = torch.argsort(flat_topk_idx)\n        inverse_permutation = torch.argsort(sorted_indices)\n\n        # Each token appears top_k times (once per expert choice)\n        token_indices = torch.arange(\n            num_tokens, device=hidden_states.device\n        ).repeat_interleave(self.top_k)\n\n        # Gather tokens and weights in sorted order\n        shuffled_tokens = hidden_states[token_indices[sorted_indices]]\n        shuffled_weights = topk_weight.view(-1)[sorted_indices].unsqueeze(-1)\n\n        # Count tokens per expert\n        tokens_per_expert = F.one_hot(flat_topk_idx, num_classes=self.num_experts).sum(\n            dim=0\n        )\n\n        # Process each expert's batch\n        expert_outputs = []\n        current_pos = 0\n        for i in range(self.num_experts):\n            num_tokens_for_expert = tokens_per_expert[i].item()\n            if num_tokens_for_expert == 0:\n                continue\n\n            expert_input = shuffled_tokens[\n                current_pos : current_pos + num_tokens_for_expert\n            ]\n            expert_output = self.experts[i](expert_input)\n            expert_outputs.append(expert_output)\n            current_pos += num_tokens_for_expert\n\n        # Concatenate all outputs\n        if expert_outputs:\n            concatenated_outputs = torch.cat(expert_outputs, dim=0)\n        else:\n            concatenated_outputs = torch.zeros(\n                num_tokens * self.top_k,\n                hidden_dim,\n                device=hidden_states.device,\n                dtype=hidden_states.dtype,\n            )\n\n        # Apply weights while still in sorted order\n        weighted_outputs = concatenated_outputs * shuffled_weights\n\n        # Unsort back to original token order\n        unshuffled_outputs = weighted_outputs[inverse_permutation]\n\n        # Sum contributions from all top_k experts for each token\n        final_hidden_states = unshuffled_outputs.view(\n            num_tokens, self.top_k, hidden_dim\n        ).sum(dim=1)\n\n        return final_hidden_states\n\n    @torch.no_grad()\n    def _inference_forward(\n        self,\n        hidden_states: torch.Tensor,\n        topk_idx: torch.Tensor,\n        topk_weight: torch.Tensor,\n    ) -> torch.Tensor:\n        \"\"\"\n        Optimized inference forward (original implementation).\n        \"\"\"\n        cnts = topk_idx.new_zeros((topk_idx.shape[0], len(self.experts)))\n        cnts.scatter_(1, topk_idx, 1)\n        tokens_per_expert = cnts.sum(dim=0)\n        idxs = topk_idx.view(-1).argsort()\n        sorted_tokens = hidden_states[idxs // topk_idx.shape[1]]\n\n        tokens_per_expert_list = tokens_per_expert.cpu().numpy()\n\n        outputs = []\n        start_idx = 0\n        for i, num_tokens in enumerate(tokens_per_expert_list):\n            end_idx = start_idx + num_tokens\n            if num_tokens == 0:\n                continue\n            expert = self.experts[i]\n            tokens_for_expert = sorted_tokens[start_idx:end_idx]\n            expert_out = expert(tokens_for_expert)\n            outputs.append(expert_out)\n            start_idx = end_idx\n\n        outs = torch.cat(outputs, dim=0) if outputs else sorted_tokens.new_empty(0)\n\n        new_x = torch.empty_like(outs)\n        new_x[idxs] = outs\n        final_out = (\n            new_x.view(*topk_idx.shape, -1)\n            .type(topk_weight.dtype)\n            .mul_(topk_weight.unsqueeze(dim=-1))\n            .sum(dim=1)\n            .type(new_x.dtype)\n        )\n        return final_out\n\n\nclass KimiDecoderLayer(nn.Module):\n    def __init__(self, config: KimiLinearConfig, layer_idx: int):\n        super().__init__()\n        self.hidden_size = config.hidden_size\n        self.config = config\n        if config.is_kda_layer(layer_idx):\n            self.is_linear_attn = True\n            self.self_attn = KimiDeltaAttention(config=config, layer_idx=layer_idx)\n        elif config.is_mla:\n            self.is_linear_attn = False\n            self.self_attn = KimiMLAAttention(config=config, layer_idx=layer_idx)\n        else:\n            raise NotImplementedError\n        if (\n            config.num_experts is not None\n            and layer_idx >= config.first_k_dense_replace\n            and layer_idx % getattr(config, \"moe_layer_freq\", 1) == 0\n        ):\n            self.block_sparse_moe = KimiSparseMoeBlock(config)\n        else:\n            self.mlp = KimiMLP(config)\n        self.input_layernorm = KimiRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n        self.post_attention_layernorm = KimiRMSNorm(\n            config.hidden_size, eps=config.rms_norm_eps\n        )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Tuple[torch.Tensor]] = None,\n        output_attentions: Optional[bool] = False,\n        use_cache: Optional[bool] = False,\n        **kwargs: Unpack[FlashAttentionKwargs],\n    ) -> Tuple[\n        torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]\n    ]:\n        \"\"\"\n        Args:\n            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`\n            attention_mask (`torch.FloatTensor`, *optional*): attention mask of size\n                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.\n            output_attentions (`bool`, *optional*):\n                Whether or not to return the attentions tensors of all attention layers. See `attentions` under\n                returned tensors for more detail.\n            use_cache (`bool`, *optional*):\n                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding\n                (see `past_key_values`).\n            past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states\n        \"\"\"\n\n        residual = hidden_states\n\n        hidden_states = self.input_layernorm(hidden_states)\n\n        # Self Attention\n        if self.is_linear_attn is False:\n            hidden_states = self.self_attn(\n                hidden_states=hidden_states,\n                attention_mask=attention_mask,\n                position_ids=position_ids,\n                past_key_values=past_key_values,\n                output_attentions=output_attentions,\n                use_cache=use_cache,\n                **kwargs,\n            )\n        else:\n            hidden_states = self.self_attn(\n                hidden_states=hidden_states,\n                attention_mask=attention_mask,\n                cache_params=past_key_values,\n                output_attentions=output_attentions,\n                use_cache=use_cache,\n                **kwargs,\n            )\n        hidden_states = residual + hidden_states\n\n        # Fully Connected\n        residual = hidden_states\n        hidden_states = self.post_attention_layernorm(hidden_states)\n        if hasattr(self, \"block_sparse_moe\"):\n            hidden_states = self.block_sparse_moe(hidden_states)\n        else:\n            hidden_states = self.mlp(hidden_states)\n        hidden_states = residual + hidden_states\n\n        return hidden_states\n\n\nclass KimiPreTrainedModel(PreTrainedModel):\n    config_class = KimiLinearConfig\n    base_model_prefix = \"model\"\n    supports_gradient_checkpointing = True\n    _no_split_modules = [\"KimiDecoderLayer\"]\n    _skip_keys_device_placement = \"past_key_values\"\n    _supports_flash_attn_2 = True\n    _can_record_outputs = {\n        \"router_logits\": OutputRecorder(KimiMoEGate, index=0),\n        \"hidden_states\": KimiDecoderLayer,\n        \"attentions\": KimiMLAAttention,\n    }\n    _is_stateful = True\n\n    def _init_weights(self, module):\n        std = self.config.initializer_range\n        if isinstance(module, nn.Linear):\n            module.weight.data.normal_(mean=0.0, std=std)\n            if module.bias is not None:\n                module.bias.data.zero_()\n        elif isinstance(module, nn.Embedding):\n            module.weight.data.normal_(mean=0.0, std=std)\n            if module.padding_idx is not None:\n                module.weight.data[module.padding_idx].zero_()\n\n\nclass KimiLinearModel(KimiPreTrainedModel):\n    def __init__(self, config: KimiLinearConfig):\n        super().__init__(config)\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n\n        self.embed_tokens = nn.Embedding(\n            config.vocab_size, config.hidden_size, self.padding_idx\n        )\n        self.layers = nn.ModuleList(\n            [\n                KimiDecoderLayer(config, layer_idx)\n                for layer_idx in range(config.num_hidden_layers)\n            ]\n        )\n        self.norm = KimiRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n\n        if getattr(config, \"_attn_implementation\", None) is not None:\n            if config._attn_implementation != \"flash_attention_2\":\n                logger.warning_once(\n                    f\"Ignoring the provided attention implementation {config._attn_implementation}\"\n                )\n                logger.warning_once(\"Using flash_attention_2 backend instead.\")\n                config._attn_implementation = \"flash_attention_2\"\n        else:\n            config._attn_implementation = \"flash_attention_2\"\n\n        self._use_flash_attention_2 = config._attn_implementation == \"flash_attention_2\"\n        self.gradient_checkpointing = False\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def _update_linear_attn_mask(self, attention_mask, cache_position):\n        \"\"\"\n        NOTE: Left-padding is used for linear attention mask.\n        No need for zeroing states when\n            1. Cached forward\n            2. Attending to all inputs\n        \"\"\"\n        linear_attn_mask = attention_mask\n        if cache_position[0] > 0 or (\n            attention_mask is not None and torch.all(attention_mask == 1)\n        ):\n            linear_attn_mask = None\n        return linear_attn_mask\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Cache] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        **kwargs: Unpack[TransformersKwargs],\n    ) -> Union[Tuple, BaseModelOutputWithPast]:\n        use_cache = use_cache if use_cache is not None else self.config.use_cache\n\n        if (input_ids is None) and (inputs_embeds is None):\n            raise ValueError(\n                \"You must specify exactly one of input_ids or inputs_embeds\"\n            )\n\n        # Get inputs_embeds\n        if inputs_embeds is None:\n            inputs_embeds = self.embed_tokens(input_ids)\n\n        if use_cache and past_key_values is None:\n            past_key_values = KimiDynamicCache(config=self.config)\n\n        if cache_position is None:\n            past_seen_tokens = (\n                past_key_values.get_seq_length() if past_key_values is not None else 0\n            )\n            cache_position: torch.Tensor = torch.arange(\n                past_seen_tokens,\n                past_seen_tokens + inputs_embeds.shape[1],\n                device=inputs_embeds.device,\n            )\n\n        if position_ids is None:\n            position_ids = cache_position.unsqueeze(0)\n\n        causal_mask = create_causal_mask(\n            config=self.config,\n            input_embeds=inputs_embeds,\n            attention_mask=attention_mask,\n            cache_position=cache_position,\n            past_key_values=past_key_values,\n            position_ids=position_ids,\n        )\n        linear_attn_mask = self._update_linear_attn_mask(attention_mask, cache_position)\n\n        hidden_states = inputs_embeds\n        if past_key_values is not None:\n            assert isinstance(past_key_values, KimiDynamicCache)\n\n        for decoder_layer in self.layers:\n            layer_mask = (\n                linear_attn_mask if decoder_layer.is_linear_attn else causal_mask\n            )\n\n            hidden_states = decoder_layer(\n                hidden_states,\n                attention_mask=layer_mask,\n                past_key_values=past_key_values,\n                cache_position=cache_position,\n                **kwargs,\n            )\n\n        hidden_states = self.norm(hidden_states)\n\n        return BaseModelOutputWithPast(\n            last_hidden_state=hidden_states,\n            past_key_values=past_key_values,\n        )\n\n\nclass KimiLinearForCausalLM(KimiPreTrainedModel, GenerationMixin):\n    _tied_weights_keys = [\"lm_head.weight\"]\n\n    def __init__(self, config):\n        super().__init__(config)\n        self.model = KimiLinearModel(config)\n        self.vocab_size = config.vocab_size\n        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    @can_return_tuple\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        generation_mode: Optional[bool] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n        **kwargs: Unpack[TransformersKwargs],\n    ) -> Union[Tuple, CausalLMOutputWithPast]:\n        r\"\"\"\n        Args:\n            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,\n                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored\n                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.\n\n        Returns:\n\n        Example:\n\n        ```python\n        >>> from transformers import AutoTokenizer, KimiLinearForCausalLM\n\n        >>> model = KimiLinearForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)\n        >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)\n\n        >>> prompt = \"Hey, are you conscious? Can you talk to me?\"\n        >>> inputs = tokenizer(prompt, return_tensors=\"pt\")\n\n        >>> # Generate\n        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)\n        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]\n        \"Hey, are you conscious? Can you talk to me?\\nI'm not conscious, but I can talk to you.\"\n        ```\"\"\"\n        outputs = self.model(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            cache_position=cache_position,\n            **kwargs,\n        )\n\n        logits = outputs[0]\n        if generation_mode:\n            logits = logits[:, -1:]\n        logits = self.lm_head(logits)\n\n        loss = None\n        if labels is not None:\n            loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)\n\n        aux_loss = None\n        if kwargs.get(\"output_router_logits\", False):\n            aux_loss = load_balancing_loss_func(\n                outputs.router_logits,\n                num_experts=self.config.num_experts,\n                top_k=self.config.num_experts_per_token,\n                attention_mask=attention_mask,\n            )\n            if loss is not None:\n                loss = loss + self.config.router_aux_loss_coef * aux_loss\n\n        return MoeCausalLMOutputWithPast(\n            loss=loss,\n            aux_loss=aux_loss,\n            logits=logits,\n            past_key_values=outputs.past_key_values,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n"
  },
  {
    "path": "src/axolotl/monkeypatch/models/kimi_linear/patch_kimi_linear.py",
    "content": "import importlib.resources\nimport importlib.util\nimport sys\nfrom pathlib import Path\n\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\nKIMI_PATCH_PACKAGE = \"axolotl.monkeypatch.models.kimi_linear\"\n\n\ndef get_patch_file_path(package_dot_path: str, filename: str) -> Path:\n    \"\"\"\n    Gets the absolute path to a patch file using importlib.resources.files.\n    \"\"\"\n    try:\n        return importlib.resources.files(package_dot_path) / filename\n    except ModuleNotFoundError:\n        return None\n\n\ndef _load_local_module(module_name: str, filename: str):\n    \"\"\"Helper to load a local module if not already loaded.\"\"\"\n    if module_name in sys.modules:\n        return sys.modules[module_name]\n\n    patch_path = get_patch_file_path(KIMI_PATCH_PACKAGE, filename)\n    if patch_path and patch_path.exists():\n        spec = importlib.util.spec_from_file_location(module_name, patch_path)\n        module = importlib.util.module_from_spec(spec)\n        sys.modules[module_name] = module\n        spec.loader.exec_module(module)\n        return module\n    return None\n\n\ndef _patch_get_class_in_module():\n    \"\"\"\n    Core patch function that hijacks Transformers' dynamic module loading.\n    \"\"\"\n    from transformers.dynamic_module_utils import get_class_in_module\n\n    if hasattr(get_class_in_module, \"_axolotl_patched\"):\n        return\n\n    original_get_class_in_module = get_class_in_module\n\n    # Mapping of module path patterns to (module_name, filename)\n    KIMI_MODULE_MAP = {\n        \"configuration_kimi\": (\"configuration_kimi\", \"configuration_kimi.py\"),\n        \"modeling_kimi\": (\"modeling_kimi\", \"modeling_kimi.py\"),\n        \"tokenization_kimi\": (\"tokenization_kimi\", \"tokenization_kimi.py\"),\n    }\n\n    def patched_get_class_in_module(class_name, module_path, **kwargs):\n        \"\"\"Patched version that returns our local modules instead of remote ones.\"\"\"\n        for pattern, (module_name, filename) in KIMI_MODULE_MAP.items():\n            if pattern in module_path:\n                module = _load_local_module(module_name, filename)\n                if module:\n                    return getattr(module, class_name)\n                break  # Pattern matched but file not found, fall through\n\n        return original_get_class_in_module(class_name, module_path, **kwargs)\n\n    import transformers.dynamic_module_utils\n\n    transformers.dynamic_module_utils.get_class_in_module = patched_get_class_in_module\n    patched_get_class_in_module._axolotl_patched = True\n\n\ndef patch_kimi():\n    \"\"\"\n    Apply all Kimi patches.\n    Must be called BEFORE loading config/tokenizer/model.\n    \"\"\"\n    _patch_get_class_in_module()\n    LOG.info(\"Kimi patches applied successfully!\")\n\n\n# Keep these for backward compatibility if needed\npatch_kimi_config = patch_kimi\npatch_kimi_tokenizer = patch_kimi\npatch_kimi_model = patch_kimi\n"
  },
  {
    "path": "src/axolotl/monkeypatch/models/kimi_linear/tokenization_kimi.py",
    "content": "\"\"\"\nAdapted Kimi-Linear tokenizer to use proper template defaults and misc fixes.\n\nSource: https://huggingface.co/moonshotai/Kimi-Linear-48B-A3B-Instruct/blob/main/tokenization_kimi.py\nRevision: 919416f\n\"\"\"\n\nimport os\nfrom logging import getLogger\nfrom pathlib import Path\nfrom shutil import copyfile\nfrom typing import (\n    Any,\n    Dict,\n    Iterator,\n    List,\n    Optional,\n    Tuple,\n    Union,\n    cast,\n)\n\nimport tiktoken\nfrom tiktoken.load import load_tiktoken_bpe\nfrom tokenizers import AddedToken\nfrom transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode\nfrom transformers.tokenization_utils import PreTrainedTokenizer\n\nlogger = getLogger(__name__)\nVOCAB_FILES_NAMES = {\"vocab_file\": \"tiktoken.model\"}\n\n\nclass TikTokenTokenizer(PreTrainedTokenizer):\n    \"\"\"\n    Tokenizing and encoding/decoding text using the Tiktoken tokenizer. See megatron/tokenizer/tiktoken_tokenizer.py.\n\n    This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to\n    this superclass for more information regarding those methods.\n\n    Args:\n        vocab_file (`str`):\n            The path to the Tiktoken model file.\n        bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `\"<|begin_of_text|>\",`):\n            The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.\n        eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `\"<|end_of_text|>\"`):\n            The end of sequence token.\n        unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `\"<|reserved_special_token_249|>\"`):\n            The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this\n            token instead. The second to last item in special_tokens.\n        pad_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `\"<|reserved_special_token_250|>\"`):\n            The token used for padding, for example when batching sequences of different lengths.\n        additional_special_tokens (list of `str`, *optional*):\n            A tuple or a list of additional tokens, which will be marked as `special`, meaning that they will be\n            skipped when decoding if `skip_special_tokens` is set to `True`.\n    \"\"\"\n\n    vocab_files_names = VOCAB_FILES_NAMES\n\n    model_input_names = [\"input_ids\", \"attention_mask\"]\n\n    special_tokens: Dict[str, int]\n\n    num_reserved_special_tokens = 256\n\n    pat_str = \"|\".join(\n        [\n            r\"\"\"[\\p{Han}]+\"\"\",\n            r\"\"\"[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}&&[^\\p{Han}]]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}&&[^\\p{Han}]]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?\"\"\",\n            r\"\"\"[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}&&[^\\p{Han}]]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}&&[^\\p{Han}]]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?\"\"\",\n            r\"\"\"\\p{N}{1,3}\"\"\",\n            r\"\"\" ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*\"\"\",\n            r\"\"\"\\s*[\\r\\n]+\"\"\",\n            r\"\"\"\\s+(?!\\S)\"\"\",\n            r\"\"\"\\s+\"\"\",\n        ]\n    )\n\n    def __init__(\n        self,\n        vocab_file,\n        bos_token: Union[str, AddedToken] = \"[BOS]\",  # nosec: B107\n        eos_token: Union[str, AddedToken] = \"[EOS]\",  # nosec: B107\n        unk_token: Union[str, AddedToken, None] = None,\n        pad_token: Union[str, AddedToken, None] = None,\n        additional_special_tokens: List[str] = None,\n        added_tokens_decoder: Optional[dict] = None,\n        **kwargs,\n    ):\n        assert os.path.isfile(vocab_file), vocab_file\n\n        if additional_special_tokens is None:\n            additional_special_tokens = [\n                \"<|im_end|>\",\n                \"<|im_user|>\",\n                \"<|im_assistant|>\",\n                \"<|start_header_id|>\",\n                \"<|end_header_id|>\",\n                \"[EOT]\",\n                \"<|im_system|>\",\n                \"<|im_middle|>\",\n            ]\n\n        special_tokens_mapping = {\n            i: added_tokens_decoder[i].content for i in added_tokens_decoder\n        }\n\n        self.vocab_file = vocab_file\n        mergeable_ranks = load_tiktoken_bpe(vocab_file)\n        num_base_tokens = len(mergeable_ranks)\n        self.special_tokens = {\n            special_tokens_mapping.get(i, f\"<|reserved_token_{i}|>\"): i\n            for i in range(\n                num_base_tokens, num_base_tokens + self.num_reserved_special_tokens + 2\n            )\n        }\n\n        self.model = tiktoken.Encoding(\n            name=Path(vocab_file).name,\n            pat_str=self.pat_str,\n            mergeable_ranks=mergeable_ranks,\n            special_tokens=self.special_tokens,\n        )\n        logger.info(f\"Reloaded tiktoken model from {vocab_file}\")\n\n        self.n_words: int = self.model.n_vocab\n        # BOS / EOS token IDs\n        self.bos_id: int = self.special_tokens[str(bos_token)]\n        self.eos_id: int = self.special_tokens[str(eos_token)]\n        logger.info(\n            f\"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}\"\n        )\n\n        self.pad_id: int = self.special_tokens[str(pad_token)]\n        self.unk_id: int = self.special_tokens[str(unk_token)]\n\n        self.byte_encoder = bytes_to_unicode()\n        self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}\n\n        self.decoder = {}\n        for i in range(self.n_words):\n            # Taken from https://gist.github.com/xenova/a452a6474428de0182b17605a98631ee\n            decoding = \"\".join(\n                [\n                    self.byte_encoder[ord(char)]\n                    for char in self.model.decode_single_token_bytes(i).decode(\n                        \"latin-1\"\n                    )\n                ]\n            )\n            self.decoder[i] = decoding\n\n        self.encoder = {}\n        for i in range(self.n_words):\n            if i in self.decoder:\n                self.encoder[self.decoder[i]] = i\n\n        super().__init__(\n            bos_token=bos_token,\n            eos_token=eos_token,\n            unk_token=unk_token,\n            pad_token=pad_token,\n            additional_special_tokens=additional_special_tokens,\n            **kwargs,\n        )\n        self.all_special_ids_set = set(self.all_special_ids)\n\n    def encode(\n        self, text: str, allow_special_tokens: bool = True, **kwargs\n    ) -> List[int]:\n        \"\"\"\n        Encodes a string into a list of token IDs.\n\n        Args:\n            text (str): The input string to be encoded.\n\n        Returns:\n            list[int]: A list of token IDs.\n        \"\"\"\n        # If there are other args, we should call super().encode because there are a lot of code\n        # to handle those args. supper().encode finally will call _tokenize and _convert_token_to_id.\n        # NOTE: our encode method is not compatible with the super().encode method,\n        #   e.g. split_special_tokens' default is True in our encode method.\n        if len(kwargs) > 0:\n            # logger.warning(f\"Calling super().encode with {kwargs}\")\n            return super().encode(text, **kwargs)\n\n        assert type(text) is str\n\n        # The tiktoken tokenizer can handle <=400k chars without\n        # pyo3_runtime.PanicException.\n        TIKTOKEN_MAX_ENCODE_CHARS = 400_000\n\n        # https://github.com/openai/tiktoken/issues/195\n        # Here we iterate over subsequences and split if we exceed the limit\n        # of max consecutive non-whitespace or whitespace characters.\n        MAX_NO_WHITESPACES_CHARS = 25_000\n\n        texts = self.pre_tokenizer_process(text)\n\n        all_substrs = []\n        for text in texts:\n            substrs = (\n                substr\n                for i in range(0, len(text), TIKTOKEN_MAX_ENCODE_CHARS)\n                for substr in self._split_whitespaces_or_nonwhitespaces(\n                    text[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS\n                )\n            )\n            all_substrs.extend(substrs)\n\n        t: List[int] = []\n        for substr in all_substrs:\n            if allow_special_tokens:\n                t.extend(\n                    # we should consider special token as a common token\n                    self.model.encode(\n                        substr,\n                        allowed_special=\"all\",\n                    )\n                )\n            else:\n                t.extend(\n                    # we should consider special token as a common token\n                    self.model.encode(\n                        substr,\n                        disallowed_special=(),\n                    )\n                )\n\n        return t\n\n    def decode(self, token_ids: Union[int, List[int]], **kwargs) -> str:\n        \"\"\"\n        Decodes a list of token IDs into a string.\n\n        Args:\n            token_ids (List[int]): The list of token IDs to be decoded.\n\n        Returns:\n            str: The decoded string.\n        \"\"\"\n        # If there are other args, we should call super().decode because there are a lot of code\n        # to handle those args. supper().encode finally will call convert_tokens_to_string and _convert_id_to_token.\n        if len(kwargs) > 0:\n            return super().decode(token_ids, **kwargs)\n\n        if type(token_ids) is int:\n            token_ids = [token_ids]\n\n        return self.model.decode(cast(List[int], token_ids))\n\n    @staticmethod\n    def _split_whitespaces_or_nonwhitespaces(\n        s: str, max_consecutive_slice_len: int\n    ) -> Iterator[str]:\n        \"\"\"\n        Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len`\n        consecutive whitespaces or consecutive non-whitespaces.\n        \"\"\"\n        current_slice_len = 0\n        current_slice_is_space = s[0].isspace() if len(s) > 0 else False\n        slice_start = 0\n\n        for i in range(len(s)):\n            is_now_space = s[i].isspace()\n\n            if current_slice_is_space ^ is_now_space:\n                current_slice_len = 1\n                current_slice_is_space = is_now_space\n            else:\n                current_slice_len += 1\n                if current_slice_len > max_consecutive_slice_len:\n                    yield s[slice_start:i]\n                    slice_start = i\n                    current_slice_len = 1\n        yield s[slice_start:]\n\n    def pre_tokenizer_process(self, text: str) -> List[str]:\n        \"\"\"\n        pre-tokenizes the input text into a list of tokens.\n        This method is used to split the input text into smaller chunks for internal processing.\n        \"\"\"\n        return [text]\n\n    \"\"\" ----- Below are the abstract methods required by PreTrainedTokenizer ----- \"\"\"\n\n    @property\n    def vocab_size(self) -> int:\n        return self.n_words\n\n    def get_vocab(self) -> Dict[str, int]:\n        return self.encoder\n\n    def _tokenize(self, text: str, **kwargs) -> List[str]:\n        return [self.decoder[t] for t in self.encode(text)]\n\n    def _convert_token_to_id(self, token: str) -> int:\n        return self.encoder.get(token, self.unk_id)\n\n    def _convert_id_to_token(self, index: int) -> str:\n        return self.decoder.get(index)\n\n    @staticmethod\n    def clean_up_tokenization(out_string: str) -> str:\n        return out_string\n\n    def convert_tokens_to_string(self, tokens: List[str]) -> str:\n        text = \"\".join(tokens)\n        text = bytearray([self.byte_decoder[c] for c in text]).decode(\n            \"utf-8\", \"replace\"\n        )\n        return text\n\n    def save_vocabulary(\n        self, save_directory: str, filename_prefix: Optional[str] = None\n    ) -> Tuple[str]:\n        if not os.path.isdir(save_directory):\n            raise ValueError(\n                f\"vocabulary path ({save_directory}) should be a directory\"\n            )\n        out_vocab_file = os.path.join(\n            save_directory,\n            (filename_prefix + \"-\" if filename_prefix else \"\")\n            + VOCAB_FILES_NAMES[\"vocab_file\"],\n        )\n\n        if os.path.abspath(self.vocab_file) != os.path.abspath(\n            out_vocab_file\n        ) and os.path.isfile(self.vocab_file):\n            copyfile(self.vocab_file, out_vocab_file)\n\n        return (out_vocab_file,)\n\n    def apply_chat_template(\n        self,\n        conversation,\n        tools: Optional[list[dict]] = None,\n        tokenize: bool = True,\n        add_generation_prompt: bool = False,\n        **kwargs,\n    ):\n        tools = deep_sort_dict(tools)\n        return super().apply_chat_template(\n            conversation,\n            tools=tools,\n            tokenize=tokenize,\n            add_generation_prompt=add_generation_prompt,\n            **kwargs,\n        )\n\n\ndef deep_sort_dict(obj: Any) -> Any:\n    if isinstance(obj, dict):\n        return {k: deep_sort_dict(v) for k, v in sorted(obj.items())}\n    if isinstance(obj, list):\n        return [deep_sort_dict(item) for item in obj]\n    return obj\n"
  },
  {
    "path": "src/axolotl/monkeypatch/models/llama4/__init__.py",
    "content": ""
  },
  {
    "path": "src/axolotl/monkeypatch/models/llama4/modeling.py",
    "content": "\"\"\"\nModified Llama-4 text experts modeling for linearized experts for improved LoRA support\n\"\"\"\n\nimport sys\n\nimport torch\nfrom torch import nn\nfrom transformers import Llama4Config\nfrom transformers.activations import ACT2FN\n\n\nclass Llama4TextExperts(nn.Module):\n    \"\"\"\n    Modified Llama-4 text experts modeling for linearized experts\n    \"\"\"\n\n    def __init__(self, config: Llama4Config):\n        super().__init__()\n        self.num_experts = config.num_local_experts\n        self.intermediate_size = config.intermediate_size\n        self.hidden_size = config.hidden_size\n        self.expert_dim = self.intermediate_size\n\n        # Replace fused gate_up_proj with separate Linear modules\n        self.gate_projs = nn.ModuleList(\n            [\n                nn.Linear(self.hidden_size, self.expert_dim, bias=False)\n                for _ in range(self.num_experts)\n            ]\n        )\n\n        self.up_projs = nn.ModuleList(\n            [\n                nn.Linear(self.hidden_size, self.expert_dim, bias=False)\n                for _ in range(self.num_experts)\n            ]\n        )\n\n        # Replace down_proj Parameter with Linear modules\n        self.down_projs = nn.ModuleList(\n            [\n                nn.Linear(self.expert_dim, self.hidden_size, bias=False)\n                for _ in range(self.num_experts)\n            ]\n        )\n\n        self.act_fn = ACT2FN[config.hidden_act]\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Forward method using separate Linear layers for each expert.\n\n        Args:\n            hidden_states (torch.Tensor): (num_experts * batch_size, hidden_size)\n                The input should be organized by expert\n\n        Returns:\n            torch.Tensor: (num_experts * batch_size, hidden_size)\n        \"\"\"\n        # Reshape to separate by expert\n        hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size)\n        # batch_size_per_expert = hidden_states.size(1)\n\n        # Initialize output tensor\n        next_states = torch.zeros_like(hidden_states)\n\n        # Process each expert separately\n        for i in range(self.num_experts):\n            # Get input for this expert\n            expert_input = hidden_states[\n                i\n            ]  # Shape: (batch_size_per_expert, hidden_size)\n\n            # Apply gate and up projections\n            gate = self.gate_projs[i](\n                expert_input\n            )  # Shape: (batch_size_per_expert, expert_dim)\n            up = self.up_projs[i](\n                expert_input\n            )  # Shape: (batch_size_per_expert, expert_dim)\n\n            # Apply activation and down projection\n            next_states[i] = self.down_projs[i](up * self.act_fn(gate))\n\n        # Flatten back to original shape\n        return next_states.view(-1, self.hidden_size)\n\n\ndef patch_llama4_linearized_modeling():\n    \"\"\"\n    Patch Llama4TextExperts to use separate Linear layers for each expert.\n    \"\"\"\n    from transformers.models.llama4 import modeling_llama4\n\n    old_lamma_4_text_experts = modeling_llama4.Llama4TextExperts\n    modeling_llama4.Llama4TextExperts = Llama4TextExperts\n    sys.modules[\"transformers.models.llama4\"].Llama4TextExperts = Llama4TextExperts\n\n    def unpatch():\n        modeling_llama4.Llama4TextExperts = old_lamma_4_text_experts\n        sys.modules[\n            \"transformers.models.llama4\"\n        ].Llama4TextExperts = old_lamma_4_text_experts\n\n    return unpatch\n"
  },
  {
    "path": "src/axolotl/monkeypatch/models/mistral3/__init__.py",
    "content": ""
  },
  {
    "path": "src/axolotl/monkeypatch/models/mistral3/mistral_common_tokenizer.py",
    "content": "\"\"\"\nMonkeypatch to fix inefficient tensor conversion in MistralCommonBackend.apply_chat_template\n\"\"\"\n\nimport importlib\nimport inspect\n\nfrom axolotl.monkeypatch.utils import detab_code\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\n\ndef apply_mistral_tokenizer_image_patch():\n    \"\"\"Apply patch to MistralCommonBackend.apply_chat_template to fix image tensor conversion.\"\"\"\n    from transformers.tokenization_mistral_common import MistralCommonBackend\n\n    # Get original source\n    original_source = inspect.getsource(MistralCommonBackend.apply_chat_template)\n    original_source, _ = detab_code(original_source)\n\n    # Define the replacement\n    original_tensor_conversion = (\n        \"                    pixel_values = torch.tensor(images)\"\n    )\n\n    patched_tensor_conversion = \"\"\"                    if isinstance(images, list) and len(images) > 0 and isinstance(images[0], np.ndarray):\n                        pixel_values = torch.tensor(np.array(images))\n                    else:\n                        pixel_values = torch.tensor(images)\"\"\"\n\n    # Apply the replacement\n    if original_tensor_conversion in original_source:\n        patched_source = original_source.replace(\n            original_tensor_conversion, patched_tensor_conversion\n        )\n        patched_source = patched_source.replace(\n            \"def apply_chat_template(\",\n            \"def patched_apply_chat_template(\",\n            1,\n        )\n\n        # Load necessary imports from the module\n        module_name = MistralCommonBackend.__module__\n        module = importlib.import_module(module_name)\n\n        # Detect what needs to be imported\n        items_to_import = []\n        for item in dir(module):\n            if item in patched_source and not item.startswith(\"_\"):\n                items_to_import.append(item)\n\n        # Execute imports in global scope\n        if items_to_import:\n            exec(  # nosec B102\n                f\"from {module_name} import ({', '.join(items_to_import)})\",\n                globals(),\n            )\n\n        # Also need standard imports that might be used\n        exec(\"import numpy as np\", globals())  # nosec B102\n        exec(\"import torch\", globals())  # nosec B102\n        exec(\"from typing import Union, Optional, List, Dict, Any, Callable\", globals())  # nosec B102\n        exec(\"from pathlib import Path\", globals())  # nosec B102\n\n        # Import other dependencies that might be needed\n        try:\n            exec(\"from transformers.utils import is_torch_available\", globals())  # nosec B102\n            exec(\n                \"from transformers.tokenization_utils_base import BatchEncoding, PaddingStrategy, TensorType\",\n                globals(),\n            )  # nosec B102\n            exec(\"from transformers.utils import logging\", globals())  # nosec B102\n            exec(\"logger = logging.get_logger(__name__)\", globals())  # nosec B102\n        except ImportError as e:\n            LOG.warning(f\"Could not import some dependencies: {e}\")\n\n        # Execute the patched source\n        exec(patched_source, globals())  # nosec B102\n\n        # Replace the method\n        MistralCommonBackend.apply_chat_template = patched_apply_chat_template\n        LOG.info(\"Successfully applied MistralCommonBackend tensor conversion patch\")\n    else:\n        LOG.warning(\"Could not find target code for MistralCommonBackend patching\")\n"
  },
  {
    "path": "src/axolotl/monkeypatch/models/pixtral/__init__.py",
    "content": ""
  },
  {
    "path": "src/axolotl/monkeypatch/models/pixtral/modeling_flash_attention_utils.py",
    "content": "\"\"\"Monkeypatch for FA utils to accept 1D position_ids from Pixtral's position_ids_in_meshgrid\"\"\"\n\nimport torch\n\n\ndef apply_patch_is_packed_sequence():\n    \"\"\"Apply patch to FA utils to accept 1D position_ids from Pixtral's position_ids_in_meshgrid\"\"\"\n    from transformers import modeling_flash_attention_utils\n\n    def fixed_is_packed_sequence(position_ids, batch_size):\n        \"\"\"\n        Check the position ids whether packed sequences are indicated or not\n            1. Position ids exist\n            2. Flattened sequences only are supported\n            3. Compile-friendly `not (torch.diff(position_ids, dim=-1) >= 0).all()`, i.e. we have multiple increasing sequences\n        \"\"\"\n        if position_ids is None:\n            return False\n\n        if position_ids.ndim == 1:\n            position_ids = position_ids.unsqueeze(0)  # [N] -> [1, N]\n\n        increasing_position_sequences = (\n            torch.arange(position_ids.shape[1], device=position_ids.device)\n            + position_ids.min()\n        )\n        return (\n            batch_size == 1\n            and (increasing_position_sequences - position_ids).abs().sum().bool().item()\n        )\n\n    # Store original method\n    old_fn = modeling_flash_attention_utils._is_packed_sequence\n\n    # Apply the patch\n    modeling_flash_attention_utils._is_packed_sequence = fixed_is_packed_sequence\n\n    def unpatch():\n        \"\"\"Restore the original method\"\"\"\n        modeling_flash_attention_utils._is_packed_sequence = old_fn\n\n    return unpatch\n"
  },
  {
    "path": "src/axolotl/monkeypatch/models/qwen3_5/__init__.py",
    "content": ""
  },
  {
    "path": "src/axolotl/monkeypatch/models/qwen3_5/modeling.py",
    "content": "\"\"\"Monkeypatch for Qwen3_5 and Qwen3_5Moe models to pass position_ids to linear attention.\"\"\"\n\nimport importlib\nfrom typing import Optional, Tuple\n\nimport torch\nimport torch.nn.functional as F\n\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\ntry:\n    from fla.modules.convolution import (\n        causal_conv1d as fla_causal_conv1d,  # FLA >= 0.4.1\n    )\nexcept ImportError:\n    try:\n        from fla.modules.conv import causal_conv1d as fla_causal_conv1d  # FLA < 0.4.1\n    except ImportError:\n        fla_causal_conv1d = None\n\n\ndef get_cu_seqlens(position_ids):\n    \"\"\"\n    Compute cumulative sequence lengths from position_ids for FLA varlen kernels.\n\n    Adapted from transformers.modeling_flash_attention_utils.prepare_fa_kwargs_from_position_ids.\n    https://github.com/huggingface/transformers/blob/0f1b128d3359a26bd18be99c26d7f04fb3cba914/src/transformers/modeling_flash_attention_utils.py#L316\n\n    Qwen3.5 uses MRoPE: position_ids arrive as [axes, B, T]. All axes carry the\n    same temporal positions, so axis 0 is used to recover the [B, T] layout.\n    See: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen3_5/modeling_qwen3_5.py\n    \"\"\"\n    if position_ids.ndim == 3:\n        position_ids = position_ids[0]\n\n    tensor_kwargs = {\"dtype\": torch.int32, \"device\": position_ids.device}\n    position_ids = position_ids.view(-1)\n    indices_q = (position_ids == 0).nonzero().view(-1)\n    return torch.cat(\n        (\n            indices_q.to(**tensor_kwargs),\n            torch.tensor(position_ids.size(), **tensor_kwargs),\n        )\n    )\n\n\ndef _inject_fla_kernels(module) -> None:\n    \"\"\"Inject FLA kernels into a modeling module, bypassing is_flash_linear_attention_available.\"\"\"\n    try:\n        from fla.modules import FusedRMSNormGated\n        from fla.ops.gated_delta_rule import (\n            chunk_gated_delta_rule,\n            fused_recurrent_gated_delta_rule,\n        )\n\n        module.FusedRMSNormGated = FusedRMSNormGated\n        module.chunk_gated_delta_rule = chunk_gated_delta_rule\n        module.fused_recurrent_gated_delta_rule = fused_recurrent_gated_delta_rule\n        module.is_fast_path_available = True\n    except ImportError:\n        module.chunk_gated_delta_rule = None\n        module.fused_recurrent_gated_delta_rule = None\n        module.FusedRMSNormGated = None\n\n\ndef _patched_decoder_forward(\n    self,\n    hidden_states: torch.Tensor,\n    position_embeddings: Tuple[torch.Tensor, torch.Tensor],\n    attention_mask: Optional[torch.Tensor] = None,\n    position_ids: Optional[torch.LongTensor] = None,\n    past_key_values=None,\n    cache_position: Optional[torch.LongTensor] = None,\n    **kwargs,\n) -> torch.FloatTensor:\n    \"\"\"Decoder layer forward that passes position_ids through to linear attention.\"\"\"\n    residual = hidden_states\n    hidden_states = self.input_layernorm(hidden_states)\n\n    if self.layer_type == \"linear_attention\":\n        hidden_states = self.linear_attn(\n            hidden_states=hidden_states,\n            cache_params=past_key_values,\n            cache_position=cache_position,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n        )\n    elif self.layer_type == \"full_attention\":\n        hidden_states, _ = self.self_attn(\n            hidden_states=hidden_states,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            cache_position=cache_position,\n            position_embeddings=position_embeddings,\n            **kwargs,\n        )\n\n    hidden_states = residual + hidden_states\n\n    residual = hidden_states\n    hidden_states = self.post_attention_layernorm(hidden_states)\n    hidden_states = self.mlp(hidden_states)\n    if isinstance(hidden_states, tuple):  # MoE returns (hidden_states, router_logits)\n        hidden_states, _ = hidden_states\n    hidden_states = residual + hidden_states\n\n    return hidden_states\n\n\ndef _make_qwen3_5_gated_delta_forward(apply_mask_fn):\n    \"\"\"Factory for patched Qwen3_5/Qwen3_5Moe GatedDeltaNet forward with packing support.\"\"\"\n\n    def patched_forward(\n        self,\n        hidden_states: torch.Tensor,\n        cache_params=None,\n        cache_position: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n    ):\n        hidden_states = apply_mask_fn(hidden_states, attention_mask)\n\n        batch_size, seq_len, _ = hidden_states.shape\n\n        use_precomputed_states = (\n            cache_params is not None\n            and cache_params.has_previous_state\n            and seq_len == 1\n            and cache_position is not None\n        )\n\n        cu_seqlens = None\n        if not use_precomputed_states and position_ids is not None:\n            cu_seqlens = get_cu_seqlens(position_ids=position_ids)\n\n        if cache_params is not None:\n            conv_state = cache_params.conv_states[self.layer_idx]\n            recurrent_state = cache_params.recurrent_states[self.layer_idx]\n\n        # mixed_qkv stays [B, T, D]; only transposed inside paths that require [B, D, T]\n        mixed_qkv = self.in_proj_qkv(hidden_states)  # [B, T, D]\n\n        z = self.in_proj_z(hidden_states)\n        z = z.reshape(batch_size, seq_len, -1, self.head_v_dim)\n\n        b = self.in_proj_b(hidden_states)\n        a = self.in_proj_a(hidden_states)\n\n        if use_precomputed_states:\n            mixed_qkv = self.causal_conv1d_update(\n                mixed_qkv.transpose(1, 2),\n                conv_state,\n                self.conv1d.weight.squeeze(1),\n                self.conv1d.bias,\n                self.activation,\n            ).transpose(1, 2)\n        else:\n            if cache_params is not None:\n                mixed_qkv_t = mixed_qkv.transpose(1, 2)\n                cache_params.conv_states[self.layer_idx] = F.pad(\n                    mixed_qkv_t,\n                    (self.conv_kernel_size - mixed_qkv_t.shape[-1], 0),\n                )\n\n            if fla_causal_conv1d is not None and cu_seqlens is not None:\n                # FLA varlen kernel for packed sequences; input must be contiguous [B, T, D]\n                mixed_qkv, _ = fla_causal_conv1d(\n                    x=mixed_qkv,\n                    weight=self.conv1d.weight.squeeze(1),\n                    bias=self.conv1d.bias,\n                    activation=self.activation,\n                    cu_seqlens=cu_seqlens,\n                )\n            else:\n                if cu_seqlens is not None and fla_causal_conv1d is None:\n                    raise RuntimeError(\n                        \"Packed sequences require fla.modules.convolution.causal_conv1d \"\n                        \"(cu_seqlens support). Install flash-linear-attention or disable packing.\"\n                    )\n                mixed_qkv = F.silu(\n                    self.conv1d(mixed_qkv.transpose(1, 2))[:, :, :seq_len]\n                ).transpose(1, 2)\n\n        query, key, value = torch.split(\n            mixed_qkv,\n            [self.key_dim, self.key_dim, self.value_dim],\n            dim=-1,\n        )\n        query = query.reshape(batch_size, seq_len, -1, self.head_k_dim)\n        key = key.reshape(batch_size, seq_len, -1, self.head_k_dim)\n        value = value.reshape(batch_size, seq_len, -1, self.head_v_dim)\n\n        beta = b.sigmoid()\n        g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)\n        if self.num_v_heads // self.num_k_heads > 1:\n            query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)\n            key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)\n\n        if not use_precomputed_states:\n            core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule(\n                query,\n                key,\n                value,\n                g=g.to(dtype=query.dtype),\n                beta=beta,\n                initial_state=None,\n                output_final_state=cache_params is not None,\n                use_qk_l2norm_in_kernel=True,\n                # torch_chunk_gated_delta_rule fallback does not accept cu_seqlens\n                **({\"cu_seqlens\": cu_seqlens} if cu_seqlens is not None else {}),\n            )\n        else:\n            core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule(\n                query,\n                key,\n                value,\n                g=g.to(dtype=query.dtype),\n                beta=beta,\n                initial_state=recurrent_state,\n                output_final_state=cache_params is not None,\n                use_qk_l2norm_in_kernel=True,\n            )\n\n        if cache_params is not None:\n            cache_params.recurrent_states[self.layer_idx] = last_recurrent_state\n\n        core_attn_out = core_attn_out.reshape(-1, self.head_v_dim)\n        z = z.reshape(-1, self.head_v_dim)\n        core_attn_out = self.norm(core_attn_out, z)\n        core_attn_out = core_attn_out.reshape(batch_size, seq_len, -1)\n\n        return self.out_proj(core_attn_out)\n\n    return patched_forward\n\n\ndef _apply_packing_patches(model_type: str, cls_prefix: str, forward_factory) -> None:\n    module_name = f\"transformers.models.{model_type}.modeling_{model_type}\"\n\n    try:\n        module = importlib.import_module(module_name)\n    except ImportError:\n        LOG.warning(f\"{model_type} not found in transformers, skipping packing patches\")\n        return\n\n    _inject_fla_kernels(module)\n    getattr(module, f\"{cls_prefix}DecoderLayer\").forward = _patched_decoder_forward\n    gated_cls = getattr(module, f\"{cls_prefix}GatedDeltaNet\")\n    gated_cls.forward = forward_factory(module.apply_mask_to_padding_states)\n\n    LOG.info(\n        f\"Applied {cls_prefix} packing patch \"\n        f\"(fla_causal_conv1d={'available' if fla_causal_conv1d else 'unavailable'})\"\n    )\n\n\ndef patch_qwen3_5_modeling_packing():\n    _apply_packing_patches(\"qwen3_5\", \"Qwen3_5\", _make_qwen3_5_gated_delta_forward)\n\n\ndef patch_qwen3_5_moe_modeling_packing():\n    _apply_packing_patches(\n        \"qwen3_5_moe\", \"Qwen3_5Moe\", _make_qwen3_5_gated_delta_forward\n    )\n\n\ndef patch_qwen3_5_vlm_flash_attention():\n    \"\"\"\n    Patch _is_packed_sequence to handle Qwen3.5's 3-D MRoPE position_ids.\n\n    transformers passes position_ids as [axes, B, T] to decoder layers, but\n    _is_packed_sequence only handles 2-D tensors and mis-classifies the 3-D\n    shape as a packed-sequence indicator, causing CUDA errors in the varlen path.\n    \"\"\"\n    try:\n        import transformers.modeling_flash_attention_utils as fa_utils\n\n        _original = fa_utils._is_packed_sequence\n\n        def _patched(position_ids, batch_size):\n            if position_ids is not None and position_ids.ndim != 2:\n                return False\n            return _original(position_ids, batch_size)\n\n        fa_utils._is_packed_sequence = _patched\n        LOG.info(\"Applied Qwen3.5 VLM flash-attention patch (3-D MRoPE position_ids)\")\n    except Exception as exc:  # pragma: no cover\n        LOG.warning(f\"Failed to apply Qwen3.5 VLM flash-attention patch: {exc}\")\n"
  },
  {
    "path": "src/axolotl/monkeypatch/models/qwen3_next/__init__.py",
    "content": "\"\"\"Qwen3_Next model monkeypatches.\"\"\"\n"
  },
  {
    "path": "src/axolotl/monkeypatch/models/qwen3_next/modeling.py",
    "content": "\"\"\"Monkeypatch for Qwen3_Next model to pass position_ids to linear attention.\"\"\"\n\nfrom typing import Optional, Tuple\n\nimport torch\nimport torch.nn.functional as F\n\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\ntry:\n    from fla.modules.convolution import causal_conv1d as fla_causal_conv1d\nexcept ImportError:\n    fla_causal_conv1d = None\n\n\ndef get_cu_seqlens(position_ids):\n    \"\"\"\n    Adapted from transformers.modeling_flash_attention_utils.prepare_fa_kwargs_from_position_ids.\n\n    https://github.com/huggingface/transformers/blob/0f1b128d3359a26bd18be99c26d7f04fb3cba914/src/transformers/modeling_flash_attention_utils.py#L316\n    \"\"\"\n    tensor_kwargs = {\"dtype\": torch.int32, \"device\": position_ids.device}\n\n    position_ids = position_ids.view(-1)\n    indices_q = (position_ids == 0).nonzero().view(-1)\n\n    cu_seq_lens_q = torch.cat(\n        (\n            indices_q.to(**tensor_kwargs),\n            torch.tensor(position_ids.size(), **tensor_kwargs),\n        )\n    )\n\n    return cu_seq_lens_q\n\n\ndef patch_qwen3_next_decoder_layer():\n    \"\"\"Patch Qwen3NextDecoderLayer to pass position_ids to linear attention.\"\"\"\n    try:\n        from transformers.models.qwen3_next.modeling_qwen3_next import (\n            Qwen3NextDecoderLayer,\n        )\n    except ImportError:\n        LOG.warning(\"Qwen3Next model not found, skipping patch\")\n        return\n\n    # Store original forward method\n    original_decoder_forward = Qwen3NextDecoderLayer.forward\n\n    def patched_decoder_forward(\n        self,\n        hidden_states: torch.Tensor,\n        position_embeddings: Tuple[torch.Tensor, torch.Tensor],\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Tuple[torch.Tensor]] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n        **kwargs,\n    ) -> torch.FloatTensor:\n        residual = hidden_states\n\n        hidden_states = self.input_layernorm(hidden_states)\n\n        # Token Mixer\n        if self.layer_type == \"linear_attention\":\n            hidden_states = self.linear_attn(\n                hidden_states=hidden_states,\n                cache_params=past_key_values,\n                cache_position=cache_position,\n                attention_mask=attention_mask,\n                position_ids=position_ids,\n            )\n        elif self.layer_type == \"full_attention\":\n            # Self Attention\n            hidden_states, _ = self.self_attn(\n                hidden_states=hidden_states,\n                attention_mask=attention_mask,\n                position_ids=position_ids,\n                past_key_values=past_key_values,\n                cache_position=cache_position,\n                position_embeddings=position_embeddings,\n                **kwargs,\n            )\n\n        hidden_states = residual + hidden_states\n\n        # Fully Connected\n        residual = hidden_states\n        hidden_states = self.post_attention_layernorm(hidden_states)\n        hidden_states = self.mlp(hidden_states)\n        # For the MoE layers, we need to unpack\n        if isinstance(hidden_states, Tuple):\n            hidden_states, _ = hidden_states\n        hidden_states = residual + hidden_states\n\n        return hidden_states\n\n    # Apply the patches\n    Qwen3NextDecoderLayer.forward = patched_decoder_forward\n\n    def unpatch():\n        \"\"\"Restore the original forward method\"\"\"\n        Qwen3NextDecoderLayer.forward = original_decoder_forward\n\n    return unpatch\n\n\ndef patch_qwen3_next_gateddelta_layer():\n    \"\"\"Patch Qwen3NextGatedDeltaNet to parse cu_seqlens and pass to chunk_gated_delta_rule\"\"\"\n    try:\n        from transformers.models.qwen3_next.modeling_qwen3_next import (\n            Qwen3NextDynamicCache,\n            Qwen3NextGatedDeltaNet,\n            apply_mask_to_padding_states,\n        )\n    except ImportError:\n        LOG.warning(\"Qwen3Next model not found, skipping patch\")\n        return\n\n    # Store original forward method\n    original_gated_delta_net_forward = Qwen3NextGatedDeltaNet.forward\n\n    def patched_gated_delta_net_forward(\n        self,\n        hidden_states: torch.Tensor,\n        cache_params: Optional[Qwen3NextDynamicCache] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n    ):\n        hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)\n\n        # Set up dimensions for reshapes later\n        batch_size, seq_len, _ = hidden_states.shape\n\n        use_precomputed_states = (\n            cache_params is not None\n            and cache_params.has_previous_state\n            and seq_len == 1\n            and cache_position is not None\n        )\n\n        # Compute cu_seqlens early for use by both causal_conv1d and chunk_gated_delta_rule\n        cu_seqlens = None\n        if not use_precomputed_states and position_ids is not None:\n            cu_seqlens = get_cu_seqlens(position_ids=position_ids)\n\n        # getting projected states from cache if it exists\n        if cache_params is not None:\n            conv_state = cache_params.conv_states[self.layer_idx]\n            recurrent_state = cache_params.recurrent_states[self.layer_idx]\n\n        projected_states_qkvz = self.in_proj_qkvz(hidden_states)\n        projected_states_ba = self.in_proj_ba(hidden_states)\n        query, key, value, z, b, a = self.fix_query_key_value_ordering(\n            projected_states_qkvz, projected_states_ba\n        )\n        query, key, value = (\n            x.reshape(x.shape[0], x.shape[1], -1) for x in (query, key, value)\n        )\n\n        mixed_qkv = torch.cat((query, key, value), dim=-1)  # [B, T, D]\n\n        if use_precomputed_states:\n            # Inference single-token path: causal_conv1d_update expects [B, D, T]\n            mixed_qkv = mixed_qkv.transpose(1, 2)\n            mixed_qkv = self.causal_conv1d_update(\n                mixed_qkv,\n                conv_state,\n                self.conv1d.weight.squeeze(1),\n                self.conv1d.bias,\n                self.activation,\n            )\n            mixed_qkv = mixed_qkv.transpose(1, 2)\n        else:\n            if cache_params is not None:\n                # Cache state expects [B, D, T] for the inference update path\n                mixed_qkv_t = mixed_qkv.transpose(1, 2)\n                conv_state = F.pad(\n                    mixed_qkv_t,\n                    (self.conv_kernel_size - mixed_qkv_t.shape[-1], 0),\n                )\n                cache_params.conv_states[self.layer_idx] = conv_state\n\n            if fla_causal_conv1d is not None:\n                # FLA Triton causal_conv1d: [B, T, D] in/out, with cu_seqlens support\n                mixed_qkv, _ = fla_causal_conv1d(\n                    x=mixed_qkv,\n                    weight=self.conv1d.weight.squeeze(1),\n                    bias=self.conv1d.bias,\n                    activation=self.activation,\n                    cu_seqlens=cu_seqlens,\n                )\n            else:\n                # PyTorch fallback (no cu_seqlens support)\n                if cu_seqlens is not None and cu_seqlens.shape[0] > batch_size + 1:\n                    raise RuntimeError(\n                        \"Packed sequences require fla.modules.convolution.causal_conv1d \"\n                        \"(cu_seqlens support). Install flash-linear-attention or disable packing.\"\n                    )\n                LOG.warning_once(\n                    \"FLA causal_conv1d not available. Falling back to PyTorch conv1d.\"\n                )\n                mixed_qkv = mixed_qkv.transpose(1, 2)\n                mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len])\n                mixed_qkv = mixed_qkv.transpose(1, 2)\n\n        # mixed_qkv is [B, T, D] in all paths\n        query, key, value = torch.split(\n            mixed_qkv,\n            [\n                self.key_dim,\n                self.key_dim,\n                self.value_dim,\n            ],\n            dim=-1,\n        )\n        query = query.reshape(query.shape[0], query.shape[1], -1, self.head_k_dim)\n        key = key.reshape(key.shape[0], key.shape[1], -1, self.head_k_dim)\n        value = value.reshape(value.shape[0], value.shape[1], -1, self.head_v_dim)\n\n        beta = b.sigmoid()\n        # If the model is loaded in fp16, without the .float() here, A might be -inf\n        g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)\n        if self.num_v_heads // self.num_k_heads > 1:\n            query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)\n            key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)\n\n        if not use_precomputed_states:\n            core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule(\n                query,\n                key,\n                value,\n                g=g,\n                beta=beta,\n                initial_state=None,\n                output_final_state=cache_params is not None,\n                use_qk_l2norm_in_kernel=True,\n                cu_seqlens=cu_seqlens,\n            )\n\n        else:\n            core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule(\n                query,\n                key,\n                value,\n                g=g,\n                beta=beta,\n                initial_state=recurrent_state,\n                output_final_state=cache_params is not None,\n                use_qk_l2norm_in_kernel=True,\n            )\n\n        # Update cache\n        if cache_params is not None:\n            cache_params.recurrent_states[self.layer_idx] = last_recurrent_state\n\n        z_shape_og = z.shape\n        # reshape input data into 2D tensor\n        core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1])\n        z = z.reshape(-1, z.shape[-1])\n        core_attn_out = self.norm(core_attn_out, z)\n        core_attn_out = core_attn_out.reshape(z_shape_og)\n        core_attn_out = core_attn_out.reshape(\n            core_attn_out.shape[0], core_attn_out.shape[1], -1\n        )\n\n        output = self.out_proj(core_attn_out)\n        return output\n\n    # Apply the patches\n    Qwen3NextGatedDeltaNet.forward = patched_gated_delta_net_forward\n\n    def unpatch():\n        \"\"\"Restore the original forward method\"\"\"\n        Qwen3NextGatedDeltaNet.forward = original_gated_delta_net_forward\n\n    return unpatch\n\n\ndef patch_qwen3_next_imports():\n    \"\"\"Patch Qwen3Next imports to use try/except instead of is_flash_linear_attention_available.\"\"\"\n    try:\n        import transformers.models.qwen3_next.modeling_qwen3_next as qwen3_modeling\n    except ImportError:\n        LOG.warning(\"Qwen3Next model not found, skipping import patch\")\n        return\n\n    # Save original values for unpatch\n    original_FusedRMSNormGated = getattr(qwen3_modeling, \"FusedRMSNormGated\", None)\n    original_chunk_gated_delta_rule = getattr(\n        qwen3_modeling, \"chunk_gated_delta_rule\", None\n    )\n    original_fused_recurrent_gated_delta_rule = getattr(\n        qwen3_modeling, \"fused_recurrent_gated_delta_rule\", None\n    )\n    original_is_fast_path_available = getattr(\n        qwen3_modeling, \"is_fast_path_available\", False\n    )\n\n    try:\n        from fla.modules import FusedRMSNormGated\n        from fla.ops.gated_delta_rule import (\n            chunk_gated_delta_rule,\n            fused_recurrent_gated_delta_rule,\n        )\n\n        qwen3_modeling.FusedRMSNormGated = FusedRMSNormGated\n        qwen3_modeling.chunk_gated_delta_rule = chunk_gated_delta_rule\n        qwen3_modeling.fused_recurrent_gated_delta_rule = (\n            fused_recurrent_gated_delta_rule\n        )\n\n        # Force is_fast_path_available to be True\n        # fla has triton kernels for causal_conv1d\n        qwen3_modeling.is_fast_path_available = True\n    except ImportError:\n        qwen3_modeling.chunk_gated_delta_rule = None\n        qwen3_modeling.fused_recurrent_gated_delta_rule = None\n        qwen3_modeling.FusedRMSNormGated = None\n\n    def unpatch():\n        \"\"\"Restore the original import values\"\"\"\n        qwen3_modeling.FusedRMSNormGated = original_FusedRMSNormGated\n        qwen3_modeling.chunk_gated_delta_rule = original_chunk_gated_delta_rule\n        qwen3_modeling.fused_recurrent_gated_delta_rule = (\n            original_fused_recurrent_gated_delta_rule\n        )\n        qwen3_modeling.is_fast_path_available = original_is_fast_path_available\n\n    return unpatch\n\n\ndef patch_qwen3_next_modeling_packing():\n    \"\"\"Apply all Qwen3Next model patches.\"\"\"\n    patch_qwen3_next_imports()\n    patch_qwen3_next_decoder_layer()\n    patch_qwen3_next_gateddelta_layer()\n\n    LOG.info(\"Applied Qwen3Next patch for packing\")\n"
  },
  {
    "path": "src/axolotl/monkeypatch/models/voxtral/__init__.py",
    "content": ""
  },
  {
    "path": "src/axolotl/monkeypatch/models/voxtral/modeling.py",
    "content": "\"\"\"Monkeypatch for voxtral to fix leaf node and dtype mismatch\"\"\"\n\nfrom typing import Optional, Union\n\nimport torch\nfrom transformers.cache_utils import Cache\nfrom transformers.modeling_outputs import CausalLMOutputWithPast\n\n\ndef patch_voxtral_conditional_generation_forward():\n    from transformers.models.voxtral.modeling_voxtral import (\n        VoxtralForConditionalGeneration,\n    )\n\n    # Store the original forward method\n    old_forward = VoxtralForConditionalGeneration.forward\n\n    def _forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        input_features: Optional[torch.FloatTensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Cache] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n        logits_to_keep: Union[int, torch.Tensor] = 0,\n        **kwargs,\n    ) -> CausalLMOutputWithPast:\n        if inputs_embeds is None:\n            inputs_embeds = self.get_input_embeddings()(input_ids)\n\n        if input_features is not None:\n            audio_embeds = self.get_audio_embeds(input_features)\n\n            # Cast audio_embeds to match inputs_embeds dtype\n            audio_embeds = audio_embeds.to(inputs_embeds.dtype)\n\n            # replace text-audio token placeholders with audio embeddings\n            audio_token_mask = input_ids == self.config.audio_token_id\n\n            inputs_embeds = inputs_embeds.clone()\n            inputs_embeds[audio_token_mask] = audio_embeds\n\n        outputs = self.language_model(\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            labels=labels,\n            use_cache=use_cache,\n            cache_position=cache_position,\n            logits_to_keep=logits_to_keep,\n            **kwargs,\n        )\n        return outputs\n\n    # Apply the patch\n    VoxtralForConditionalGeneration.forward = _forward\n\n    def unpatch():\n        \"\"\"Restore the original forward method\"\"\"\n        VoxtralForConditionalGeneration.forward = old_forward\n\n    return unpatch\n"
  },
  {
    "path": "src/axolotl/monkeypatch/moe_quant.py",
    "content": "\"\"\"Loading-time quantization for MoE expert weights stored as 3D nn.Parameter tensors.\"\"\"\n\nimport bitsandbytes as bnb\nimport torch\nimport torch.nn.utils.parametrize as P\n\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\n_moe_load_state = {\n    \"count\": 0,\n    \"mode\": \"4bit\",\n    \"quant_type\": \"nf4\",\n    \"compress_statistics\": True,\n    \"patched\": False,\n    # Module path → param names in definition order, captured before quantization.\n    # Without this, alphabetical loading order would mismatch merge order.\n    \"expert_param_order\": {},\n}\n\n\nclass Bnb8bitParametrization(torch.nn.Module):\n    \"\"\"Dequantizes int8 row-wise quantized data on access.\"\"\"\n\n    def __init__(self, row_stats: torch.Tensor):\n        super().__init__()\n        self.register_buffer(\"row_stats\", row_stats)\n\n    @torch.no_grad()\n    def forward(self, quantized_param: torch.Tensor) -> torch.Tensor:\n        \"\"\"Flatten 3D+ to 2D for BnB's dequant, then reshape back.\"\"\"\n        orig_shape = quantized_param.shape\n        if quantized_param.ndim > 2:\n            quantized_param = quantized_param.reshape(-1, orig_shape[-1])\n        result = bnb.functional.int8_vectorwise_dequant(quantized_param, self.row_stats)\n        return result.reshape(orig_shape)\n\n\ndef _enable_parametrization_cache(module, inputs):\n    P._cache_enabled += 1\n\n\ndef _disable_parametrization_cache(module, inputs, output):\n    P._cache_enabled -= 1\n    if not P._cache_enabled:\n        P._cache = {}\n\n\ndef replace_parameter_8bit(module, param_name):\n    \"\"\"Replace a module parameter with an 8-bit quantized version using parametrization.\"\"\"\n    original_param = getattr(module, param_name)\n    int8_data, row_stats, _ = bnb.functional.int8_vectorwise_quant(\n        original_param.data.to(torch.float16)\n    )\n\n    setattr(module, param_name, torch.nn.Parameter(int8_data, requires_grad=False))\n    del original_param\n\n    P.register_parametrization(\n        module, param_name, Bnb8bitParametrization(row_stats), unsafe=True\n    )\n\n    # Cache dequantized values during forward to avoid redundant dequantization.\n    if not getattr(module, \"_axolotl_8bit_hooks_registered\", False):\n        module.register_forward_pre_hook(_enable_parametrization_cache)\n        module.register_forward_hook(_disable_parametrization_cache)\n        module._axolotl_8bit_hooks_registered = True\n\n\ndef patch_moe_quantization_on_load(cfg):\n    \"\"\"Patch transformers' weight loading to quantize MoE expert params on-the-fly.\"\"\"\n    mode = \"8bit\" if getattr(cfg, \"load_in_8bit\", False) else \"4bit\"\n    _moe_load_state[\"mode\"] = mode\n    _moe_load_state[\"count\"] = 0\n    _moe_load_state[\"expert_param_order\"] = {}\n\n    if _moe_load_state[\"patched\"]:\n        LOG.debug(\"MoE loading-time quantization patch already active\")\n        return\n\n    import transformers.core_model_loading\n    import transformers.modeling_utils\n\n    if mode == \"4bit\":\n        from bitsandbytes.nn.parametrize import replace_parameter_4bit\n\n        quant_type = getattr(cfg, \"bnb_4bit_quant_type\", None) or \"nf4\"\n        compress_statistics = getattr(cfg, \"bnb_4bit_use_double_quant\", None)\n        if compress_statistics is None:\n            compress_statistics = True\n\n        _moe_load_state[\"quant_type\"] = quant_type\n        _moe_load_state[\"compress_statistics\"] = compress_statistics\n\n    # Disable caching_allocator_warmup — it pre-allocates a huge tensor at bf16\n    # size for all params, defeating our on-load quantization VRAM savings.\n    def _noop_warmup(*args, **kwargs):\n        pass\n\n    transformers.modeling_utils.caching_allocator_warmup = _noop_warmup\n\n    original_set_param = transformers.core_model_loading.set_param_for_module\n\n    def _patched_set_param_for_module(model, target_name, param_value, *args, **kwargs):\n        original_set_param(model, target_name, param_value, *args, **kwargs)\n\n        if param_value.ndim >= 3 and param_value.is_cuda:\n            mod_path, _, pname = target_name.rpartition(\".\")\n            mod = model.get_submodule(mod_path) if mod_path else model\n            if not isinstance(mod, (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt)):\n                if \"expert\" not in target_name.lower():\n                    LOG.debug(\n                        \"Skipping non-expert 3D param: %s (shape=%s)\",\n                        target_name,\n                        list(param_value.shape),\n                    )\n                    return\n\n                # Record definition order before parametrizations override it\n                # with alphabetical order.\n                if mod_path not in _moe_load_state[\"expert_param_order\"]:\n                    _moe_load_state[\"expert_param_order\"][mod_path] = list(\n                        mod._parameters.keys()\n                    )\n\n                if _moe_load_state[\"mode\"] == \"4bit\":\n                    replace_parameter_4bit(\n                        mod,\n                        pname,\n                        compress_statistics=_moe_load_state[\"compress_statistics\"],\n                        quant_type=_moe_load_state[\"quant_type\"],\n                    )\n                else:\n                    replace_parameter_8bit(mod, pname)\n                _moe_load_state[\"count\"] += 1\n\n                # Release the bf16 tensor so CUDA memory is freed immediately.\n                param_value.data = torch.empty(0, device=\"cpu\")\n                torch.cuda.empty_cache()\n\n    transformers.core_model_loading.set_param_for_module = _patched_set_param_for_module\n    _moe_load_state[\"patched\"] = True\n\n\ndef get_moe_quantized_count():\n    \"\"\"Return the number of expert parameters quantized during loading.\"\"\"\n    return _moe_load_state[\"count\"]\n\n\ndef patch_peft_target_parameters_matching():\n    \"\"\"Fix PEFT's _inject_parameters for target_parameters on quantized MoE experts.\n\n    1. Expands short suffixes to full module paths for parametrized modules.\n    2. Iterates params in definition order (not alphabetical order) so saved\n       adapters are compatible with standard PEFT, vLLM, etc.\n    \"\"\"\n    if getattr(patch_peft_target_parameters_matching, \"_axolotl_patched\", False):\n        return\n\n    from contextlib import nullcontext\n\n    from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer\n    from peft.utils.integrations import init_empty_weights\n    from peft.utils.other import _get_submodules\n\n    def _patched_inject_parameters(\n        self, peft_config, model, adapter_name, low_cpu_mem_usage\n    ):\n        original_targets = list(peft_config.target_parameters)\n        expanded = set(original_targets)\n\n        # Expand short suffixes to full paths for parametrized modules.\n        for module_name, module in model.named_modules():\n            if not hasattr(module, \"parametrizations\"):\n                continue\n            for target in original_targets:\n                mod_path, _, param_name = target.rpartition(\".\")\n                if (\n                    module_name == mod_path or module_name.endswith(\".\" + mod_path)\n                ) and hasattr(module, param_name):\n                    expanded.add(f\"{module_name}.{param_name}\")\n\n        target_names_set = expanded\n\n        def strip_base_layer_from_name(module_name):\n            name = \".base_layer\"\n            while name in module_name:\n                prefix, _, suffix = module_name.rpartition(name)\n                module_name = prefix + suffix\n            return module_name\n\n        def create_and_replace_param(module_name, key, param_name):\n            parent, target, target_name = _get_submodules(model, module_name)\n            unwrapped_module_name = strip_base_layer_from_name(module_name)\n            unwrapped_module = model.get_submodule(unwrapped_module_name)\n            if (\n                isinstance(unwrapped_module, BaseTunerLayer)\n                and unwrapped_module.__class__.__name__ != \"ParamWrapper\"\n            ):\n                raise ValueError(\n                    f\"Trying to wrap an `nn.Parameter` of layer \"\n                    f\"'{unwrapped_module_name}' of type \"\n                    f\"{type(target).__name__}, which is not a valid target. \"\n                    f\"Make sure that this layer is not also targeted with \"\n                    f\"`target_modules`.\"\n                )\n            self._check_target_module_compatiblity(peft_config, model, target_name)\n            ctx = init_empty_weights if low_cpu_mem_usage else nullcontext\n            with ctx():\n                self._create_and_replace(\n                    peft_config,\n                    adapter_name,\n                    target,\n                    target_name,\n                    parent,\n                    current_key=key,\n                    parameter_name=param_name.rpartition(\".\")[-1],\n                )\n\n        # Use definition order (not alphabetical order) for parametrized modules\n        # so ParamWrapper nesting matches vanilla PEFT on a plain model.\n        expert_param_order = _moe_load_state.get(\"expert_param_order\", {})\n\n        for module_name, module in model.named_modules():\n            if hasattr(module, \"parametrizations\"):\n                stored_order = expert_param_order.get(module_name)\n                if stored_order is not None:\n                    params_iter = [\n                        p for p in stored_order if p in module.parametrizations\n                    ]\n                else:\n                    # Fallback for paths that bypass model loading (e.g. unit tests).\n                    params_iter = list(module.parametrizations.keys())\n                for param_name in params_iter:\n                    key = f\"{module_name}.{param_name}\"\n                    if (key in target_names_set) or any(\n                        key.endswith(f\".{t}\") for t in target_names_set\n                    ):\n                        create_and_replace_param(module_name, key, param_name)\n                        self.targeted_parameter_names.append(key)\n            else:\n                unwrapped_module_name = strip_base_layer_from_name(module_name)\n                for param_name, _ in module.named_parameters(recurse=False):\n                    key = f\"{unwrapped_module_name}.{param_name}\"\n                    if (key in target_names_set) or any(\n                        key.endswith(f\".{t}\") for t in target_names_set\n                    ):\n                        create_and_replace_param(module_name, key, param_name)\n                        self.targeted_parameter_names.append(key)\n\n    BaseTuner._inject_parameters = _patched_inject_parameters\n    patch_peft_target_parameters_matching._axolotl_patched = True\n    LOG.info(\"Patched PEFT _inject_parameters for consistent ParamWrapper ordering\")\n"
  },
  {
    "path": "src/axolotl/monkeypatch/multipack.py",
    "content": "\"\"\"multipack patching for v2 of sample packing\"\"\"\n\nimport importlib\n\nimport transformers\nfrom accelerate import init_empty_weights\nfrom transformers import AutoConfig, AutoModelForCausalLM\nfrom transformers.integrations import is_deepspeed_zero3_enabled\n\nfrom axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3\nfrom axolotl.monkeypatch.utils import get_unpad_data\n\nSUPPORTED_MULTIPACK_MODEL_TYPES = [\n    \"apertus\",\n    \"mllama_text_model\",\n    \"llama\",\n    \"llama4\",\n    \"mistral\",\n    \"mixtral\",\n    \"qwen2\",\n    \"qwen2_moe\",\n    \"qwen3\",\n    \"qwen3_moe\",\n    \"qwen3_next\",\n    \"qwen3_5\",\n    \"qwen3_5_moe\",\n    \"falcon\",\n    \"phi\",\n    \"phi3\",\n    \"gemma\",\n    \"gemma2\",\n    \"gemma3\",\n    \"gemma3_text\",\n    \"cohere\",\n    \"cohere2\",\n    \"gemmoe\",\n    \"starcoder2\",\n    \"deepseek_v2\",\n    \"deepseek_v3\",\n    \"glm\",\n    \"glm4\",\n    \"glm4_moe\",\n    \"smollm3\",\n    \"granite\",\n    \"granitemoe\",\n    \"granitemoeshared\",\n    \"granitemoehybrid\",\n    \"hunyuan_v1_dense\",\n    \"hunyuan_v1_moe\",\n    \"gpt_oss\",\n    \"arcee\",\n    \"seed_oss\",\n    \"lfm2\",\n    \"lfm2_moe\",\n    \"olmo\",\n    \"olmo2\",\n    \"olmo3\",\n    \"ministral\",\n    \"ministral3\",\n    \"mistral4\",\n    \"afmoe\",\n    \"nemotron\",\n]\n\n\ndef patch_for_multipack(model_type, model_name=None, has_remote_code=False):\n    if has_remote_code:\n        patch_remote(model_name)\n    elif hasattr(transformers, \"modeling_flash_attention_utils\"):\n        # sanity check in case upstream api changes on this\n        assert hasattr(\n            transformers.modeling_flash_attention_utils, \"_get_unpad_data\"\n        ), \"transformers api changed for _get_unpad_data for flash attention\"\n        transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data\n\n    if model_type == \"mixtral\" and is_deepspeed_zero3_enabled():\n        patch_mixtral_moe_forward_zero3()\n\n\ndef patch_remote(model_name):\n    model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)\n    # we need to load the model here in order for modeling_* to be available\n    with init_empty_weights():\n        AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)\n    parts = model_config.__class__.__module__.split(\".\")\n    parts[-1] = parts[-1].replace(\"configuration_\", \"modeling_\", 1)\n    module_name = \".\".join(parts)\n    modeling_arch = importlib.import_module(module_name)\n    if hasattr(modeling_arch, \"_get_unpad_data\"):\n        modeling_arch._get_unpad_data = get_unpad_data\n"
  },
  {
    "path": "src/axolotl/monkeypatch/peft/__init__.py",
    "content": ""
  },
  {
    "path": "src/axolotl/monkeypatch/peft/utils.py",
    "content": "\"\"\"\nPatch prepare_model_for_kbit_training to not upcast everything\n\"\"\"\n\nimport inspect\n\nimport peft\n\nimport axolotl\nfrom axolotl.monkeypatch.utils import detab_code\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\nORIGINAL_PREPARE_CODE = \"\"\"\n        for param in model.parameters():\n            if (\n                (param.dtype == torch.float16) or (param.dtype == torch.bfloat16)\n            ) and param.__class__.__name__ != \"Params4bit\":\n                param.data = param.data.to(torch.float32)\n\"\"\"\n\nPATCHED_PREPARE_CODE = \"\"\"\n        for name, param in model.named_parameters():\n            if (\n                (param.dtype == torch.float16) or (param.dtype == torch.bfloat16)\n            ) and param.__class__.__name__ != \"Params4bit\" and all(embed_name not in name for embed_name in [\"embed_tokens\", \"lm_head\"]):\n                param.data = param.data.to(torch.float32)\n\"\"\"\n\n\ndef get_peft_prep_code() -> str:\n    prepare = inspect.getsource(peft.utils.other.prepare_model_for_kbit_training)\n    return prepare\n\n\ndef check_peft_prep_code_is_patchable() -> bool:\n    prep_code = get_peft_prep_code()\n    prep_code, _ = detab_code(prep_code)\n    return ORIGINAL_PREPARE_CODE in prep_code\n\n\ndef patch_peft_prep_code():\n    \"\"\"\n    monkeypatch create_accelerator_and_postprocess so it checks for additional kwargs\n    \"\"\"\n\n    try:\n        prep_code = get_peft_prep_code()\n    except OSError:\n        return\n    peft.utils.other._original_create_accelerator_and_postprocess = prep_code\n    prep_code, _ = detab_code(prep_code)\n    if ORIGINAL_PREPARE_CODE not in prep_code:\n        return\n\n    prep_code = prep_code.replace(ORIGINAL_PREPARE_CODE, PATCHED_PREPARE_CODE)\n    prep_code = prep_code.replace(\n        \"def prepare_model_for_kbit_training(\",\n        \"def fixed_prepare_model_for_kbit_training(\",\n        1,\n    )\n\n    items_to_import = []\n    for item in dir(peft.utils.other):\n        if item in prep_code:\n            items_to_import.append(item)\n\n    exec(\n        \"from peft.utils.other import (\" + \", \".join(x for x in items_to_import) + \")\",\n        globals(),\n    )\n    exec(prep_code, globals())\n    LOG.info(\"patching prepare_model_for_kbit_training to allow for overrides\")\n    peft.utils.other.prepare_model_for_kbit_training = (\n        fixed_prepare_model_for_kbit_training\n    )\n    axolotl.loaders.model.prepare_model_for_kbit_training = (\n        fixed_prepare_model_for_kbit_training\n    )\n"
  },
  {
    "path": "src/axolotl/monkeypatch/relora.py",
    "content": "\"\"\"Implements the ReLoRA training procedure from https://arxiv.org/abs/2307.05695, minus the initial full fine-tune.\"\"\"\n\nimport glob\nimport json\nimport os.path\nimport shutil\nfrom functools import partial\nfrom pathlib import Path\nfrom typing import Dict, List, Union\n\nimport bitsandbytes as bnb\nimport peft\nimport safetensors.torch as st\nimport torch\nfrom huggingface_hub import snapshot_download\nfrom torch.distributed.optim import ZeroRedundancyOptimizer\nfrom transformers import (\n    TrainerCallback,\n    TrainerControl,\n    TrainerState,\n    TrainingArguments,\n)\nfrom transformers.trainer_utils import PREFIX_CHECKPOINT_DIR\n\nfrom axolotl.utils.dict import DictDefault\nfrom axolotl.utils.distributed import barrier, is_main_process\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\n\n@torch.no_grad()\ndef magnitude_pruning_(tensor, prune_ratio):\n    tensor_magnitude = torch.abs(tensor)\n    threshold = torch.quantile(\n        tensor_magnitude.flatten().to(dtype=torch.float32), prune_ratio\n    ).to(dtype=tensor.dtype)\n\n    mask = tensor_magnitude > threshold\n    tensor.mul_(mask.to(dtype=tensor.dtype))\n\n\ndef reset_optimizer(\n    optimizer: torch.optim.Optimizer,\n    *,\n    reset_params: List[str],  # where str is the key to a torch.nn.Parameter\n    optimizer_state_keys: List[str],\n    optimizer_magnitude_pruning: float = 0.9,\n):\n    # pylint:disable=unused-argument\n    pruning_fn = partial(magnitude_pruning_, prune_ratio=optimizer_magnitude_pruning)\n    n_zeros = 0\n    n_total = 0\n\n    optimizer_state = optimizer.state\n    if isinstance(optimizer, ZeroRedundancyOptimizer):\n        optimizer_state = optimizer.optim.state\n\n    for group in optimizer.param_groups:\n        for param in group[\"params\"]:\n            state = optimizer_state[param]\n            for key, value in state.items():\n                if key not in optimizer_state_keys:\n                    continue\n                if torch.is_tensor(value):\n                    try:\n                        pruning_fn(value)\n                        n_total += value.numel()\n                        n_zeros += torch.sum(value == 0).item()\n                    except RuntimeError as exc:\n                        if \"quantile() input tensor is too large\" in str(exc):\n                            pass\n                        else:\n                            raise exc\n\n    _zeroed = n_zeros / (1e-7 + n_total) * 100\n    LOG.info(f\"Percent of optimizer states zeroed: {_zeroed:.2f}\")\n    LOG.info(f\"absolute n of optimizer states zeroed: {n_zeros}\")\n\n\nclass ReLoRACallback(TrainerCallback):\n    \"\"\"Callback to merge LoRA weights into the base model and save full-weight checkpoints\"\"\"\n\n    def __init__(self, cfg: DictDefault):\n        self.relora_steps = cfg.jagged_restart_steps\n        self.cpu_offload = cfg.relora_cpu_offload\n        self.quantized = cfg.load_in_4bit or cfg.load_in_8bit\n        self.last_full_model = cfg.base_model\n        self.resume_from_checkpoint = cfg.resume_from_checkpoint\n\n        if not os.path.exists(self.last_full_model):\n            self.last_full_model = str(Path(snapshot_download(cfg.base_model)))\n\n        assert os.path.exists(self.last_full_model), (\n            \"for ReLORA base_model must be a local path\"\n        )\n\n        self.num_lora_restarts = 0\n        self.need_full_save = False\n\n    def on_train_begin(\n        self,\n        _args: TrainingArguments,\n        _state: TrainerState,\n        control: TrainerControl,\n        model: peft.LoraModel,\n        **_kwargs,\n    ):\n        if self.resume_from_checkpoint:\n            weight_path = os.path.join(self.resume_from_checkpoint, \"relora\")\n            if not os.path.exists(weight_path):\n                LOG.warning(\n                    \"Resuming ReLoRA from checkpoint, but no full-weight save found\"\n                )\n            else:\n                LOG.info(f\"Loading adjusted base weights from {weight_path}\")\n                load_weight_checkpoint(model, weight_path)\n        return control\n\n    def on_step_begin(\n        self,\n        args: TrainingArguments,\n        state: TrainerState,\n        control: TrainerControl,\n        model: peft.LoraModel,\n        optimizer: torch.optim.Optimizer,\n        **_kwargs,\n    ):\n        if not optimizer:\n            optimizer = state.optimizer\n        if state.global_step > 0 and state.global_step % self.relora_steps == 0:\n            checkpoint_folder = os.path.join(\n                args.output_dir,\n                f\"{PREFIX_CHECKPOINT_DIR}-{state.global_step}\",\n                \"relora\",\n            )\n\n            if \"adam\" in args.optim.lower():\n                optimizer_state_keys = [\"exp_avg\", \"exp_avg_sq\"]\n                if \"8bit\" in args.optim.lower():\n                    optimizer_state_keys.append(\"state1\")\n                    optimizer_state_keys.append(\"state2\")\n            else:\n                raise ValueError(f\"Optimizer {args.optim} not supported with ReLoRA\")\n\n            lora_params = [\n                n\n                for n, p in model.named_parameters()\n                if p.requires_grad and \"lora_\" in n\n            ]\n\n            model.save_pretrained(\n                os.path.join(\n                    args.output_dir,\n                    f\"{PREFIX_CHECKPOINT_DIR}-{state.global_step}\",\n                    \"adapter\",\n                ),\n            )\n            with torch.no_grad():\n                merge_and_save(\n                    model,\n                    self.last_full_model,\n                    checkpoint_folder,\n                    reinit=True,\n                    quantized=self.quantized,\n                    actually_save=is_main_process(),\n                    cpu_offload=self.cpu_offload,\n                )\n                reset_optimizer(\n                    optimizer,\n                    reset_params=lora_params,\n                    optimizer_state_keys=optimizer_state_keys,\n                    optimizer_magnitude_pruning=args.relora_prune_ratio,\n                )\n\n            if self.quantized:\n                self.last_full_model = checkpoint_folder\n            self.num_lora_restarts += 1\n\n        return control\n\n    def on_save(\n        self,\n        args: TrainingArguments,\n        state: TrainerState,\n        control: TrainerControl,\n        model: peft.LoraModel,\n        **_kwargs,\n    ):\n        checkpoint_folder = os.path.join(\n            args.output_dir, f\"{PREFIX_CHECKPOINT_DIR}-{state.global_step}\", \"relora\"\n        )\n        if (\n            state.global_step >= self.relora_steps\n            and state.global_step % self.relora_steps != 0\n        ):\n            if self.quantized:\n                if is_main_process() and self.last_full_model != checkpoint_folder:\n                    # ensure the latest full parameter save is in the latest checkpoint\n                    # folder, so that automatic pruning of checkpoints does not remove it\n                    LOG.info(f\"moving last full parameter save to {checkpoint_folder}\")\n                    os.makedirs(checkpoint_folder, exist_ok=True)\n                    chunks = glob.glob(\n                        f\"{self.last_full_model}/model*.safetensors\"\n                    ) + glob.glob(f\"{self.last_full_model}/model*.index.json\")\n                    for path in chunks:\n                        new_path = os.path.abspath(shutil.move(path, checkpoint_folder))\n                        try:\n                            os.symlink(new_path, path)\n                        except OSError:\n                            # probably on windows without permission to symlink\n                            pass\n\n                    self.last_full_model = checkpoint_folder\n            else:\n                model.model.save_pretrained(checkpoint_folder)\n\n        return control\n\n    def on_log(\n        self,\n        _args: TrainingArguments,\n        _state: TrainerState,\n        control: TrainerControl,\n        logs: Dict[str, float],\n        **_kwargs,\n    ):\n        logs[\"num_lora_restarts\"] = self.num_lora_restarts\n        return control\n\n    def on_train_end(\n        self,\n        args: TrainingArguments,\n        _state: TrainerState,\n        control: TrainerControl,\n        model: peft.LoraModel,\n        **_kwargs,\n    ):\n        if self.quantized:\n            # perform final merge and save\n            with torch.no_grad():\n                merge_and_save(\n                    model,\n                    self.last_full_model,\n                    args.output_dir,\n                    reinit=False,\n                    quantized=self.quantized,\n                    actually_save=is_main_process(),\n                    cpu_offload=self.cpu_offload,\n                )\n        # no need to save if unquantized, as finetune.py will call merge_and_unload()\n        return control\n\n\ndef sharded_paths(path: str, module_names: List[str]) -> Dict[str, str]:\n    model_name = \"model.safetensors\"\n    if not os.path.exists(str(Path(path) / model_name)) and not os.path.exists(\n        str(Path(path) / f\"{model_name}.index.json\")\n    ):\n        model_name = \"pytorch_model.bin\"\n\n    index_path = str(Path(path) / f\"{model_name}.index.json\")\n    if os.path.exists(index_path):\n        with open(index_path, \"r\", encoding=\"utf-8\") as file:\n            data = json.load(file)\n        return data[\"weight_map\"]\n    return {(module_name + \".weight\"): model_name for module_name in module_names}\n\n\ndef lora_delta_weight(layer: peft.tuners.lora.LoraLayer, device) -> torch.Tensor:\n    if isinstance(layer, (peft.tuners.lora.Linear8bitLt, peft.tuners.lora.Linear4bit)):\n        adapter: Union[List[str], str] = layer.active_adapter\n        if isinstance(adapter, list):\n            if len(adapter) > 1:\n                raise ValueError(\"unhandled relora for multiple adapters\")\n            adapter = adapter[0]\n        return (\n            peft.utils.transpose(\n                layer.lora_B[adapter].weight.detach().to(device)\n                @ layer.lora_A[adapter].weight.detach().to(device),\n                getattr(layer, \"fan_in_fan_out\", False),\n            )\n            * layer.scaling[adapter]\n        )\n\n    raise ValueError(\"unhandled lora layer type\")\n\n\ndef find_lora_modules(model: peft.LoraModel) -> Dict[str, peft.tuners.lora.LoraLayer]:\n    modules: Dict[str, peft.tuners.lora.LoraLayer] = {}\n\n    key_list = [key for key, _ in model.model.named_modules() if \"lora\" not in key]\n    for key in key_list:\n        try:\n            _parent, target, _target_name = peft.utils._get_submodules(model.model, key)\n        except AttributeError:\n            continue\n\n        if isinstance(target, peft.tuners.lora.LoraLayer):\n            modules[key] = target\n\n    return modules\n\n\ndef update_weights(\n    target: peft.tuners.lora.LoraLayer, new_weight: torch.Tensor, reinit: bool, device\n):\n    if reinit:\n        for adapter_name in target.lora_A:\n            target.reset_lora_parameters(adapter_name, True)\n        for adapter_name in target.lora_embedding_A:\n            target.reset_lora_parameters(adapter_name, True)\n\n    if isinstance(target, peft.tuners.lora.Linear4bit):\n        # This could be faster, but the quantization of Linear4bit weights occurs\n        # when the module is moved from cpu to gpu. Without meddling *too* deeply in\n        # PEFT's innards or maintaining a duplicate of that codepath, this is good\n        # enough for now.\n        target.weight.quant_state = None\n        target.weight.data = new_weight.cpu()\n        target.to(device)\n    elif isinstance(target, peft.tuners.lora.Linear8bitLt):\n        target.weight.data = (\n            bnb.nn.Int8Params(new_weight, requires_grad=False).to(device).data\n        )\n    else:\n        target.weight.data = new_weight.to(device)\n\n\ndef merge_and_save(\n    model: peft.LoraModel,\n    model_src: str,\n    model_dst: str,\n    reinit: bool = False,\n    quantized: bool = False,\n    cpu_offload: bool = False,\n    actually_save: bool = True,\n):\n    modules = find_lora_modules(model)\n\n    if not quantized:\n        for _, target in modules.items():\n            active_adapter = target.active_adapter\n            if isinstance(active_adapter, list):\n                active_adapter = active_adapter[0]\n            update = target.get_delta_weight(active_adapter).detach()\n            target.weight.data += update\n\n            if reinit:\n                for adapter_name in target.lora_A:\n                    target.reset_lora_parameters(adapter_name, True)\n                for adapter_name in target.lora_embedding_A:\n                    target.reset_lora_parameters(adapter_name, True)\n        return\n\n    os.makedirs(model_dst, exist_ok=True)\n    shard_paths = sharded_paths(model_src, modules.keys())\n    out_shard_paths = {}\n\n    unique_shards = list(set(shard_paths.values()))\n    for shard_path in unique_shards:\n        out_tensors = {}\n        if shard_path.endswith(\".safetensors\"):\n            in_tensors = st.load_file(str(Path(model_src) / shard_path))\n        else:\n            in_tensors = torch.load(\n                Path(model_src) / shard_path,\n                weights_only=True,  # to prevent arbitrary code execution\n            )\n            if \"state_dict\" in in_tensors:\n                in_tensors = in_tensors[\"state_dict\"]\n\n        for module_name, target in modules.items():\n            key = module_name + \".weight\"\n            if key not in shard_paths or shard_paths[key] != shard_path:\n                continue\n\n            orig_weight = in_tensors[key]\n            old_dev = target.weight.device\n            math_dev = \"cpu\" if cpu_offload else old_dev\n\n            delta_weight = lora_delta_weight(target, math_dev)\n            new_weight = orig_weight.to(math_dev) + delta_weight\n            del delta_weight\n\n            if actually_save:\n                out_tensors[key] = new_weight.half().cpu()\n\n            update_weights(target, new_weight, reinit=reinit, device=old_dev)\n\n        if actually_save:\n            out_shard_name = shard_path\n            if out_shard_name.startswith(\"pytorch_model\"):\n                out_shard_name = (\n                    out_shard_name.replace(\"pytorch_model\", \"model\").rstrip(\".bin\")\n                    + \".safetensors\"\n                )\n\n            for module_name in in_tensors:\n                if module_name not in out_tensors:\n                    out_tensors[module_name] = in_tensors[module_name].half()\n                out_shard_paths[module_name] = out_shard_name\n\n            shard_fn = str(Path(model_dst) / out_shard_name)\n            LOG.info(f\"saving tensors to {shard_fn}\")\n            st.save_file(out_tensors, shard_fn, metadata={\"format\": \"pt\"})\n\n        barrier()\n        del in_tensors\n        del out_tensors\n        torch.cuda.empty_cache()\n\n    if actually_save and len(unique_shards) > 1:\n        with open(\n            str(Path(model_dst, \"model.safetensors.index.json\")), \"w\", encoding=\"utf-8\"\n        ) as file:\n            json.dump({\"metadata\": {}, \"weight_map\": out_shard_paths}, file)\n\n\ndef load_weight_checkpoint(model: peft.LoraModel, checkpoint_path: str):\n    modules = find_lora_modules(model)\n    shard_paths = sharded_paths(checkpoint_path, modules.keys())\n    unique_shards = list(set(shard_paths.values()))\n\n    for shard_path in unique_shards:\n        tensors = st.load_file(os.path.join(checkpoint_path, shard_path))\n\n        for module_name, target in modules.items():\n            key = module_name + \".weight\"\n            if key not in shard_paths or shard_paths[key] != shard_path:\n                continue\n\n            new_weight = tensors[key]\n            update_weights(\n                target, new_weight, reinit=False, device=target.weight.device\n            )\n"
  },
  {
    "path": "src/axolotl/monkeypatch/ring_attn/__init__.py",
    "content": "\"\"\"Init for ring attention monkeypatch module\"\"\"\n\n# flake8: noqa\n\nfrom .patch import (\n    get_ring_attn_group,\n    register_ring_attn_from_device_mesh,\n    set_ring_attn_group,\n    update_ring_attn_params,\n)\n\n__all__ = (\n    \"get_ring_attn_group\",\n    \"register_ring_attn_from_device_mesh\",\n    \"set_ring_attn_group\",\n    \"update_ring_attn_params\",\n)\n"
  },
  {
    "path": "src/axolotl/monkeypatch/ring_attn/adapters/__init__.py",
    "content": ""
  },
  {
    "path": "src/axolotl/monkeypatch/ring_attn/adapters/batch.py",
    "content": "\"\"\"\nHuggingFace flash attention adapter for basic ring attention (batch API).\n\nInspired by\nhttps://github.com/zhuzilin/ring-flash-attention/blob/ce9fd3935ca0e5f0592bb0826cbed18ec69da729/ring_flash_attn/adapters/hf_adapter.py.\nOur implementation closely follows the structure of that module, but we've minified it\nsomewhat to support only the latest versions of transformers.\n\"\"\"\n\nimport os\nfrom typing import Callable\n\nimport torch\nimport torch.distributed as dist\nimport transformers\nimport transformers.modeling_flash_attention_utils\nfrom ring_flash_attn import ring_flash_attn_func\nfrom ring_flash_attn.adapters.hf_adapter import check_params\nfrom transformers.modeling_flash_attention_utils import is_flash_attn_greater_or_equal\n\ntry:\n    from transformers.modeling_flash_attention_utils import _flash_supports_window\nexcept ImportError:\n    try:\n        from transformers.modeling_flash_attention_utils import (\n            _flash_supports_window_size as _flash_supports_window,\n        )\n    except ImportError:\n        _flash_supports_window = True\n\nfrom transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS\n\nfrom axolotl.utils.schemas.enums import RingAttnFunc\n\nRING_ATTN_FUNC_MAPPING = {\n    RingAttnFunc.BATCH_RING: torch.compile(ring_flash_attn_func),\n    # RingAttnFunc.BATCH_ZIGZAG: torch.compile(zigzag_ring_flash_attn_func),\n    # RingAttnFunc.BATCH_STRIPE: torch.compile(stripe_flash_attn_func),\n}\n\n\ndef create_flash_attn_forward_varlen_llama3(\n    process_group: dist.ProcessGroup, ring_attn_func: RingAttnFunc\n) -> Callable:\n    \"\"\"\n    Create a ring flash attention forward function compatible with HuggingFace's\n    interface.\n\n    Args:\n        process_group: A PyTorch distributed process group.\n        ring_attn_func: Function from `ring_flash_attention` to replace HF flash\n            attention with.\n\n    Returns:\n        A function that implements the ring flash attention forward pass with the\n            signature expected by HuggingFace Transformers.\n    \"\"\"\n\n    # transformers 4.48+\n\n    def _flash_attention_forward(\n        query_states: torch.Tensor,\n        key_states: torch.Tensor,\n        value_states: torch.Tensor,\n        attention_mask: torch.Tensor,\n        query_length: int,\n        is_causal: bool,\n        dropout: float = 0.0,\n        position_ids: torch.Tensor | None = None,\n        softmax_scale: float | None = None,\n        sliding_window: int | None = None,\n        use_top_left_mask: bool = False,\n        softcap: float | None = None,\n        deterministic: bool = None,\n        cu_seq_lens_q: torch.LongTensor | None = None,\n        cu_seq_lens_k: torch.LongTensor | None = None,\n        max_length_q: int | None = None,\n        max_length_k: int | None = None,\n        target_dtype: torch.dtype | None = None,\n        attn_implementation: str | None = None,\n        **kwargs,\n    ):\n        \"\"\"\n        Calls the forward method of Ring Flash Attention.\n\n        Args:\n            query_states: Tensor containing the query vectors.\n            key_states: Tensor containing the key vectors.\n            value_states: Tensor containing the value vectors.\n            attention_mask: Not used in this implementation.\n            query_length: Integer representing the length of the query sequence.\n            is_causal: Boolean indicating whether to apply a causal mask to the attention.\n            dropout: Float representing the dropout probability. Default is 0.0.\n            position_ids: Not used in this implementation.\n            softmax_scale: Optional float value for the softmax scaling factor. Default is None.\n            sliding_window: Optional integer defining the size of the sliding attention window.\n                Default is None.\n            use_top_left_mask: Boolean indicating whether to use a top-left mask for the attention.\n                Default is False.\n            softcap: Not used in this implementation.\n            deterministic: Optional boolean to enforce deterministic computation. Default is None.\n            cu_seq_lens_q: Not used in this implementation.\n            cu_seq_lens_k: Not used in this implementation.\n            max_length_q: Not used in this implementation.\n            max_length_k: Not used in this implementation.\n            target_dtype: Not used in this implementation.\n            attn_implementation: Not used in this implementation.\n            **kwargs: Additional keyword arguments. Not used in this implementation.\n\n        Returns:\n            torch.Tensor: The output of the attention mechanism, with shape\n                `[batch_size, query_length, num_heads, head_dim]`.\n        \"\"\"\n        if not use_top_left_mask:\n            causal = is_causal\n        else:\n            causal = is_causal and query_length != 1\n\n        # Handle sliding window\n        use_sliding_windows = (\n            _flash_supports_window\n            and sliding_window is not None\n            and key_states.shape[1] > sliding_window\n        )\n        window_size = (\n            (sliding_window, sliding_window) if use_sliding_windows else (-1, -1)\n        )\n\n        # Handle deterministic mode\n        if is_flash_attn_greater_or_equal(\"2.4.1\"):\n            if deterministic is None:\n                deterministic = (\n                    os.environ.get(\"FLASH_ATTENTION_DETERMINISTIC\", \"0\") == \"1\"\n                )\n\n        # Call ring flash attention function\n        attn_output = RING_ATTN_FUNC_MAPPING[ring_attn_func](\n            query_states,\n            key_states,\n            value_states,\n            dropout_p=dropout,\n            softmax_scale=softmax_scale,\n            causal=causal,\n            window_size=window_size,\n            alibi_slopes=None,\n            deterministic=deterministic,\n            return_attn_probs=False,\n            group=process_group,\n        )\n\n        return attn_output\n\n    return _flash_attention_forward\n\n\ndef substitute_hf_flash_attn(\n    process_group: dist.ProcessGroup, ring_attn_func: RingAttnFunc\n):\n    \"\"\"\n    Substitute HuggingFace's flash attention implementation with ring-based implementation.\n\n    Args:\n        process_group: PyTorch distributed process group for communication.\n        ring_attn_func: Function from `ring_flash_attention` to replace HF flash\n            attention with.\n    \"\"\"\n    try:\n        # Substitute flash attention\n        old_flash_attention_forward = (\n            transformers.modeling_flash_attention_utils._flash_attention_forward\n        )\n        new_flash_attention_forward = create_flash_attn_forward_varlen_llama3(\n            process_group=process_group, ring_attn_func=ring_attn_func\n        )\n\n        if check_params(old_flash_attention_forward, new_flash_attention_forward):\n            transformers.modeling_flash_attention_utils._flash_attention_forward = (\n                new_flash_attention_forward\n            )\n        else:\n            raise ValueError(\n                \"The signature of the new flash attention forward function does not match the old one.\"\n            )\n    except Exception as exception:\n        raise ValueError(\n            f\"The current transformer version {transformers.__version__} is not supported. \"\n            \"Please use pip install -U transformers to upgrade to the latest version. \"\n            \"If the code failed with the latest version, \"\n            f\"please file an issue.\"\n        ) from exception\n\n    # Register with ALL_ATTENTION_FUNCTIONS if available\n    if ALL_ATTENTION_FUNCTIONS is not None:\n        from ring_flash_attn.adapters.hf_adapter import flash_attention_forward\n\n        ALL_ATTENTION_FUNCTIONS[\"flash_attention_2\"] = flash_attention_forward\n"
  },
  {
    "path": "src/axolotl/monkeypatch/ring_attn/patch.py",
    "content": "\"\"\"Ring attention group registration and flash attention patching.\n\nMake use of the `ring-flash-attn` (https://github.com/zhuzilin/ring-flash-attention)\npackage, specifically the `hf_adapter.substitute_hf_flash_attn` function to patch in\ntheir sequence parallel version of Flash Attention 2.\n\nWe also provide some patches for accelerate functions to prepare the dataloader for\nsequence parallelism training.\n\"\"\"\n\nimport os\nfrom typing import Optional\n\nimport torch\nimport torch.distributed as dist\nfrom torch.distributed import DeviceMesh\n\ntry:\n    from transformers.modeling_flash_attention_utils import _flash_supports_window\nexcept ImportError:\n    try:\n        from transformers.modeling_flash_attention_utils import (\n            _flash_supports_window_size as _flash_supports_window,\n        )\n    except ImportError:\n        _flash_supports_window = True\n\nfrom axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids\nfrom axolotl.utils.logging import get_logger\nfrom axolotl.utils.schemas.enums import RingAttnFunc\n\nLOG = get_logger(__name__)\n\nRING_ATTN_GROUP = None\n\n\ndef get_ring_attn_group() -> dist.ProcessGroup:\n    \"\"\"Getter for ring attention group on this rank.\"\"\"\n    if RING_ATTN_GROUP is None:\n        raise RuntimeError(\"register_ring_attn_from_device_mesh() not yet called\")\n    return RING_ATTN_GROUP\n\n\ndef set_ring_attn_group(ring_attn_group: dist.ProcessGroup | None):\n    \"\"\"Setter for ring attention group on this rank.\"\"\"\n    global RING_ATTN_GROUP\n    RING_ATTN_GROUP = ring_attn_group\n\n\ndef create_ring_flash_attention_forward(\n    process_group: dist.ProcessGroup, heads_k_stride: int\n):\n    from ring_flash_attn import llama3_flash_attn_varlen_func\n    from ring_flash_attn.adapters.hf_adapter import DATA_PARAMS\n\n    def _flash_attention_forward_v3(\n        query_states: torch.Tensor,\n        key_states: torch.Tensor,\n        value_states: torch.Tensor,\n        attention_mask: torch.Tensor,\n        query_length: int,\n        is_causal: bool,\n        dropout: float = 0.0,\n        position_ids: Optional[torch.Tensor] = None,\n        softmax_scale: Optional[float] = None,\n        sliding_window: Optional[int] = None,\n        use_top_left_mask: bool = False,\n        softcap: Optional[float] = None,\n        deterministic: bool = None,\n        cu_seq_lens_q: Optional[torch.LongTensor] = None,\n        cu_seq_lens_k: Optional[torch.LongTensor] = None,\n        max_length_q: Optional[int] = None,\n        max_length_k: Optional[int] = None,\n        target_dtype: Optional[torch.dtype] = None,\n        attn_implementation: Optional[str] = None,\n        **kwargs,\n    ):\n        if not use_top_left_mask:\n            causal = is_causal\n        else:\n            # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__.\n            causal = is_causal and query_length != 1\n\n        # Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length).\n        use_sliding_windows = (\n            _flash_supports_window\n            and sliding_window is not None\n            and key_states.shape[1] > sliding_window\n        )\n        flash_kwargs = (\n            {\"window_size\": (sliding_window, sliding_window)}\n            if use_sliding_windows\n            else {}\n        )\n\n        if deterministic is None:\n            deterministic = os.environ.get(\"FLASH_ATTENTION_DETERMINISTIC\", \"0\") == \"1\"\n        flash_kwargs[\"deterministic\"] = deterministic\n        assert softcap is None, (\n            \"llama3_flash_attn_varlen_func does not support softcap yet.\"\n        )\n        # flash_kwargs[\"softcap\"] = softcap\n        flash_kwargs[\"group\"] = process_group\n\n        # not sure why attention_mask can be not None...\n        assert causal, \"only causal attention is supported yet.\"\n        batch_size = query_states.size(0)\n        assert batch_size == 1, \"varlen data should be processed in advance.\"\n\n        attn_output = llama3_flash_attn_varlen_func(\n            query_states.squeeze(dim=0),\n            key_states.squeeze(dim=0),\n            value_states.squeeze(dim=0),\n            cu_seqlens_q=DATA_PARAMS[\"cu_seqlens_q\"],\n            cu_seqlens_k=DATA_PARAMS[\"cu_seqlens_k\"],\n            max_seqlen_q=DATA_PARAMS[\"max_seqlen_q\"],\n            max_seqlen_k=DATA_PARAMS[\"max_seqlen_k\"],\n            heads_k_stride=heads_k_stride,\n            local_k_slice=DATA_PARAMS[\"local_k_slice\"],\n            dropout_p=dropout,\n            softmax_scale=softmax_scale,\n            causal=causal,\n            **flash_kwargs,\n        )\n\n        attn_output = attn_output.unsqueeze(dim=0)\n\n        return attn_output\n\n    return [\n        _flash_attention_forward_v3,\n    ]\n\n\ndef register_ring_attn_from_device_mesh(\n    device_mesh: \"DeviceMesh\",\n    context_parallel_dim: tuple[str, ...],\n    heads_k_stride: int | None,\n    ring_attn_func: RingAttnFunc | None,\n):\n    \"\"\"Create ring attention group using DeviceMesh and substitute flash attn with ring flash attn.\n\n    Args:\n        device_mesh: DeviceMesh object containing the parallelism topology.\n        context_parallel_dim: Name of the sequence parallel dimension in the device mesh.\n        heads_k_stride: Sequence parallelism K head stride size. Passed through to\n            `varlen_llama3` `ring_flash_attn` implementation.\n        ring_attn_func: `ring_flash_attn` ring attention implemention. If sample\n            packing is enabled, it must be a `varlen` function; otherwise, it must be a\n            `batch` function.\n    \"\"\"\n    rank = dist.get_rank()\n\n    LOG.info(\n        f\"Enabling ring attention sequence parallelism using DeviceMesh \"\n        f\"dimension '{context_parallel_dim}'\",\n    )\n\n    # Extract the sequence parallel submesh\n    try:\n        sequence_mesh = device_mesh[context_parallel_dim]\n    except (KeyError, IndexError) as e:\n        raise ValueError(\n            f\"Dimension '{context_parallel_dim}' not found in device_mesh. \"\n            f\"Available dimensions: {device_mesh.mesh_dim_names}\"\n        ) from e\n\n    # Get the process group for context parallelism\n    sequence_pg = sequence_mesh.get_group()\n    context_parallel_size = sequence_mesh.size()\n\n    if rank == 0:\n        LOG.info(\n            f\"Sequence parallel degree: {context_parallel_size}, \"\n            f\"mesh shape: {sequence_mesh.mesh.shape}\"\n        )\n\n    # Log which ranks are in the current process group\n    if sequence_pg != dist.GroupMember.WORLD:\n        ranks_in_group = dist.get_process_group_ranks(sequence_pg)\n        LOG.info(f\"Current sequence parallel group ranks: {ranks_in_group}\")\n\n    # Set the ring attention group\n    set_ring_attn_group(sequence_pg)\n\n    if ring_attn_func is RingAttnFunc.VARLEN_LLAMA3:\n        # fmt: off\n        import ring_flash_attn.adapters.hf_adapter\n\n        from ring_flash_attn.adapters.hf_adapter import (  # isort: skip\n            create_ring_flash_attention_forward as create_ring_flash_attention_forward_orig,\n        )\n\n        create_ring_flash_attention_forward_orig = (  # noqa: F811,F841\n            create_ring_flash_attention_forward\n        )\n        ring_flash_attn.adapters.hf_adapter.create_ring_flash_attention_forward = create_ring_flash_attention_forward\n        # fmt: on\n\n        ring_flash_attn.adapters.hf_adapter.substitute_hf_flash_attn(\n            process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride or 1\n        )\n    elif ring_attn_func is RingAttnFunc.BATCH_RING:\n        from axolotl.monkeypatch.ring_attn.adapters.batch import (\n            substitute_hf_flash_attn,\n        )\n\n        substitute_hf_flash_attn(\n            process_group=get_ring_attn_group(),\n            ring_attn_func=ring_attn_func,\n        )\n\n\ndef update_ring_attn_params(position_ids: torch.Tensor | None):\n    \"\"\"\n    Calculate the cumulative sequence lengths for the current forward pass and pass the\n    value to the substituted `ring_flash_attn`.\n\n    Args:\n        position_ids: Optional tensor of position IDs (for sample packed data).\n    \"\"\"\n    from ring_flash_attn import update_ring_flash_attn_params\n\n    cu_seqlens, _ = get_cu_seqlens_from_pos_ids(position_ids)\n    cu_seqlens = cu_seqlens.squeeze().to(device=torch.cuda.current_device())\n    update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group())\n"
  },
  {
    "path": "src/axolotl/monkeypatch/scaled_softmax_attn.py",
    "content": "\"\"\"\nScaled Softmax (SSMax) attention patch using FlexAttention.\nSSMax:  softmax(scores * s * log(n) + b) where n is the position index\nRef: https://arxiv.org/abs/2501.19399\n\"\"\"\n\nimport torch\nfrom transformers import PreTrainedModel\n\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\ntry:\n    from torch.nn.attention.flex_attention import BlockMask\n    from transformers.integrations.flex_attention import (\n        compile_friendly_flex_attention,\n        repeat_kv,\n    )\n\n    FLEX_ATTENTION_AVAILABLE = True\nexcept ImportError:\n    FLEX_ATTENTION_AVAILABLE = False\n    BlockMask = None\n\n_ssmax_config = {}\n\n\ndef patch_scaled_softmax_attention(\n    scaling_factor_init: float = 0.43, bias: float = 0.0, model: PreTrainedModel = None\n):\n    \"\"\"Patch attention to apply SSMax via FlexAttention score_mod.\"\"\"\n    global _ssmax_config\n\n    if not FLEX_ATTENTION_AVAILABLE:\n        raise RuntimeError(\"SSMax requires FlexAttention.\")\n\n    _ssmax_config[\"ssmax_s\"] = scaling_factor_init\n    _ssmax_config[\"ssmax_b\"] = bias\n\n    from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS\n\n    if \"flex_attention\" in ALL_ATTENTION_FUNCTIONS:\n        _ssmax_config[\"original_flex_fn\"] = ALL_ATTENTION_FUNCTIONS[\"flex_attention\"]\n        ALL_ATTENTION_FUNCTIONS[\"flex_attention\"] = ssmax_flex_attention_forward\n        LOG.info(\n            f\"Patched flex_attention with SSMax (s={scaling_factor_init}, b={bias})\"\n        )\n    else:\n        LOG.warning(\"flex_attention not found.  Ensure flex_attention:  true is set.\")\n\n\ndef ssmax_flex_attention_forward(\n    module: torch.nn.Module,\n    query: torch.Tensor,\n    key: torch.Tensor,\n    value: torch.Tensor,\n    attention_mask,\n    scaling: float | None = None,\n    softcap: float | None = None,\n    **kwargs,\n) -> tuple[torch.Tensor, torch.Tensor | None]:\n    \"\"\"FlexAttention forward with SSMax:  score * (s * log(n) + b).\"\"\"\n\n    if kwargs.get(\"dropout\", 0.0) > 0:\n        raise ValueError(\"flex_attention does not support dropout\")\n\n    ssmax_s = _ssmax_config.get(\"ssmax_s\", 0.43)\n    ssmax_b = _ssmax_config.get(\"ssmax_b\", 0.0)\n\n    position_ids = kwargs.get(\"position_ids\", None)\n    position_ids_flat = position_ids.view(-1) if position_ids is not None else None\n\n    block_mask = attention_mask if isinstance(attention_mask, BlockMask) else None\n    score_mask = None if block_mask else attention_mask\n\n    if score_mask is not None:\n        score_mask = score_mask[:, :, :, : key.shape[-2]]\n\n    def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):\n        \"\"\"\n        Apply SSMax scaling:  score * (s * log(n) + b)\n        where n is the relative position within each packed sequence.\n        \"\"\"\n        if position_ids_flat is not None:\n            relative_pos = position_ids_flat[q_idx]\n            n = (relative_pos + 1).float()\n        else:\n            n = (q_idx + 1).float()\n\n        n = torch.clamp(n, min=2.0)\n\n        ssmax_scale = ssmax_s * torch.log(n) + ssmax_b\n        score = score * ssmax_scale\n\n        if softcap is not None:\n            score = softcap * torch.tanh(score / softcap)\n\n        if score_mask is not None:\n            score = score + score_mask[batch_idx][0][q_idx][kv_idx]\n\n        return score\n\n    enable_gqa = True\n    if (query.shape[1] & (query.shape[1] - 1)) != 0:\n        key = repeat_kv(key, query.shape[1] // key.shape[1])\n        value = repeat_kv(value, query.shape[1] // value.shape[1])\n        enable_gqa = False\n\n    return_lse = query.device.type != \"cpu\"\n    flex_output = compile_friendly_flex_attention(\n        query,\n        key,\n        value,\n        score_mod=score_mod,\n        block_mask=block_mask,\n        enable_gqa=enable_gqa,\n        scale=scaling,\n        kernel_options=kwargs.get(\"kernel_options\"),\n        return_lse=return_lse,\n        training=module.training,\n    )\n\n    if return_lse:\n        attention_output, lse = flex_output\n        lse = lse.to(value.dtype)\n    else:\n        attention_output, lse = flex_output, None\n\n    return attention_output.transpose(1, 2).contiguous(), lse\n\n\ndef unpatch_scaled_softmax_attention():\n    \"\"\"Restore the original FlexAttention function.\"\"\"\n    global _ssmax_config\n    from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS\n\n    if \"original_flex_fn\" in _ssmax_config:\n        ALL_ATTENTION_FUNCTIONS[\"flex_attention\"] = _ssmax_config[\"original_flex_fn\"]\n        _ssmax_config.clear()\n        LOG.info(\"Unpatched flex_attention, restored original\")\n"
  },
  {
    "path": "src/axolotl/monkeypatch/stablelm_attn_hijack_flash.py",
    "content": "# coding=utf-8\n# Copyright 2023 Stability AI, EleutherAI, and The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\n# This code is based off the following work:\n# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py\n# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py\n\"\"\"PyTorch StableLM Epoch model.\"\"\"\n\nimport importlib\nimport math\nfrom typing import Optional, Tuple, Union\n\nimport torch\nimport torch.utils.checkpoint\nfrom accelerate import init_empty_weights\nfrom einops import rearrange\nfrom flash_attn.flash_attn_interface import (\n    flash_attn_varlen_qkvpacked_func,\n)\nfrom torch import nn\nfrom transformers import AutoConfig, AutoModelForCausalLM\nfrom transformers.modeling_outputs import BaseModelOutputWithPast\n\nfrom axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids\nfrom axolotl.utils.logging import get_logger\n\nlogger = get_logger(__name__)\n\n\ndef replace_stablelm_attn_with_flash_attn(model_name=\"stabilityai/stablelm-3b-4e1t\"):\n    # this is a wonky hack to get the remotely loaded module\n    model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)\n    # we need to load the model here in order for modeling_stablelm_epoch to be available\n    with init_empty_weights():\n        AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)\n    module_name = model_config.__class__.__module__.replace(\n        \".configuration_stablelm_epoch\", \".modeling_stablelm_epoch\"\n    )\n    modeling_stablelm = importlib.import_module(module_name)\n    modeling_stablelm.Attention.forward = flashattn_attn\n    modeling_stablelm.StableLMEpochModel.forward = stablelm_model_forward\n    modeling_stablelm.DecoderLayer.forward = decoder_layer_forward\n\n\ndef rotate_half(x: torch.Tensor):\n    \"\"\"Rotates half the hidden dims of the input.\"\"\"\n\n    x1, x2 = torch.chunk(x, 2, dim=-1)\n    return torch.cat((-x2, x1), dim=-1)\n\n\ndef apply_rotary_pos_emb(q, k, cos, sin, position_ids):\n    # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.\n\n    cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]\n    sin = sin.squeeze(1).squeeze(0)  # [seq_len, dim]\n    cos = cos[position_ids].unsqueeze(1)  # [batch_size, 1, seq_len, dim]\n    sin = sin[position_ids].unsqueeze(1)  # [batch_size, 1, seq_len, dim]\n    q_embed = (q * cos) + (rotate_half(q) * sin)\n    k_embed = (k * cos) + (rotate_half(k) * sin)\n    return q_embed, k_embed\n\n\ndef repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:\n    \"\"\"\n    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,\n    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)\n    \"\"\"\n    batch, num_key_value_heads, slen, head_dim = hidden_states.shape\n    if n_rep == 1:\n        return hidden_states\n    hidden_states = hidden_states[:, :, None, :, :].expand(\n        batch, num_key_value_heads, n_rep, slen, head_dim\n    )\n    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)\n\n\ndef flashattn_attn(\n    self,\n    hidden_states: torch.FloatTensor,\n    attention_mask: torch.FloatTensor,\n    position_ids: torch.LongTensor,\n    past_key_value: Optional[Tuple[torch.Tensor]] = None,\n    output_attentions: Optional[bool] = False,\n    use_cache: Optional[bool] = False,\n    cu_seqlens: Optional[torch.Tensor] = None,\n    max_seqlen: Optional[torch.Tensor] = None,\n) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n    bsz, q_len, _ = hidden_states.size()\n\n    query_states = self.q_proj(hidden_states)\n    key_states = self.k_proj(hidden_states)\n    value_states = self.v_proj(hidden_states)\n\n    query_states = query_states.view(\n        bsz, q_len, self.num_heads, self.head_dim\n    ).transpose(1, 2)\n    key_states = key_states.view(\n        bsz, q_len, self.num_key_value_heads, self.head_dim\n    ).transpose(1, 2)\n    value_states = value_states.view(\n        bsz, q_len, self.num_key_value_heads, self.head_dim\n    ).transpose(1, 2)\n\n    query_rot = query_states[..., : self.rotary_ndims]\n    query_pass = query_states[..., self.rotary_ndims :]\n    key_rot = key_states[..., : self.rotary_ndims]\n    key_pass = key_states[..., self.rotary_ndims :]\n\n    kv_seq_len = key_states.shape[-2]\n    if past_key_value is not None:\n        kv_seq_len += past_key_value[0].shape[-2]\n    cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)\n    query_states, key_states = apply_rotary_pos_emb(\n        query_rot, key_rot, cos, sin, position_ids\n    )\n\n    # [batch_size, num_heads, seq_len, head_dim]\n    query_states = torch.cat((query_states, query_pass), dim=-1)\n    key_states = torch.cat((key_states, key_pass), dim=-1)\n\n    if past_key_value is not None:\n        # Reuse k, v, self_attention\n        key_states = torch.cat((past_key_value[0], key_states), dim=2)\n        value_states = torch.cat((past_key_value[1], value_states), dim=2)\n\n    past_key_value = (key_states, value_states) if use_cache else None\n\n    # Repeat k/v heads if n_kv_heads < n_heads\n    key_states = repeat_kv(key_states, self.num_key_value_groups)\n    value_states = repeat_kv(value_states, self.num_key_value_groups)\n\n    if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:\n        # special handling using sample packing\n        qkv = torch.stack(\n            [query_states, key_states, value_states], dim=2\n        )  # [bsz, nh, 3, q_len, hd]\n        qkv = qkv.transpose(1, 3)  # [bsz, q_len, 3, nh, hd]\n        qkv = rearrange(qkv, \"b s ... -> (b s) ...\")\n        softmax_scale = None\n\n        output = flash_attn_varlen_qkvpacked_func(\n            qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=softmax_scale, causal=True\n        )\n\n        attn_output = rearrange(output, \"(b s) ... -> b s ...\", b=bsz)\n        attn_output = rearrange(attn_output, \"b s h d -> b s (h d)\")\n    else:\n        attn_weights = torch.matmul(\n            query_states, key_states.transpose(2, 3)\n        ) / math.sqrt(self.head_dim)\n\n        if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):\n            raise ValueError(\n                f\"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is\"\n                f\" {attn_weights.size()}\"\n            )\n\n        if attention_mask is not None:\n            if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):\n                raise ValueError(\n                    f\"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}\"\n                )\n            attn_weights = attn_weights + attention_mask\n\n        # Upcast attention to fp32\n        attn_weights = nn.functional.softmax(\n            attn_weights, dim=-1, dtype=torch.float32\n        ).to(query_states.dtype)\n        attn_output = torch.matmul(attn_weights, value_states)\n\n        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):\n            raise ValueError(\n                f\"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is\"\n                f\" {attn_output.size()}\"\n            )\n\n        # Merge heads\n        attn_output = attn_output.transpose(1, 2).contiguous()\n        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)\n\n    # Final linear projection\n    attn_output = self.o_proj(attn_output)\n\n    return attn_output, None, past_key_value\n\n\ndef decoder_layer_forward(\n    self,\n    hidden_states: Optional[torch.FloatTensor],\n    attention_mask: Optional[torch.FloatTensor] = None,\n    position_ids: Optional[torch.LongTensor] = None,\n    past_key_value: Optional[Tuple[torch.Tensor]] = None,\n    output_attentions: Optional[bool] = False,\n    use_cache: Optional[bool] = False,\n    cu_seqlens: Optional[torch.Tensor] = None,\n    max_seqlen: Optional[torch.Tensor] = None,\n) -> Union[\n    Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]\n]:\n    residual = hidden_states\n\n    hidden_states = self.input_layernorm(hidden_states)\n\n    # Self Attention\n    hidden_states, self_attn_weights, present_key_value = self.self_attn(\n        hidden_states=hidden_states,\n        attention_mask=attention_mask,\n        position_ids=position_ids,\n        past_key_value=past_key_value,\n        output_attentions=output_attentions,\n        use_cache=use_cache,\n        cu_seqlens=cu_seqlens,\n        max_seqlen=max_seqlen,\n    )\n    hidden_states = residual + hidden_states\n\n    # Fully Connected\n    residual = hidden_states\n    hidden_states = self.post_attention_layernorm(hidden_states)\n    hidden_states = self.mlp(hidden_states)\n    hidden_states = residual + hidden_states\n\n    outputs = (hidden_states,)\n\n    if output_attentions:\n        outputs += (self_attn_weights,)\n\n    if use_cache:\n        outputs += (present_key_value,)\n\n    return outputs\n\n\ndef stablelm_model_forward(\n    self,\n    input_ids: Optional[torch.LongTensor] = None,\n    attention_mask: Optional[torch.FloatTensor] = None,\n    position_ids: Optional[torch.LongTensor] = None,\n    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,\n    inputs_embeds: Optional[torch.FloatTensor] = None,\n    use_cache: Optional[bool] = None,\n    output_attentions: Optional[bool] = None,\n    output_hidden_states: Optional[bool] = None,\n    return_dict: Optional[bool] = None,\n) -> Union[Tuple, BaseModelOutputWithPast]:\n    output_attentions = (\n        output_attentions\n        if output_attentions is not None\n        else self.config.output_attentions\n    )\n    output_hidden_states = (\n        output_hidden_states\n        if output_hidden_states is not None\n        else self.config.output_hidden_states\n    )\n    use_cache = use_cache if use_cache is not None else self.config.use_cache\n\n    return_dict = (\n        return_dict if return_dict is not None else self.config.use_return_dict\n    )\n\n    # Retrieve input_ids and inputs_embeds\n    if input_ids is not None and inputs_embeds is not None:\n        raise ValueError(\n            \"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time\"\n        )\n    if input_ids is not None:\n        batch_size, seq_length = input_ids.shape\n    elif inputs_embeds is not None:\n        batch_size, seq_length, _ = inputs_embeds.shape\n    else:\n        raise ValueError(\n            \"You have to specify either decoder_input_ids or decoder_inputs_embeds\"\n        )\n\n    seq_length_with_past = seq_length\n    past_key_values_length = 0\n\n    if past_key_values is not None:\n        past_key_values_length = past_key_values[0][0].shape[2]\n        seq_length_with_past = seq_length_with_past + past_key_values_length\n\n    cu_seqlens = None\n    max_seqlen = None\n    if position_ids is None:\n        device = input_ids.device if input_ids is not None else inputs_embeds.device\n        position_ids = torch.arange(\n            past_key_values_length,\n            seq_length + past_key_values_length,\n            dtype=torch.long,\n            device=device,\n        )\n        position_ids = position_ids.unsqueeze(0).view(-1, seq_length)\n    else:\n        position_ids = position_ids.view(-1, seq_length).long()\n        cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids)\n        cu_seqlens = cu_seqlens.squeeze()\n\n    if inputs_embeds is None:\n        inputs_embeds = self.embed_tokens(input_ids)\n    # Embed positions\n    if attention_mask is None:\n        attention_mask = torch.ones(\n            (batch_size, seq_length_with_past),\n            dtype=torch.bool,\n            device=inputs_embeds.device,\n        )\n    attention_mask = self._prepare_decoder_attention_mask(\n        attention_mask,\n        (batch_size, seq_length),\n        inputs_embeds,\n        past_key_values_length,\n    )\n\n    hidden_states = inputs_embeds\n\n    if self.gradient_checkpointing and self.training:\n        if use_cache:\n            logger.warning(\n                \"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...\"\n            )\n            use_cache = False\n\n    # Decoder layers\n    all_hidden_states = () if output_hidden_states else None\n    all_self_attns = () if output_attentions else None\n    next_decoder_cache = () if use_cache else None\n\n    for idx, decoder_layer in enumerate(self.layers):\n        if output_hidden_states:\n            all_hidden_states += (hidden_states,)\n\n        past_key_value = past_key_values[idx] if past_key_values is not None else None\n\n        if self.gradient_checkpointing and self.training:\n\n            def create_custom_forward(module):\n                def custom_forward(*inputs):\n                    # None for past_key_value\n                    return module(*inputs)\n\n                return custom_forward\n\n            layer_outputs = torch.utils.checkpoint.checkpoint(\n                create_custom_forward(decoder_layer),\n                hidden_states,\n                attention_mask,\n                position_ids,\n                past_key_value,\n                output_attentions,\n                None,\n                cu_seqlens,\n                max_seqlen,\n            )\n        else:\n            layer_outputs = decoder_layer(\n                hidden_states,\n                attention_mask=attention_mask,\n                position_ids=position_ids,\n                past_key_value=past_key_value,\n                output_attentions=output_attentions,\n                use_cache=use_cache,\n                cu_seqlens=cu_seqlens,\n                max_seqlen=max_seqlen,\n            )\n\n        hidden_states = layer_outputs[0]\n\n        if use_cache:\n            next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)\n\n        if output_attentions:\n            all_self_attns += (layer_outputs[1],)\n\n    hidden_states = self.norm(hidden_states)\n\n    # Add hidden states from the last decoder layer\n    if output_hidden_states:\n        all_hidden_states += (hidden_states,)\n\n    next_cache = next_decoder_cache if use_cache else None\n    if not return_dict:\n        return tuple(\n            v\n            for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]\n            if v is not None\n        )\n    return BaseModelOutputWithPast(\n        last_hidden_state=hidden_states,\n        past_key_values=next_cache,\n        hidden_states=all_hidden_states,\n        attentions=all_self_attns,\n    )\n"
  },
  {
    "path": "src/axolotl/monkeypatch/tiled_mlp/__init__.py",
    "content": "\"\"\"\nTiledMLP monkey patches\n\"\"\"\n\nfrom .patch import (\n    patch_tiled_mlp,\n)\n\n__all__ = [\n    \"patch_tiled_mlp\",\n]\n"
  },
  {
    "path": "src/axolotl/monkeypatch/tiled_mlp/base.py",
    "content": "\"\"\"\nTiledMLP support for DDP, FSDP, and single GPU\n\"\"\"\n\nimport threading\nfrom typing import List\n\nimport torch\n\n\nclass DeepSpeedTiledMLPMoE(torch.autograd.Function):\n    @staticmethod\n    def forward(\n        ctx,\n        fn,\n        self,\n        x,\n        shards,\n        compute_params,\n    ) -> torch.Tensor:\n        ctx.fn = fn\n        ctx.self = self\n        ctx.shards = shards\n        ctx.compute_params = [p for p in compute_params if p.requires_grad]\n        ctx.save_for_backward(x)\n\n        x_shards = list(torch.chunk(x, chunks=shards, dim=1))\n        with torch.no_grad():\n            output_shards = [fn(self, x_shard) for x_shard in x_shards]\n\n        ctx.is_tuple_output = isinstance(output_shards[0], tuple)\n        if isinstance(output_shards[0], tuple):\n            tuple_dim_idx = [1, 0]\n            output_unsharded = tuple(\n                torch.cat(\n                    [output_shard[i] for output_shard in output_shards],\n                    dim=tuple_dim_idx[i],\n                )\n                for i in range(len(output_shards[0]))\n            )\n        else:\n            output_unsharded = torch.cat(output_shards, dim=1)\n\n        return output_unsharded\n\n    @staticmethod\n    def backward(ctx, *grads) -> torch.Tensor:\n        fn = ctx.fn\n        (x,) = ctx.saved_tensors\n        self = ctx.self\n        shards = ctx.shards\n        compute_params = ctx.compute_params\n        is_tuple_output = ctx.is_tuple_output\n\n        x_requires_grad = x.requires_grad\n        x = x.detach()\n        # detach() unsets `x.requires_grad`, so restore it\n        x.requires_grad_(x_requires_grad)\n\n        incoming_grad = grads[0]\n        x_grad = torch.zeros_like(x)\n        x_shards = list(torch.chunk(x, chunks=shards, dim=1))\n\n        shard_step = x_shards[0].numel()\n        for i, x_shard in enumerate(x_shards):\n            # Tell deepspeed not to add a new grad to its ipg bucket until the last shard is run\n            if compute_params is not None:\n                if i + 1 < shards:\n                    for param in compute_params:\n                        param.ds_grad_is_ready = False\n                else:\n                    # last shard, can add the grad\n                    for param in compute_params:\n                        param.ds_grad_is_ready = True\n\n            x_shard.requires_grad_(x_requires_grad)\n\n            shard_offset = i * shard_step\n            x_shard.grad = (\n                x_grad.view(-1)\n                .narrow(0, shard_offset, x_shard.numel())\n                .view_as(x_shard)\n            )\n            incoming_grad_shard = (\n                incoming_grad.view(-1)\n                .narrow(0, shard_offset, x_shard.numel())\n                .view_as(x_shard)\n            )\n            with torch.enable_grad():\n                output = fn(self, x_shard)\n            if is_tuple_output:\n                torch.autograd.backward(output[0], incoming_grad_shard)\n            else:\n                torch.autograd.backward(output, incoming_grad_shard)\n\n        return (None, None, x_grad, None, None)\n\n\nclass TiledMLP(torch.autograd.Function):\n    \"\"\"\n    TiledMLP implementation using gradient hooks\n    \"\"\"\n\n    @staticmethod\n    def forward(\n        ctx,\n        fn,\n        self,\n        x,\n        shards,\n        compute_params,\n    ) -> torch.Tensor:\n        ctx.fn = fn\n        ctx.self = self\n        ctx.shards = shards\n        ctx.compute_params = [p for p in compute_params if p.requires_grad]\n        ctx.save_for_backward(x)\n\n        x_shards = list(torch.chunk(x, chunks=shards, dim=1))\n        with torch.no_grad():\n            output_shards = [fn(self, x_shard) for x_shard in x_shards]\n        ctx.is_tuple_output = isinstance(output_shards[0], tuple)\n        if isinstance(output_shards[0], tuple):\n            tuple_dim_idx = [1, 0]\n            output_unsharded = tuple(\n                torch.cat(\n                    [output_shard[i] for output_shard in output_shards],\n                    dim=tuple_dim_idx[i],\n                )\n                for i in range(len(output_shards[0]))\n            )\n        else:\n            output_unsharded = torch.cat(output_shards, dim=1)\n\n        return output_unsharded\n\n    @staticmethod\n    def backward(ctx, *grads) -> torch.Tensor:\n        fn = ctx.fn\n        (x,) = ctx.saved_tensors\n        self = ctx.self\n        shards = ctx.shards\n        compute_params = ctx.compute_params\n        is_tuple_output = ctx.is_tuple_output\n\n        x_requires_grad = x.requires_grad\n        x = x.detach()\n        x.requires_grad_(x_requires_grad)\n\n        incoming_grad = grads[0]\n        x_grad = torch.zeros_like(x)\n        x_shards = list(torch.chunk(x, chunks=shards, dim=1))\n\n        # Create a gradient accumulator for parameters\n        grad_accumulator = GradientAccumulator(compute_params, shards, dtype=x.dtype)\n\n        shard_step = x_shards[0].numel()\n        for i, x_shard in enumerate(x_shards):\n            x_shard.requires_grad_(x_requires_grad)\n\n            shard_offset = i * shard_step\n            x_shard.grad = (\n                x_grad.view(-1)\n                .narrow(0, shard_offset, x_shard.numel())\n                .view_as(x_shard)\n            )\n            incoming_grad_shard = (\n                incoming_grad.view(-1)\n                .narrow(0, shard_offset, x_shard.numel())\n                .view_as(x_shard)\n            )\n\n            # Install hooks for this shard\n            is_last_shard = i + 1 == shards\n            grad_accumulator.install_hooks(is_last_shard)\n\n            with torch.enable_grad():\n                output = fn(self, x_shard)\n            if is_tuple_output:\n                torch.autograd.backward(output[0], incoming_grad_shard)\n            else:\n                torch.autograd.backward(output, incoming_grad_shard)\n\n        # Clean up hooks\n        grad_accumulator.cleanup()\n        del grad_accumulator\n\n        return (None, None, x_grad, None, None)\n\n\nclass GradientAccumulator:\n    \"\"\"\n    Manual gradient accumulator for TiledMLP with configurable precision\n    Accumulates in specified dtype and rescales the gradient at the end\n    \"\"\"\n\n    def __init__(\n        self,\n        params: List[torch.nn.Parameter],\n        total_shards: int,\n        dtype: torch.dtype | None = None,\n    ):\n        self.params = params\n        self.total_shards = total_shards\n        self.grad_accumulation_dtype = dtype or torch.float32\n        self.accumulated_grads = {}\n        self.hooks = []\n        self.lock = threading.Lock()\n        self.gradient_scale = 1.0 / total_shards\n\n        # Initialize accumulated gradients in the specified dtype\n        for param in self.params:\n            if param.grad is not None:\n                self.accumulated_grads[param] = param.grad.to(\n                    self.grad_accumulation_dtype\n                )\n                param.grad = None\n            else:\n                self.accumulated_grads[param] = torch.zeros_like(\n                    param, dtype=self.grad_accumulation_dtype\n                )\n\n    def install_hooks(self, is_last_shard: bool):\n        \"\"\"Install gradient hooks that accumulate gradients in higher precision\"\"\"\n\n        def create_hook(param):\n            def hook(grad):\n                with self.lock:\n                    grad_to_accum_dtype = grad.to(self.grad_accumulation_dtype)\n                    scaled_grad = grad_to_accum_dtype * self.gradient_scale\n\n                    if param in self.accumulated_grads:\n                        self.accumulated_grads[param] += scaled_grad\n                    else:\n                        self.accumulated_grads[param] = scaled_grad.clone()\n\n                    # Only assign the averaged gradient on the last shard\n                    if is_last_shard:\n                        param.grad = self.accumulated_grads[param].to(param.dtype)\n                        return param.grad\n                    return None\n\n            return hook\n\n        # Install hooks on all parameters\n        for param in self.params:\n            if param.requires_grad:\n                hook = param.register_hook(create_hook(param))\n                self.hooks.append(hook)\n\n    def cleanup(self):\n        \"\"\"Remove all installed hooks\"\"\"\n        for hook in self.hooks:\n            hook.remove()\n        self.hooks.clear()\n        del self.accumulated_grads\n"
  },
  {
    "path": "src/axolotl/monkeypatch/tiled_mlp/patch.py",
    "content": "\"\"\"Monkeypatch for Tiled MLP implementation\"\"\"\n\nimport math\nimport os\n\nimport torch\nimport torch.distributed as dist\n\nfrom axolotl.utils.callbacks.models import get_causal_lm_model_cls_prefix\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\n\ndef patch_tiled_mlp(model_type, use_original_mlp=True, cfg_num_shards=None):\n    from deepspeed.runtime.sequence_parallel.ulysses_sp import (\n        TiledMLP as DeepSpeedTiledMLP,\n    )\n\n    from axolotl.monkeypatch.tiled_mlp.base import DeepSpeedTiledMLPMoE, TiledMLP\n\n    try:\n        # Dynamically import the module and MLP class\n        module_path = f\"transformers.models.{model_type}.modeling_{model_type}\"\n        model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type)\n        module = __import__(module_path, fromlist=[f\"{model_cls_prefix}MLP\"])\n        mlp_cls = getattr(module, f\"{model_cls_prefix}MLP\")\n\n        if use_original_mlp:\n            mlp_forward = mlp_cls.forward\n        else:\n\n            def generic_mlp_forward(self_, hs):\n                return self_.down_proj(\n                    self_.act_fn(self_.gate_proj(hs)) * self_.up_proj(hs)\n                )\n\n            mlp_forward = torch.compile(generic_mlp_forward)\n\n        is_distributed = int(os.environ.get(\"WORLD_SIZE\", 1)) > 1\n\n        def tiled_mlp_forward(self, x):\n            input_shape = x.shape\n            seqlen = input_shape[-2]\n            hidden = input_shape[-1]\n            if cfg_num_shards is None:\n                num_shards = math.ceil(seqlen / hidden)\n                if is_distributed:\n                    num_shards_tensor = torch.tensor(num_shards, device=x.device)\n                    dist.all_reduce(num_shards_tensor, op=dist.ReduceOp.MAX)\n                    num_shards = num_shards_tensor.item()\n            else:\n                num_shards = cfg_num_shards\n\n            if not self._compute_params:\n                self._compute_params = [p for p in self.parameters() if p.requires_grad]\n\n            compute_params = self._compute_params\n            if not self._tiled_mlp_dist_impl:\n                if (\n                    self._compute_params\n                    and any(\n                        hasattr(p, \"ds_id\") or hasattr(p, \"param_idx_in_group\")\n                        for p in self._compute_params\n                    )\n                ) or os.environ.get(\"ACCELERATE_USE_DEEPSPEED\", \"false\") == \"true\":\n                    if model_type == \"gpt_oss\":\n                        self._tiled_mlp_dist_impl = DeepSpeedTiledMLPMoE\n                    else:\n                        self._tiled_mlp_dist_impl = DeepSpeedTiledMLP\n                else:\n                    self._tiled_mlp_dist_impl = TiledMLP\n\n            down_res = self._tiled_mlp_dist_impl.apply(\n                mlp_forward,\n                self,\n                x,\n                num_shards,\n                compute_params,\n            )\n            return down_res\n\n        mlp_cls.forward = tiled_mlp_forward\n        mlp_cls._compute_params = []\n        mlp_cls._tiled_mlp_dist_impl = None\n        LOG.info(\n            f\"Successfully monkey-patched TiledMLP for model_type: {model_type}\",\n        )\n    except (ImportError, AttributeError) as e:\n        raise RuntimeError(\n            f\"Could not import MLP class for model_type: {model_type}. Error: {str(e)}\"\n        ) from e\n"
  },
  {
    "path": "src/axolotl/monkeypatch/trainer/__init__.py",
    "content": "from .utils import entropy_from_logits, selective_log_softmax\n\n__all__ = [\"entropy_from_logits\", \"selective_log_softmax\"]\n"
  },
  {
    "path": "src/axolotl/monkeypatch/trainer/lr.py",
    "content": "\"\"\"\nmonkeypatch for Trainer _get_learning_rate method\n\"\"\"\n\nimport torch\n\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\n\n# TODO remove this patch once https://github.com/huggingface/transformers/pull/37881 is included in a release\ndef _get_learning_rate(self):\n    if self.is_deepspeed_enabled:\n        # with deepspeed's fp16 and dynamic loss scale enabled the optimizer/scheduler steps may\n        # not run for the first few dozen steps while loss scale is too large, and thus during\n        # that time `get_last_lr` will fail if called during that warm up stage, so work around it:\n        try:\n            last_lr = self.lr_scheduler.get_last_lr()[0]\n        except AssertionError as e:\n            if \"need to call step\" in str(e):\n                LOG.warning(\n                    \"tried to get lr value before scheduler/optimizer started stepping, returning lr=0\"\n                )\n                last_lr = 0\n            else:\n                raise\n    else:\n        if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):\n            last_lr = self.optimizer.param_groups[0][\"lr\"]\n        else:\n            last_lr = self.lr_scheduler.get_last_lr()[0]\n\n    if torch.is_tensor(last_lr):\n        last_lr = last_lr.item()\n    return last_lr\n\n\ndef patch_trainer_get_lr():\n    from transformers.trainer import Trainer\n\n    Trainer._get_learning_rate = _get_learning_rate\n"
  },
  {
    "path": "src/axolotl/monkeypatch/trainer/trl.py",
    "content": "\"\"\"Monkeypatch for TRL trainer FSDP preparation.\"\"\"\n\n\ndef prepare_fsdp(model, accelerator):\n    from axolotl.monkeypatch.accelerate.fsdp2 import fsdp2_prepare_model\n\n    return fsdp2_prepare_model(accelerator, model)\n\n\ndef patch_trl_prepare_fsdp2():\n    import trl.models.utils\n\n    trl.models.utils.prepare_fsdp = prepare_fsdp\n"
  },
  {
    "path": "src/axolotl/monkeypatch/trainer/trl_vllm.py",
    "content": "\"\"\"Monkeypatches for TRL's vLLM integration and trainer utils.\n\nAdds:\n- VLLMClient.batch_update_named_params: batched weight sync (fewer HTTP round-trips)\n- extract_logprobs: NaN→0.0 fix (prevents downstream NaN propagation)\n- VLLMGeneration: weight_sync_chunk_size + batched sync path for non-FSDP/non-ZeRO\n- split_tensor_dict / shuffle_sequence_dict: scalar type handling (int/float/bool passthrough)\n\"\"\"\n\nimport logging\nimport math\nfrom functools import wraps\n\nimport torch\nfrom torch import nn\n\nLOG = logging.getLogger(__name__)\n\n\ndef _batch_update_named_params(\n    self, params: list[tuple[str, torch.Tensor]], chunk_size: int | None = None\n):\n    \"\"\"Batched weight sync — sends param metadata via HTTP, tensors via NCCL.\"\"\"\n    from transformers import is_torch_xpu_available\n\n    if chunk_size is None:\n        chunks = [params]\n    else:\n        chunks = []\n        current_chunk: list[tuple[str, torch.Tensor]] = []\n        current_elements = 0\n        for name, weights in params:\n            n_elem = weights.numel()\n            if current_chunk and current_elements + n_elem > chunk_size:\n                chunks.append(current_chunk)\n                current_chunk = []\n                current_elements = 0\n            current_chunk.append((name, weights))\n            current_elements += n_elem\n        if current_chunk:\n            chunks.append(current_chunk)\n\n    for chunk in chunks:\n        param_metadata = [\n            {\"name\": name, \"dtype\": str(weights.dtype), \"shape\": list(weights.shape)}\n            for name, weights in chunk\n        ]\n        url = f\"{self.base_url}/batch_update_named_params/\"\n        response = self.session.post(url, json={\"params\": param_metadata})\n        if response.status_code != 200:\n            raise Exception(f\"Request failed: {response.status_code}, {response.text}\")\n\n        for _name, weights in chunk:\n            if is_torch_xpu_available():\n                self.communicator.broadcast(weights, root=self.rank)\n            else:\n                self.communicator.broadcast(weights, src=self.rank)\n\n        if is_torch_xpu_available():\n            self.communicator.barrier()\n        else:\n            self.communicator.group.barrier()\n\n\ndef _update_model_params(self, model: nn.Module, chunk_size: int | None = None):\n    \"\"\"Updates all model params using batch_update_named_params.\"\"\"\n    params = [(name, param.data) for name, param in model.named_parameters()]\n    self.batch_update_named_params(params, chunk_size=chunk_size)\n\n\ndef _patched_extract_logprobs(all_outputs):\n    \"\"\"extract_logprobs with NaN→0.0 fix (stock TRL uses None which causes downstream errors).\"\"\"\n    all_logprobs = []\n    all_token_ids = []\n\n    for outputs in all_outputs:\n        for output in outputs.outputs:\n            if output.logprobs is None:\n                return None, None\n            seq_logprobs = []\n            seq_token_ids = []\n            for lp in output.logprobs:\n                sorted_items = sorted(lp.items(), key=lambda x: x[1].rank)\n                seq_token_ids.append([token_id for token_id, _ in sorted_items])\n                seq_logprobs.append(\n                    [\n                        0.0 if math.isnan(item.logprob) else item.logprob\n                        for _, item in sorted_items\n                    ]\n                )\n            all_logprobs.append(seq_logprobs)\n            all_token_ids.append(seq_token_ids)\n\n    return all_logprobs, all_token_ids\n\n\ndef _patched_split_tensor_dict(tensor_dict, num_chunks):\n    \"\"\"split_tensor_dict that handles scalar types (int/float/bool) for num_items_in_batch.\"\"\"\n    first_tensor = next(\n        tensor\n        for tensor in tensor_dict.values()\n        if tensor is not None and isinstance(tensor, torch.Tensor) and tensor.ndim > 0\n    )\n    chunk_size = first_tensor.shape[0] // num_chunks\n    chunks = []\n    for i in range(num_chunks):\n        chunk_dict = {}\n        for key, tensor in tensor_dict.items():\n            if isinstance(tensor, (int, float, bool)):\n                chunk_dict[key] = tensor\n            elif tensor is not None and (isinstance(tensor, list) or tensor.ndim > 0):\n                chunk_dict[key] = tensor[i * chunk_size : (i + 1) * chunk_size]\n            elif tensor is not None and tensor.ndim == 0:\n                chunk_dict[key] = tensor\n            else:\n                chunk_dict[key] = None\n        chunks.append(chunk_dict)\n    return chunks\n\n\ndef _patched_shuffle_sequence_dict(seq_dict):\n    \"\"\"shuffle_sequence_dict that handles scalar types (int/float/bool).\"\"\"\n    first_seq = next(\n        v\n        for v in seq_dict.values()\n        if v is not None and isinstance(v, (torch.Tensor, list)) and len(v) > 0\n    )\n    perm = torch.randperm(len(first_seq))\n\n    def permute(v):\n        if v is None:\n            return None\n        if isinstance(v, (int, float, bool)):\n            return v\n        if isinstance(v, torch.Tensor) and v.ndim == 0:\n            return v\n        if isinstance(v, torch.Tensor) and v.ndim >= 1:\n            return v[perm]\n        if isinstance(v, list):\n            return [v[i] for i in perm.tolist()]\n        return v\n\n    return {k: permute(v) for k, v in seq_dict.items()}\n\n\ndef _patch_sync_weights_batched(original_init):\n    \"\"\"Wrap VLLMGeneration.__init__ to accept weight_sync_chunk_size.\"\"\"\n\n    @wraps(original_init)\n    def patched_init(self, *args, weight_sync_chunk_size=None, **kwargs):\n        original_init(self, *args, **kwargs)\n        self.weight_sync_chunk_size = weight_sync_chunk_size\n\n    return patched_init\n\n\ndef _make_batched_sync_weights(original_sync_weights):\n    \"\"\"Wrap sync_weights to use batched sync for non-FSDP/non-ZeRO paths.\"\"\"\n\n    @wraps(original_sync_weights)\n    def patched_sync_weights(self):\n        from accelerate.utils import is_peft_model\n\n        # Check if we're in a non-PEFT, non-FSDP, non-ZeRO scenario where batching helps\n        accelerator = self.accelerator\n        model = self.model\n        is_fsdp_enabled = self.is_fsdp_enabled\n\n        deepspeed_plugin = accelerator.state.deepspeed_plugin\n        zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3\n\n        is_peft = is_peft_model(model)\n\n        # If PEFT, FSDP, or ZeRO-3, fall back to original (which handles those cases)\n        if is_peft or is_fsdp_enabled or zero_stage_3:\n            return original_sync_weights(self)\n\n        # Non-PEFT, non-FSDP, non-ZeRO: use batched sync\n        if self.mode == \"colocate\" and getattr(self, \"enable_sleep_mode\", False):\n            from vllm.distributed.device_communicators.cuda_wrapper import (\n                empty_cache,\n            )\n\n            empty_cache()\n            self.llm.wake_up(tags=[\"weights\"])\n\n        if self.mode == \"server\" and accelerator.is_main_process:\n            params = [\n                (self._fix_param_name_to_vllm(name), param.data)\n                for name, param in model.named_parameters()\n            ]\n            self.vllm_client.batch_update_named_params(\n                params, chunk_size=getattr(self, \"weight_sync_chunk_size\", None)\n            )\n        elif self.mode == \"colocate\":\n            llm_model = (\n                self.llm.llm_engine.model_executor.driver_worker.model_runner.model\n            )\n            weights = [\n                (self._fix_param_name_to_vllm(name), param.data)\n                for name, param in model.named_parameters()\n            ]\n            llm_model.load_weights(weights=weights)\n\n        # Reset cache\n        if self.mode == \"server\" and accelerator.is_main_process:\n            self.vllm_client.reset_prefix_cache()\n        elif self.mode == \"colocate\":\n            self.llm.reset_prefix_cache()\n\n    return patched_sync_weights\n\n\ndef patch_trl_vllm():\n    \"\"\"Apply all TRL vLLM monkeypatches.\"\"\"\n    import trl.generation.vllm_client\n    import trl.generation.vllm_generation\n    import trl.trainer.utils\n\n    VLLMClient = trl.generation.vllm_client.VLLMClient\n    VLLMGeneration = trl.generation.vllm_generation.VLLMGeneration\n\n    # 1. Add batch_update_named_params to VLLMClient\n    if not hasattr(VLLMClient, \"batch_update_named_params\"):\n        VLLMClient.batch_update_named_params = _batch_update_named_params\n        VLLMClient.update_model_params = _update_model_params\n        LOG.info(\"Patched VLLMClient with batch_update_named_params\")\n\n    # 2. Patch extract_logprobs (NaN→0.0)\n    trl.generation.vllm_generation.extract_logprobs = _patched_extract_logprobs\n    LOG.info(\"Patched extract_logprobs with NaN→0.0 fix\")\n\n    # 3. Patch VLLMGeneration.__init__ to accept weight_sync_chunk_size\n    VLLMGeneration.__init__ = _patch_sync_weights_batched(VLLMGeneration.__init__)\n\n    # 4. Patch sync_weights for batched non-FSDP/non-ZeRO path\n    VLLMGeneration.sync_weights = _make_batched_sync_weights(\n        VLLMGeneration.sync_weights\n    )\n    LOG.info(\"Patched VLLMGeneration with batched sync_weights\")\n\n    # 5. Patch split_tensor_dict and shuffle_sequence_dict\n    trl.trainer.utils.split_tensor_dict = _patched_split_tensor_dict\n    trl.trainer.utils.shuffle_sequence_dict = _patched_shuffle_sequence_dict\n    LOG.info(\"Patched split_tensor_dict and shuffle_sequence_dict for scalar types\")\n"
  },
  {
    "path": "src/axolotl/monkeypatch/trainer/utils.py",
    "content": "# Copyright 2026 Axolotl AI. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nimport torch.nn.functional as F\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef _entropy_online_kernel(\n    logits_ptr,\n    output_ptr,\n    stride_row,\n    V: tl.constexpr,\n    BLOCK_V: tl.constexpr,\n):\n    \"\"\"Online entropy: single pass with running max correction.\"\"\"\n    row = tl.program_id(0)\n    row_ptr = logits_ptr + tl.cast(row, tl.int64) * stride_row\n\n    running_max = tl.full([], float(\"-inf\"), dtype=tl.float32)\n    running_sum_exp = tl.full([], 0.0, dtype=tl.float32)\n    running_weighted = tl.full([], 0.0, dtype=tl.float32)\n\n    for v_start in range(0, V, BLOCK_V):\n        offs = v_start + tl.arange(0, BLOCK_V)\n        mask = offs < V\n        x = tl.load(row_ptr + offs, mask=mask, other=float(\"-inf\")).to(tl.float32)\n\n        block_max = tl.max(x, axis=0)\n        new_max = tl.maximum(running_max, block_max)\n\n        correction = tl.exp(running_max - new_max)\n        running_sum_exp = running_sum_exp * correction\n        running_weighted = running_weighted * correction\n\n        exp_x = tl.exp(x - new_max)\n        exp_x = tl.where(mask, exp_x, 0.0)\n        x = tl.where(mask, x, 0.0)\n        running_sum_exp += tl.sum(exp_x, axis=0)\n        running_weighted += tl.sum(exp_x * x, axis=0)\n\n        running_max = new_max\n\n    entropy = tl.log(running_sum_exp) + running_max - running_weighted / running_sum_exp\n    tl.store(output_ptr + row, entropy)\n\n\n@triton.jit\ndef _entropy_online_kernel_strided(\n    logits_ptr,\n    output_ptr,\n    stride_outer,\n    stride_inner,\n    n_inner,\n    row_offset,\n    V: tl.constexpr,\n    BLOCK_V: tl.constexpr,\n):\n    \"\"\"Online entropy for non-contiguous 3D (B, L, V) tensors.\"\"\"\n    local_row = tl.program_id(0)\n    row = local_row + row_offset\n    outer_idx = row // n_inner\n    inner_idx = row % n_inner\n    off = outer_idx.to(tl.int64) * stride_outer + inner_idx.to(tl.int64) * stride_inner\n    row_ptr = logits_ptr + off\n\n    running_max = tl.full([], float(\"-inf\"), dtype=tl.float32)\n    running_sum_exp = tl.full([], 0.0, dtype=tl.float32)\n    running_weighted = tl.full([], 0.0, dtype=tl.float32)\n\n    for v_start in range(0, V, BLOCK_V):\n        offs = v_start + tl.arange(0, BLOCK_V)\n        mask = offs < V\n        x = tl.load(row_ptr + offs, mask=mask, other=float(\"-inf\")).to(tl.float32)\n\n        block_max = tl.max(x, axis=0)\n        new_max = tl.maximum(running_max, block_max)\n\n        correction = tl.exp(running_max - new_max)\n        running_sum_exp = running_sum_exp * correction\n        running_weighted = running_weighted * correction\n\n        exp_x = tl.exp(x - new_max)\n        exp_x = tl.where(mask, exp_x, 0.0)\n        x = tl.where(mask, x, 0.0)\n        running_sum_exp += tl.sum(exp_x, axis=0)\n        running_weighted += tl.sum(exp_x * x, axis=0)\n\n        running_max = new_max\n\n    entropy = tl.log(running_sum_exp) + running_max - running_weighted / running_sum_exp\n    tl.store(output_ptr + local_row, entropy)\n\n\ndef entropy_from_logits(logits: torch.Tensor, chunk_size: int = 128) -> torch.Tensor:\n    \"\"\"Triton-fused entropy (online single-pass). Handles non-contiguous tensors without copying.\"\"\"\n    original_shape = logits.shape[:-1]\n    V = logits.shape[-1]\n    N = 1\n    for s in original_shape:\n        N *= s\n\n    if not logits.is_cuda:\n        # CPU fallback: stable entropy via log_softmax\n        logp = F.log_softmax(logits.float(), dim=-1)\n        ent = -(logp.exp() * logp).sum(dim=-1)\n        return ent.to(logits.dtype).reshape(original_shape)\n\n    output = torch.empty(N, device=logits.device, dtype=torch.float32)\n\n    BLOCK_V = 4096\n    MAX_GRID_CONTIG = 8192\n    MAX_GRID_STRIDED = 2048\n\n    # Vocab (last) dim must be contiguous for coalesced loads\n    if logits.stride(-1) != 1:\n        logits = logits.contiguous()\n\n    if logits.is_contiguous():\n        flat_logits = logits.reshape(-1, V)\n        stride = flat_logits.stride(0)\n        for start in range(0, N, MAX_GRID_CONTIG):\n            n_rows = min(MAX_GRID_CONTIG, N - start)\n            _entropy_online_kernel[(n_rows,)](\n                flat_logits[start], output[start], stride, V=V, BLOCK_V=BLOCK_V\n            )\n    elif logits.ndim == 3:\n        stride_outer = logits.stride(0)\n        stride_inner = logits.stride(1)\n        n_inner = logits.shape[1]\n        for start in range(0, N, MAX_GRID_STRIDED):\n            n_rows = min(MAX_GRID_STRIDED, N - start)\n            _entropy_online_kernel_strided[(n_rows,)](\n                logits,\n                output[start],\n                stride_outer,\n                stride_inner,\n                n_inner,\n                start,\n                V=V,\n                BLOCK_V=BLOCK_V,\n            )\n    else:\n        logits = logits.contiguous()\n        flat_logits = logits.reshape(-1, V)\n        stride = flat_logits.stride(0)\n        for start in range(0, N, MAX_GRID_CONTIG):\n            n_rows = min(MAX_GRID_CONTIG, N - start)\n            _entropy_online_kernel[(n_rows,)](\n                flat_logits[start], output[start], stride, V=V, BLOCK_V=BLOCK_V\n            )\n\n    return output.to(logits.dtype).reshape(original_shape)\n\n\n# ---------------------------------------------------------------------------\n# selective_log_softmax — fused forward + backward Triton kernels\n# ---------------------------------------------------------------------------\n\n\ndef selective_log_softmax_original(logits, index) -> torch.Tensor:\n    \"\"\"Original selective_log_softmax (reference/fallback).\"\"\"\n    squeeze = index.ndim == logits.ndim - 1\n    if squeeze:\n        index = index.unsqueeze(-1)\n\n    if logits.dtype in [torch.float32, torch.float64]:\n        selected_logits = torch.gather(logits, dim=-1, index=index)\n        logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])\n        per_token_logps = selected_logits - logsumexp_values.unsqueeze(-1)\n    else:\n        per_token_logps = []\n        for row_logits, row_labels in zip(logits, index, strict=True):\n            row_logps = F.log_softmax(row_logits, dim=-1)\n            row_per_token_logps = row_logps.gather(dim=-1, index=row_labels)\n            per_token_logps.append(row_per_token_logps)\n        per_token_logps = torch.stack(per_token_logps)\n\n    if squeeze:\n        per_token_logps = per_token_logps.squeeze(-1)\n\n    return per_token_logps\n\n\n@triton.jit\ndef _selective_logsoftmax_fwd_kernel(\n    logits_ptr,\n    index_ptr,\n    output_ptr,\n    logsumexp_ptr,\n    stride_logits_row,\n    stride_index_row,\n    stride_output_row,\n    actual_K,\n    K_BLOCK: tl.constexpr,\n    V: tl.constexpr,\n    BLOCK_V: tl.constexpr,\n):\n    \"\"\"Forward: online logsumexp + gather. Saves logsumexp for backward.\"\"\"\n    row = tl.program_id(0)\n    logits_row_ptr = logits_ptr + tl.cast(row, tl.int64) * stride_logits_row\n\n    # Online logsumexp\n    running_max = tl.full([], float(\"-inf\"), dtype=tl.float32)\n    running_sum_exp = tl.full([], 0.0, dtype=tl.float32)\n\n    for v_start in range(0, V, BLOCK_V):\n        offs = v_start + tl.arange(0, BLOCK_V)\n        mask = offs < V\n        x = tl.load(logits_row_ptr + offs, mask=mask, other=float(\"-inf\")).to(\n            tl.float32\n        )\n\n        block_max = tl.max(x, axis=0)\n        new_max = tl.maximum(running_max, block_max)\n        running_sum_exp = running_sum_exp * tl.exp(running_max - new_max)\n\n        exp_x = tl.exp(x - new_max)\n        exp_x = tl.where(mask, exp_x, 0.0)\n        running_sum_exp += tl.sum(exp_x, axis=0)\n        running_max = new_max\n\n    lse = tl.log(running_sum_exp) + running_max\n    tl.store(logsumexp_ptr + row, lse)\n\n    # Gather and subtract\n    index_row_ptr = index_ptr + tl.cast(row, tl.int64) * stride_index_row\n    output_row_ptr = output_ptr + tl.cast(row, tl.int64) * stride_output_row\n\n    k_offs = tl.arange(0, K_BLOCK)\n    k_mask = k_offs < actual_K\n    indices = tl.load(index_row_ptr + k_offs, mask=k_mask, other=0).to(tl.int64)\n    valid_mask = k_mask & (indices >= 0) & (indices < V)\n    safe_indices = tl.where(valid_mask, indices, 0)\n    selected = tl.load(logits_row_ptr + safe_indices, mask=valid_mask, other=0.0).to(\n        tl.float32\n    )\n    tl.store(output_row_ptr + k_offs, selected - lse, mask=valid_mask)\n\n\n@triton.jit\ndef _selective_logsoftmax_bwd_kernel(\n    grad_output_ptr,\n    logits_ptr,\n    index_ptr,\n    logsumexp_ptr,\n    grad_logits_ptr,\n    stride_grad_out_row,\n    stride_logits_row,\n    stride_index_row,\n    stride_grad_logits_row,\n    actual_K,\n    K_BLOCK: tl.constexpr,\n    V: tl.constexpr,\n    BLOCK_V: tl.constexpr,\n):\n    \"\"\"Backward: d_logits[j] = -softmax(x)[j] * sum(grad_out) + (grad_out[k] if j == index[k]).\n\n    Single fused pass over V. For each tile, computes the base gradient and adds\n    scatter contributions inline by checking which indices fall in the current tile.\n    No separate scatter pass — no read-after-write issues.\n    \"\"\"\n    row = tl.program_id(0)\n    logits_row_ptr = logits_ptr + tl.cast(row, tl.int64) * stride_logits_row\n    grad_logits_row_ptr = (\n        grad_logits_ptr + tl.cast(row, tl.int64) * stride_grad_logits_row\n    )\n    grad_out_row_ptr = grad_output_ptr + tl.cast(row, tl.int64) * stride_grad_out_row\n    index_row_ptr = index_ptr + tl.cast(row, tl.int64) * stride_index_row\n\n    lse = tl.load(logsumexp_ptr + row).to(tl.float32)\n\n    # Load grad_output and indices (K_BLOCK elements, masked)\n    k_offs = tl.arange(0, K_BLOCK)\n    k_mask = k_offs < actual_K\n    grad_out = tl.load(grad_out_row_ptr + k_offs, mask=k_mask, other=0.0).to(tl.float32)\n    indices = tl.load(\n        index_row_ptr + k_offs, mask=k_mask, other=-1\n    )  # -1 = never matches\n    valid_mask = k_mask & (indices >= 0) & (indices < V)\n    grad_out = tl.where(valid_mask, grad_out, 0.0)\n    indices = tl.where(valid_mask, indices, -1)\n    grad_sum = tl.sum(grad_out, axis=0)\n\n    # Fused pass: for each tile, compute -softmax * grad_sum + scatter\n    for v_start in range(0, V, BLOCK_V):\n        offs = v_start + tl.arange(0, BLOCK_V)  # [BLOCK_V]\n        mask = offs < V\n        x = tl.load(logits_row_ptr + offs, mask=mask, other=0.0).to(tl.float32)\n        softmax_j = tl.exp(x - lse)\n        softmax_j = tl.where(mask, softmax_j, 0.0)\n        grad_j = -softmax_j * grad_sum\n\n        # Scatter: check which selected indices fall in this tile\n        # offs: [BLOCK_V], indices: [K_BLOCK]\n        # Broadcast: offs[:, None] == indices[None, :] → [BLOCK_V, K_BLOCK]\n        match = offs[:, None] == indices[None, :]  # [BLOCK_V, K_BLOCK]\n        # Sum grad_out contributions: for each position j, sum grad_out[k] where index[k]==j\n        scatter_contrib = tl.sum(\n            tl.where(match, grad_out[None, :], 0.0), axis=1\n        )  # [BLOCK_V]\n        grad_j += scatter_contrib\n\n        tl.store(grad_logits_row_ptr + offs, grad_j, mask=mask)\n\n\nclass _SelectiveLogSoftmaxTriton(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, flat_logits, flat_index, K, K_BLOCK, V, BLOCK_V, MAX_GRID):\n        N = flat_logits.shape[0]\n        output = torch.empty(N, K_BLOCK, device=flat_logits.device, dtype=torch.float32)\n        logsumexp = torch.empty(N, device=flat_logits.device, dtype=torch.float32)\n\n        for start in range(0, N, MAX_GRID):\n            n_rows = min(MAX_GRID, N - start)\n            _selective_logsoftmax_fwd_kernel[(n_rows,)](\n                flat_logits[start],\n                flat_index[start],\n                output[start],\n                logsumexp[start],\n                flat_logits.stride(0),\n                flat_index.stride(0),\n                output.stride(0),\n                K,\n                K_BLOCK=K_BLOCK,\n                V=V,\n                BLOCK_V=BLOCK_V,\n            )\n\n        ctx.save_for_backward(flat_logits, flat_index, logsumexp)\n        ctx.K = K\n        ctx.K_BLOCK = K_BLOCK\n        ctx.V = V\n        ctx.BLOCK_V = BLOCK_V\n        ctx.MAX_GRID = MAX_GRID\n        return output\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        flat_logits, flat_index, logsumexp = ctx.saved_tensors\n        K, K_BLOCK, V, BLOCK_V, MAX_GRID = (\n            ctx.K,\n            ctx.K_BLOCK,\n            ctx.V,\n            ctx.BLOCK_V,\n            ctx.MAX_GRID,\n        )\n        N = flat_logits.shape[0]\n\n        grad_logits = torch.empty_like(flat_logits)\n\n        # grad_output may have K_BLOCK cols; backward kernel reads actual_K\n        grad_output_contig = grad_output.contiguous()\n\n        for start in range(0, N, MAX_GRID):\n            n_rows = min(MAX_GRID, N - start)\n            _selective_logsoftmax_bwd_kernel[(n_rows,)](\n                grad_output_contig[start],\n                flat_logits[start],\n                flat_index[start],\n                logsumexp[start],\n                grad_logits[start],\n                grad_output_contig.stride(0),\n                flat_logits.stride(0),\n                flat_index.stride(0),\n                grad_logits.stride(0),\n                K,\n                K_BLOCK=K_BLOCK,\n                V=V,\n                BLOCK_V=BLOCK_V,\n            )\n\n        # Return grads for: flat_logits, flat_index, K, K_BLOCK, V, BLOCK_V, MAX_GRID\n        return grad_logits, None, None, None, None, None, None\n\n\ndef selective_log_softmax(logits, index) -> torch.Tensor:\n    \"\"\"\n    Fused selective_log_softmax with Triton forward+backward kernels.\n\n    Equivalent to: torch.gather(logits.log_softmax(-1), dim=-1, index=index)\n    \"\"\"\n    squeeze = index.ndim == logits.ndim - 1\n    if squeeze:\n        index = index.unsqueeze(-1)\n\n    if not logits.is_cuda or logits.dtype == torch.float64:\n        # Triton kernel computes in float32; fall back for float64 and CPU\n        return selective_log_softmax_original(\n            logits, index.squeeze(-1) if squeeze else index\n        )\n\n    V = logits.shape[-1]\n    K = index.shape[-1]\n    original_index_shape = index.shape\n\n    flat_logits = logits.reshape(-1, V).contiguous()\n    flat_index = index.reshape(-1, K).contiguous()\n\n    BLOCK_V = 4096\n    MAX_GRID = 8192\n    K_BLOCK = max(1, triton.next_power_of_2(K))\n\n    output = _SelectiveLogSoftmaxTriton.apply(\n        flat_logits, flat_index, K, K_BLOCK, V, BLOCK_V, MAX_GRID\n    )\n\n    if K_BLOCK != K:\n        output = output[:, :K]\n\n    per_token_logps = output.to(logits.dtype).reshape(original_index_shape)\n\n    if squeeze:\n        per_token_logps = per_token_logps.squeeze(-1)\n\n    return per_token_logps\n"
  },
  {
    "path": "src/axolotl/monkeypatch/trainer_accelerator_args.py",
    "content": "\"\"\"\nallow adding additional kwargs to Accelerator init\n\"\"\"\n\nimport inspect\n\nfrom transformers import Trainer\n\nfrom axolotl.monkeypatch.utils import detab_code\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\nORIGINAL_TRAINER_CODE = \"\"\"\n    # create accelerator object\n    self.accelerator = Accelerator(**args)\n\"\"\"\n\nPATCHED_TRAINER_CODE = \"\"\"\n    if hasattr(self, \"additional_accelerator_args\"):\n        additional_args = self.additional_accelerator_args(fp8=True, enable_fsdp_float8_all_gather={enable_fsdp_float8_all_gather}, **args)\n        if additional_args:\n            args.update(additional_args)\n\n    # create accelerator object\n    self.accelerator = Accelerator(**args)\n\"\"\"\n\n\ndef get_create_accelerate_code() -> str:\n    training_loop = inspect.getsource(Trainer.create_accelerator_and_postprocess)\n    return training_loop\n\n\ndef check_create_accelerate_code_is_patchable() -> bool:\n    create_code = get_create_accelerate_code()\n    create_code, _ = detab_code(create_code)\n    return ORIGINAL_TRAINER_CODE in create_code\n\n\ndef patch_create_accelerate_code_for_fp8(enable_fsdp_float8_all_gather: bool):\n    \"\"\"\n    Monkeypatch create_accelerator_and_postprocess so it checks for additional kwargs.\n    \"\"\"\n\n    try:\n        create_code = get_create_accelerate_code()\n    except OSError:\n        return\n    Trainer._original_create_accelerator_and_postprocess = create_code\n    create_code, _ = detab_code(create_code)\n    if ORIGINAL_TRAINER_CODE not in create_code:\n        return\n\n    patched_trainer_code = PATCHED_TRAINER_CODE.format(\n        enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather\n    )\n    create_code = create_code.replace(ORIGINAL_TRAINER_CODE, patched_trainer_code)\n    create_code = create_code.replace(\n        \"def create_accelerator_and_postprocess(\",\n        \"def fixed_create_accelerator_and_postprocess(\",\n        1,\n    )\n\n    # load imports necessary\n    import transformers.trainer\n\n    items_to_import = []\n    for item in dir(transformers.trainer):\n        if item in create_code:\n            items_to_import.append(item)\n\n    exec(\n        \"from transformers.trainer import (\"\n        + \", \".join(x for x in items_to_import)\n        + \")\",\n        globals(),\n    )\n    exec(create_code, globals())\n    LOG.info(\"patching create_accelerator_and_postprocess to allow for overrides\")\n    Trainer.create_accelerator_and_postprocess = (\n        fixed_create_accelerator_and_postprocess\n    )\n"
  },
  {
    "path": "src/axolotl/monkeypatch/trainer_fsdp_optim.py",
    "content": "\"\"\"\nfix for FSDP optimizer save in trainer w 4.47.0\n\"\"\"\n\nimport inspect\n\nfrom transformers import Trainer\n\nfrom axolotl.monkeypatch.utils import detab_code\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\nORIGINAL_TRAINER_CODE = \"\"\"\n                if delay_optimizer_creation:\n                    self.optimizer = self.accelerator.prepare(self.optimizer)\n\"\"\"\n\nPATCHED_TRAINER_CODE = \"\"\"\n                if delay_optimizer_creation:\n                    model = self.accelerator.prepare(self.model)\n\"\"\"\n\n\ndef get_training_loop_code() -> str:\n    training_loop = inspect.getsource(Trainer._inner_training_loop)\n    return training_loop\n\n\ndef check_training_loop_is_patchable() -> bool:\n    training_loop = get_training_loop_code()\n    training_loop, _ = detab_code(training_loop)\n    return ORIGINAL_TRAINER_CODE in training_loop\n\n\ndef patch_training_loop_for_fsdp():\n    \"\"\"\n    monkeypatch for fixing the training loop for fsdp with optimizer save\n    \"\"\"\n\n    try:\n        training_loop = get_training_loop_code()\n    except OSError:\n        return\n    Trainer._original_inner_training_loop = training_loop\n    training_loop, _ = detab_code(training_loop)\n    if ORIGINAL_TRAINER_CODE not in training_loop:\n        return\n\n    training_loop = training_loop.replace(ORIGINAL_TRAINER_CODE, PATCHED_TRAINER_CODE)\n    training_loop = training_loop.replace(\n        \"def _inner_training_loop(\",\n        \"def _fixed_inner_training_loop(\",\n        1,\n    )\n\n    # load imports necessary\n    import transformers.trainer\n\n    items_to_import = []\n    for item in dir(transformers.trainer):\n        if item in training_loop:\n            items_to_import.append(item)\n\n    exec(\n        \"from transformers.trainer import (\"\n        + \", \".join(x for x in items_to_import)\n        + \")\",\n        globals(),\n    )\n    exec(training_loop, globals())\n    LOG.info(\"patching _inner_training_loop for fsdp optimizer save\")\n    Trainer._inner_training_loop = _fixed_inner_training_loop\n"
  },
  {
    "path": "src/axolotl/monkeypatch/transformers/__init__.py",
    "content": ""
  },
  {
    "path": "src/axolotl/monkeypatch/transformers/trainer_context_parallel.py",
    "content": "\"\"\"Monkey patch to allow context parallelism with FlashAttention in HF Trainer.\"\"\"\n\nfrom __future__ import annotations\n\nimport importlib\nimport inspect\n\nfrom transformers import Trainer\n\nfrom axolotl.monkeypatch.utils import detab_code\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\nGUARD_PATTERN = 'if model.config._attn_implementation != \"sdpa\":'\nPATCHED_GUARD = 'if (attn_impl := (getattr(model.config, \"_attn_implementation\", None) or getattr(model.model.config, \"_attn_implementation\", None))) and attn_impl not in (\"sdpa\", \"flash_attention_2\"):'\n\n\ndef patch_prepare_context_parallel_inputs() -> None:\n    \"\"\"Relax the SDPA-only guard when running context parallelism with FlashAttention.\"\"\"\n    if getattr(Trainer, \"_axolotl_prepare_context_parallel_inputs_patched\", False):\n        LOG.debug(\"Trainer._prepare_context_parallel_inputs already patched\")\n        return\n\n    try:\n        original_source = inspect.getsource(Trainer._prepare_context_parallel_inputs)\n    except OSError as exc:  # pragma: no cover - occurs when source is unavailable\n        LOG.warning(\"Unable to patch Trainer._prepare_context_parallel_inputs: %s\", exc)\n        return\n\n    if GUARD_PATTERN not in original_source:\n        LOG.warning(\n            \"Expected guard not found in Trainer._prepare_context_parallel_inputs; \\n\"\n            \"skipping FlashAttention context parallelism patch\"\n        )\n        return\n\n    patched_source = original_source.replace(GUARD_PATTERN, PATCHED_GUARD)\n    patched_source, _ = detab_code(patched_source)\n    patched_source = patched_source.replace(\n        \"def _prepare_context_parallel_inputs(\",\n        \"def axolotl_prepare_context_parallel_inputs(\",\n        1,\n    )\n\n    module_name = Trainer.__module__\n    module = importlib.import_module(module_name)\n\n    # import symbols referenced in the method so exec can succeed\n    items_to_import = []\n    for item in dir(module):\n        if item in patched_source:\n            items_to_import.append(item)\n\n    # Use a separate namespace to capture the exec'd function\n    namespace = {}\n    exec(f\"from {module_name} import ({', '.join(items_to_import)})\", namespace)\n    exec(patched_source, namespace)\n\n    # Explicitly get the function from the namespace\n    axolotl_prepare_context_parallel_inputs = namespace[\n        \"axolotl_prepare_context_parallel_inputs\"\n    ]\n    Trainer._original_prepare_context_parallel_inputs = (\n        Trainer._prepare_context_parallel_inputs\n    )\n    Trainer._prepare_context_parallel_inputs = axolotl_prepare_context_parallel_inputs\n    Trainer._axolotl_prepare_context_parallel_inputs_source = patched_source\n    Trainer._axolotl_prepare_context_parallel_inputs_patched = True\n    LOG.debug(\n        \"Patched Trainer._prepare_context_parallel_inputs for FlashAttention + CP\"\n    )\n"
  },
  {
    "path": "src/axolotl/monkeypatch/transformers/trainer_loss_calc.py",
    "content": "\"\"\"\nModule for patching transformers Trainer loss calculation to use nanmean.\n\nThis is needed for context parallelism since chunks of the input sequences may be fully\nmasked and return NaNs in the loss calculation.\n\nAlso includes a patch for FSDP2 + torch.compile. We need to bundle this together with\nthe other evaluation_loop patch because we can't patch the same code twice without\nraising an OSError.\n\"\"\"\n\nimport importlib\nimport inspect\n\nfrom transformers import Trainer\n\nfrom axolotl.monkeypatch.utils import detab_code\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\nORIGINAL_EVAL_CODE = {\n    \"list\": 'metrics[f\"{metric_key_prefix}_loss\"] = np.concatenate(all_losses).mean().item()',\n    \"array\": 'metrics[f\"{metric_key_prefix}_loss\"] = all_losses.mean().item()',\n}\nPATCHED_EVAL_CODE = {\n    \"list\": 'metrics[f\"{metric_key_prefix}_loss\"] = np.nanmean(np.concatenate(all_losses)).item()',\n    \"array\": 'metrics[f\"{metric_key_prefix}_loss\"] = np.nanmean(all_losses).item()',\n}\n\nORIGINAL_MAYBE_CODE = (\n    \"tr_loss_scalar = nested_gather(tr_loss, self.args.parallel_mode).mean().item()\"\n)\nPATCHED_MAYBE_CODE = (\n    \"tr_loss_scalar = nested_gather(tr_loss, self.args.parallel_mode).nanmean().item()\"\n)\n\n\ndef check_evaluation_loop_is_patchable() -> bool:\n    evaluation_loop_source = inspect.getsource(Trainer.evaluation_loop)\n    return all(value in evaluation_loop_source for value in ORIGINAL_EVAL_CODE.values())\n\n\ndef patch_evaluation_loop():\n    \"\"\"Patch the evaluation_loop method.\"\"\"\n    # Check if already patched\n    if hasattr(Trainer, \"_original_evaluation_loop\"):\n        LOG.debug(\"Trainer.evaluation_loop already patched\")\n        return\n\n    # Check if the patterns exist\n    try:\n        evaluation_loop_source = inspect.getsource(Trainer.evaluation_loop)\n    except OSError:\n        return\n    Trainer.evaluation = evaluation_loop_source\n    evaluation_loop_source, _ = detab_code(evaluation_loop_source)\n\n    # Apply the nanmean patches\n    evaluation_loop_source = evaluation_loop_source.replace(\n        ORIGINAL_EVAL_CODE[\"list\"], PATCHED_EVAL_CODE[\"list\"]\n    )\n    evaluation_loop_source = evaluation_loop_source.replace(\n        ORIGINAL_EVAL_CODE[\"array\"], PATCHED_EVAL_CODE[\"array\"]\n    )\n\n    # Rename the function to avoid conflicts\n    evaluation_loop_source = evaluation_loop_source.replace(\n        \"def evaluation_loop(\",\n        \"def axolotl_evaluation_loop(\",\n        1,\n    )\n\n    # Get the module for necessary imports\n    module_name = Trainer.__module__\n    module = importlib.import_module(module_name)\n\n    # Import necessary items from the module\n    items_to_import = []\n    for item in dir(module):\n        if item in evaluation_loop_source:\n            items_to_import.append(item)\n\n    # Execute the imports and patched method\n    exec(\n        f\"from {module_name} import ({', '.join(items_to_import)})\",\n        globals(),\n    )\n    exec(evaluation_loop_source, globals())\n\n    LOG.debug(\"Patched Trainer.evaluation_loop with nanmean loss calculation\")\n    Trainer.evaluation_loop = axolotl_evaluation_loop\n\n\ndef check_maybe_log_save_evaluate_is_patchable() -> bool:\n    maybe_log_source = inspect.getsource(Trainer._maybe_log_save_evaluate)\n    return ORIGINAL_MAYBE_CODE in maybe_log_source\n\n\ndef patch_maybe_log_save_evaluate():\n    \"\"\"Patch the _maybe_log_save_evaluate method.\"\"\"\n    # Check if already patched\n    if hasattr(Trainer, \"_original_maybe_log_save_evaluate\"):\n        LOG.info(\"Trainer._maybe_log_save_evaluate already patched\")\n        return\n\n    # Check if the patterns exist\n    try:\n        maybe_log_source = inspect.getsource(Trainer._maybe_log_save_evaluate)\n    except OSError:\n        return\n    Trainer._original_maybe_log_save_evaluate = maybe_log_source\n    maybe_log_source, _ = detab_code(maybe_log_source)\n\n    # Apply the patch\n    maybe_log_source = maybe_log_source.replace(ORIGINAL_MAYBE_CODE, PATCHED_MAYBE_CODE)\n\n    # Rename the function to avoid conflicts\n    maybe_log_source = maybe_log_source.replace(\n        \"def _maybe_log_save_evaluate(\",\n        \"def axolotl_maybe_log_save_evaluate(\",\n        1,\n    )\n\n    # Get the module for necessary imports\n    module_name = Trainer.__module__\n    module = importlib.import_module(module_name)\n\n    # Import necessary items from the module\n    items_to_import = []\n    for item in dir(module):\n        if item in maybe_log_source:\n            items_to_import.append(item)\n\n    # Execute the imports and patched method\n    exec(\n        f\"from {module_name} import ({', '.join(items_to_import)})\",\n        globals(),\n    )\n    exec(maybe_log_source, globals())\n\n    LOG.debug(\"Patched Trainer._maybe_log_save_evaluate with nanmean loss calculation\")\n    Trainer._maybe_log_save_evaluate = axolotl_maybe_log_save_evaluate\n"
  },
  {
    "path": "src/axolotl/monkeypatch/transformers_fa_utils.py",
    "content": "\"\"\"\nsee https://github.com/huggingface/transformers/pull/35834\n\"\"\"\n\nfrom functools import partial\nfrom typing import Optional\n\nimport torch\n\nfrom axolotl.utils.logging import get_logger\n\nlogger = get_logger(__name__)\n\n\ndef fixed_fa_peft_integration_check(\n    query: torch.Tensor,\n    key: torch.Tensor,\n    value: torch.Tensor,\n    target_dtype: Optional[torch.dtype] = None,\n    preferred_dtype: Optional[torch.dtype] = None,\n):\n    \"\"\"\n    PEFT usually casts the layer norms in float32 for training stability reasons\n    therefore the input hidden states gets silently casted in float32. Hence, we need\n    cast them back in float16 / bfloat16 just to be sure everything works as expected.\n    This might slowdown training & inference so it is recommended to not cast the LayerNorms!\n\n    Args:\n        query (`torch.Tensor`):\n            Input query states to be passed to Flash Attention API\n        key (`torch.Tensor`):\n            Input key states to be passed to Flash Attention API\n        value (`torch.Tensor`):\n            Input value states to be passed to Flash Attention API\n        target_dtype (`torch.dtype`, *optional*):\n            The dtype to convert the attention tensors to. Conversion can be ignored by\n            not providing the target dtype.\n        preferred_dtype (`torch.dtype`, *optional*):\n            The preferred dtype to convert the attention tensors to regardless of the\n            target dtype.\n    \"\"\"\n    if target_dtype is None and preferred_dtype is None:\n        return query, key, value\n\n    if preferred_dtype and target_dtype != preferred_dtype:\n        target_dtype = preferred_dtype\n\n    # check if any of query, key, or value are in float32. If so, cast them back to target dtype.\n    if any(module.dtype == torch.float32 for module in [query, key, value]):\n        logger.warning_once(\n            f\"The input hidden states seems to be silently casted in float32, this might be related to\"\n            f\" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in\"\n            f\" {target_dtype}.\"\n        )\n\n        query = query.to(target_dtype)\n        key = key.to(target_dtype)\n        value = value.to(target_dtype)\n\n    return query, key, value\n\n\ndef patch_fa_peft_integration():\n    import transformers.modeling_flash_attention_utils\n\n    transformers.modeling_flash_attention_utils.fa_peft_integration_check = partial(\n        fixed_fa_peft_integration_check, preferred_dtype=None\n    )\n"
  },
  {
    "path": "src/axolotl/monkeypatch/unsloth_.py",
    "content": "\"\"\"module for patching with unsloth optimizations\"\"\"\n\nimport inspect\nimport types\n\nimport torch\nfrom peft import PeftModelForCausalLM\nfrom torch import nn\nfrom transformers.models.llama.modeling_llama import LlamaFlashAttention2\n\nfrom axolotl.monkeypatch.utils import detab_code\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\nORIGINAL_QKV_CODE = \"\"\"\n    query_states = self.q_proj(hidden_states)\n    key_states = self.k_proj(hidden_states)\n    value_states = self.v_proj(hidden_states)\n\"\"\".lstrip(\"\\n\")\n\nPATCHED_QKV_CODE = \"\"\"\n    query_states, key_states, value_states = self.apply_qkv(self, hidden_states)\n\"\"\".lstrip(\"\\n\")\n\nORIGINAL_O_CODE = \"\"\"\n    attn_output = self.o_proj(attn_output)\n\"\"\".lstrip(\"\\n\")\n\nPATCHED_O_CODE = \"\"\"\n    attn_output = self.apply_o(self, attn_output)\n\"\"\".lstrip(\"\\n\")\n\n\ndef original_apply_qkv(self, hidden_states):\n    query_states = self.q_proj(hidden_states)\n    key_states = self.k_proj(hidden_states)\n    value_states = self.v_proj(hidden_states)\n    return query_states, key_states, value_states\n\n\ndef original_apply_o(self, hidden_states):\n    attn_output = self.o_proj(hidden_states)\n    return attn_output\n\n\ndef get_self_attn_code() -> str:\n    forward = inspect.getsource(LlamaFlashAttention2.forward)\n    return forward\n\n\ndef check_self_attn_is_patchable() -> bool:\n    qkv = get_self_attn_code()\n    qkv, _ = detab_code(qkv)\n    return ORIGINAL_QKV_CODE in qkv and ORIGINAL_O_CODE in qkv\n\n\ndef integrate_cross_entropy_loss_patch(model_type: str = \"llama\") -> None:\n    from unsloth.kernels.cross_entropy_loss import fast_cross_entropy_loss\n\n    def UnslothForCausalLMLoss(\n        logits,\n        labels,\n        vocab_size: int,\n        num_items_in_batch: int = None,\n        ignore_index: int = -100,\n        **kwargs,\n    ):\n        # Upcast to float if we need to compute the loss to avoid potential precision issues\n        logits = logits.float()\n        # Shift so that tokens < n predict n\n        shift_logits = logits[..., :-1, :].contiguous()\n        shift_labels = labels[..., 1:].contiguous()\n\n        loss = fast_cross_entropy_loss(\n            logits=shift_logits, labels=shift_labels, n_items=num_items_in_batch\n        )\n        return loss\n\n    if model_type == \"llama\":\n        from transformers.loss import loss_utils\n\n        loss_utils.ForCausalLMLoss = UnslothForCausalLMLoss  # type: ignore[assignment]\n    else:\n        raise ValueError(\"Unsupported model type\")\n\n\nself_attn_lora_patched = False\n\n\ndef patch_self_attn_lora():\n    global self_attn_lora_patched\n    if self_attn_lora_patched:\n        # prevent patching multiple times\n        return\n    self_attn_forward = get_self_attn_code()\n    LlamaFlashAttention2._original_forward = self_attn_forward\n    self_attn_forward, _ = detab_code(self_attn_forward)\n    assert ORIGINAL_QKV_CODE in self_attn_forward, \"Original qkv code not found\"\n    assert ORIGINAL_O_CODE in self_attn_forward, \"Original o code not found\"\n\n    self_attn_forward = self_attn_forward.replace(ORIGINAL_QKV_CODE, PATCHED_QKV_CODE)\n    self_attn_forward = self_attn_forward.replace(ORIGINAL_O_CODE, PATCHED_O_CODE)\n    self_attn_forward = self_attn_forward.replace(\n        \"def forward(\",\n        \"def unsloth_attn_forward(\",\n        1,\n    )\n\n    # load imports necessary\n    import transformers.models.llama.modeling_llama\n\n    items_to_import = []\n    for item in dir(transformers.models.llama.modeling_llama):\n        if item in self_attn_forward:\n            items_to_import.append(item)\n\n    exec(\n        \"from transformers.models.llama.modeling_llama import (\"\n        + \", \".join(x for x in items_to_import)\n        + \")\",\n        globals(),\n    )\n    exec(self_attn_forward, globals())\n    self_attn_lora_patched = True\n    LOG.info(\"patching unsloth attn lora\")\n    LlamaFlashAttention2.forward = unsloth_attn_forward\n\n\ndef integrate_rope_embeddings():\n    import transformers.models.llama.modeling_llama\n    from unsloth.kernels.rope_embedding import fast_rope_embedding\n\n    def apply_rotary_pos_emb(\n        q,\n        k,\n        cos,\n        sin,\n        position_ids=None,\n        unsqueeze_dim=1,\n    ):\n        return fast_rope_embedding(q, k, cos, sin)\n\n    LOG.info(\"patching unsloth RoPE embeddings\")\n    transformers.models.llama.modeling_llama.apply_rotary_pos_emb = apply_rotary_pos_emb\n\n\ndef integrate_lora_mlp_patch(peft_model: PeftModelForCausalLM):\n    if peft_model.base_model.config.model_type in [\"llama\", \"mistral\"]:\n        from unsloth.kernels import apply_lora_mlp_swiglu\n\n        apply_lora_mlp = apply_lora_mlp_swiglu\n    elif peft_model.base_model.config.model_type == \"gemma\":\n        from unsloth.kernels import apply_lora_mlp_geglu_approx\n\n        apply_lora_mlp = apply_lora_mlp_geglu_approx\n    else:\n        raise NotImplementedError(\n            f\"Model type {peft_model.base_model.config.model_type} not supported\"\n        )\n\n    for idx, layer in enumerate(peft_model.model.model.layers):\n        layer_modules = [\n            getattr(layer.mlp, linear_proj)\n            for linear_proj in [\"gate_proj\", \"up_proj\", \"down_proj\"]\n        ]\n        is_mlp_lora = all(hasattr(module, \"lora_A\") for module in layer_modules)\n        mlp_no_bias = all(\n            getattr(module, \"base_layer\", module).bias is None\n            for module in layer_modules\n        )\n        mlp_not_dora = all(\n            len(getattr(module, \"lora_magnitude_vector\", []) or []) == 0\n            for module in layer_modules\n        )\n\n        if is_mlp_lora and mlp_no_bias and mlp_not_dora:\n            layer.mlp.forward = types.MethodType(apply_lora_mlp, layer.mlp)\n        else:\n            LOG.warning(f\"unable to apply unsloth lora mlp patch to layer {idx}\")\n\n\ndef integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):\n    from unsloth.kernels import apply_lora_o, apply_lora_qkv\n\n    for idx, layer in enumerate(peft_model.model.model.layers):\n        if cfg.unsloth_lora_qkv:\n            layer_modules = [\n                getattr(layer.self_attn, linear_proj)\n                for linear_proj in [\"q_proj\", \"k_proj\", \"v_proj\"]\n            ]\n            is_qkv_lora = all(hasattr(module, \"lora_A\") for module in layer_modules)\n            qkv_no_bias = all(\n                getattr(module, \"base_layer\", module).bias is None\n                for module in layer_modules\n            )\n            qkv_not_dora = all(\n                len(getattr(module, \"lora_magnitude_vector\", []) or []) == 0\n                for module in layer_modules\n            )\n\n            if is_qkv_lora and qkv_no_bias and qkv_not_dora:\n                layer.self_attn.apply_qkv = apply_lora_qkv\n            else:\n                layer.self_attn.apply_qkv = original_apply_qkv\n                LOG.warning(f\"unable to apply unsloth lora qkv patch to layer {idx}\")\n        if cfg.unsloth_lora_o:\n            layer_modules = [\n                getattr(layer.self_attn, linear_proj) for linear_proj in [\"o_proj\"]\n            ]\n            is_o_lora = all(hasattr(module, \"lora_A\") for module in layer_modules)\n            o_no_bias = all(\n                getattr(module, \"base_layer\", module).bias is None\n                for module in layer_modules\n            )\n            o_not_dora = all(\n                len(getattr(module, \"lora_magnitude_vector\", []) or []) == 0\n                for module in layer_modules\n            )\n\n            if is_o_lora and o_no_bias and o_not_dora:\n                layer.self_attn.apply_o = apply_lora_o\n            else:\n                layer.self_attn.apply_o = original_apply_o\n                LOG.warning(f\"unable to apply unsloth lora o_proj patch to layer {idx}\")\n\n\ndef patch_unsloth_layernorm():\n    try:\n        import transformers.models.llama.modeling_llama\n        from unsloth.kernels.rms_layernorm import Fast_RMS_Layernorm\n\n        class LlamaRMSNorm(nn.Module):\n            \"\"\"LlamaRMSNorm\"\"\"\n\n            def __init__(self, hidden_size, eps=1e-6):\n                \"\"\"\n                LlamaRMSNorm is equivalent to T5LayerNorm\n                \"\"\"\n                super().__init__()\n                self.weight = nn.Parameter(torch.ones(hidden_size))\n                self.variance_epsilon = eps\n\n            def forward(self, hidden_states):\n                return Fast_RMS_Layernorm.apply(\n                    hidden_states, self.weight, self.variance_epsilon, False\n                )\n\n        LOG.info(\"patching with unsloth.kernels.rms_layernorm\")\n        transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm\n    except ImportError:\n        LOG.warning(\"missing unsloth library\")\n"
  },
  {
    "path": "src/axolotl/monkeypatch/utils.py",
    "content": "\"\"\"\nShared utils for the monkeypatches\n\"\"\"\n\nimport re\nfrom typing import Tuple\n\nimport torch\nimport torch.nn.functional as F\n\n\n@torch.jit.script\ndef get_max_seqlen_in_batch(attention_mask: torch.Tensor) -> torch.Tensor:\n    max_num = int(torch.max(attention_mask).item())\n    batch_size, _ = attention_mask.shape\n    counts = torch.zeros((batch_size, max_num), dtype=torch.int32)\n    for i in range(1, max_num + 1):\n        mask = attention_mask == i\n        counts[:, i - 1] = torch.sum(mask, dim=-1).to(dtype=torch.int32)\n    result = counts.flatten()\n    nonzero_indices = torch.nonzero(result).squeeze(-1)\n    return result[nonzero_indices]\n\n\n@torch.jit.script\ndef get_unpad_data(attention_mask: torch.Tensor):\n    device = attention_mask.device\n    seqlens_in_batch = get_max_seqlen_in_batch(attention_mask)\n    indices = torch.nonzero(attention_mask.flatten()).flatten()\n    max_seqlen_in_batch = seqlens_in_batch.max().item()\n    cu_seqlens = (\n        F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))\n        .to(device=device)\n        .detach()\n    )\n    return (\n        indices,\n        cu_seqlens,\n        max_seqlen_in_batch,\n    )\n\n\ndef get_cu_seqlens(attn_mask):\n    \"\"\"generate a cumulative sequence length mask for flash attention using attn mask\"\"\"\n    if len(attn_mask.shape) == 1:\n        attn_mask = attn_mask.unsqueeze(0)\n\n    device = attn_mask.device\n    results = []\n    max_seq_lens = []\n\n    for row in attn_mask:\n        # Exclude zeros to avoid adding their positions to the mask\n        t_non_zeros = row[row != 0]\n        # Find where the sequence number changes (including the first position)\n        seq_change = torch.cat(\n            [\n                torch.tensor([1], dtype=torch.int32, device=device),\n                t_non_zeros[1:] != t_non_zeros[:-1],\n            ]\n        )\n        # Get the indices where the sequence changes\n        change_indices = torch.cat(\n            [\n                (seq_change == 1).nonzero(as_tuple=True)[0],\n                torch.tensor([len(t_non_zeros)], dtype=torch.int32, device=device),\n            ]\n        )\n        # Calculate the sequence lengths\n        seq_lengths = change_indices[1:] - change_indices[:-1]\n        # Calculate the length of the final sequence or padding\n        final_seq_length = len(row) - change_indices[-1]\n        # Append the length of the final sequence or padding to seq_lengths\n        if final_seq_length.item():\n            seq_lengths = torch.cat(\n                [\n                    seq_lengths,\n                    torch.tensor(\n                        [final_seq_length.item()], dtype=torch.int32, device=device\n                    ),\n                ]\n            )\n        # Calculate the cumulative sequence lengths\n        cu_seqlens = torch.cat(\n            [torch.tensor([0], dtype=torch.int32, device=device), seq_lengths.cumsum(0)]\n        )\n        max_seq_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max()\n        results.append(cu_seqlens)\n        max_seq_lens.append(max_seq_len)\n\n    return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)\n\n\ndef get_cu_seqlens_from_pos_ids(\n    position_ids: torch.Tensor,\n) -> tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"generate a cumulative sequence length mask for flash attention using pos ids\"\"\"\n    if len(position_ids.shape) == 1:\n        position_ids = position_ids.unsqueeze(0)\n\n    device = position_ids.device\n    results = []\n    max_seq_lens = []\n\n    for row in position_ids:\n        # Count the number of consecutive zeros from the right side\n        padding_length = (row == 0).int().flip(dims=[0]).cumprod(dim=0).sum().item()\n\n        # Adjust the row to exclude padding\n        adjusted_row = row[:-padding_length] if padding_length else row.clone()\n\n        # Find where the position resets to 0 (indicating a new sequence)\n        seq_starts = torch.cat(\n            [\n                torch.tensor([True], dtype=torch.bool, device=device),\n                adjusted_row[1:] == 0,\n            ]\n        )\n        # Get the indices where the sequence starts\n        start_indices = torch.cat(\n            [\n                torch.nonzero(seq_starts).unbind(dim=1)[0],\n                torch.tensor([len(adjusted_row)], dtype=torch.int32, device=device),\n            ]\n        )\n        # Calculate the sequence lengths\n        seq_lengths = start_indices[1:] - start_indices[:-1]\n        # Calculate the cumulative sequence lengths\n        cu_seqlens = torch.cat(\n            [torch.tensor([0], dtype=torch.int32, device=device), seq_lengths.cumsum(0)]\n        )\n        # Append the padding length to the cumulative sequence lengths\n        if padding_length:\n            cu_seqlens = torch.cat(\n                [cu_seqlens, torch.tensor([len(row)], dtype=torch.int32, device=device)]\n            )\n        max_seq_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max()\n        results.append(cu_seqlens)\n        max_seq_lens.append(max_seq_len)\n\n    # Find the maximum value across all tensors\n    max_value = max(t.max() for t in results)\n\n    # Find the length of the longest tensor\n    max_length = max(t.size(0) for t in results)\n\n    # Pad each tensor to the same length and collect them in a list\n    padded_results = [\n        F.pad(t, (0, max_length - t.size(0)), \"constant\", max_value) for t in results\n    ]\n\n    return torch.stack(padded_results).to(dtype=torch.int32), torch.stack(max_seq_lens)\n\n\ndef set_module_name(model, name, value):\n    if \".\" in name:\n        parent_name = name.rsplit(\".\", 1)[0]\n        child_name = name[len(parent_name) + 1 :]\n        parent = model.get_submodule(parent_name)\n    else:\n        parent_name = \"\"\n        parent = model\n        child_name = name\n\n    setattr(parent, child_name, value)\n\n\ndef detab_code(code: str) -> Tuple[str, str]:\n    try:\n        spaces = re.match(r\"([\\s\\t]{1,})\", code).group(0)\n        code = re.sub(r\"^\" + spaces, \"\", code, flags=re.MULTILINE)\n    except AttributeError:\n        return code, \"\"\n    return code, spaces\n"
  },
  {
    "path": "src/axolotl/monkeypatch/xformers_/__init__.py",
    "content": "\"\"\"\nFused MLP layer for incrementally improved training efficiency\n\"\"\"\n\nimport torch\nfrom transformers.models.llama.modeling_llama import LlamaMLP\nfrom xformers.ops import SwiGLU\n\nfrom axolotl.monkeypatch.utils import set_module_name\n\n\nclass FusedMLP(torch.nn.Module):\n    \"\"\"\n    Fused MLP layer for incrementally improved training efficiency\n    \"\"\"\n\n    def __init__(\n        self,\n        config,\n        gate_proj: torch.nn.Linear,\n        up_proj: torch.nn.Linear,\n        down_proj: torch.nn.Linear,\n    ):\n        super().__init__()\n        self.config = config\n        self.swiglu = SwiGLU(\n            in_features=config.hidden_size,\n            hidden_features=config.intermediate_size,\n            bias=False,\n            _pack_weights=True,\n        )\n        # overwrite initialized weights with pretrained weights\n        self.swiglu.w12.weight.data = torch.cat(\n            (gate_proj.weight.data, up_proj.weight.data), dim=0\n        )\n        self.swiglu.w3.weight.data = down_proj.weight.data\n\n    def _post_training(self, model, name):\n        w1, w2 = torch.split(\n            self.swiglu.w12.weight.data, self.config.intermediate_size, dim=0\n        )\n\n        # Assign the split weights back to the original layers\n        new_mlp = LlamaMLP(self.config)\n        new_mlp.gate_proj.weight.data = w1\n        new_mlp.up_proj.weight.data = w2\n        new_mlp.down_proj.weight.data = self.swiglu.w3.weight.data\n\n        set_module_name(model, name, new_mlp)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        return self.swiglu(x)\n"
  },
  {
    "path": "src/axolotl/processing_strategies.py",
    "content": "\"\"\"Module containing ProcessingStrategy classes and its derivative for different MultiModal Model types\"\"\"\n\nfrom copy import deepcopy\nfrom typing import Optional\n\nfrom PIL import Image, ImageOps\nfrom PIL.Image import Resampling\nfrom torch import Tensor, zeros_like\nfrom transformers import ProcessorMixin\nfrom transformers.image_utils import load_image\nfrom transformers.models.internvl import InternVLProcessor\nfrom transformers.models.smolvlm import SmolVLMProcessor\nfrom transformers.models.voxtral import VoxtralProcessor\n\nfrom axolotl.utils.dict import remove_none_values\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\n\nclass ProcessingStrategy:\n    \"\"\"Base Processing Strategy class\"\"\"\n\n    def __init__(\n        self,\n        processor: ProcessorMixin,\n        chat_template: Optional[str] = None,\n        image_size: int | tuple[int, int] | None = None,\n        image_resize_algorithm: Resampling | None = None,\n    ):\n        self.processor = processor\n        self.chat_template = chat_template\n        self.image_token = None\n        self.image_token_id = None\n\n        self.image_size = image_size\n        self.image_resize_algorithm = (\n            image_resize_algorithm or Image.Resampling.BILINEAR\n        )\n\n        if hasattr(processor, \"image_token\"):\n            self.image_token = processor.image_token\n            self.image_token_id = processor.tokenizer.convert_tokens_to_ids(\n                self.image_token\n            )\n\n    def __call__(self, examples: list[dict]) -> list[dict]:\n        \"\"\"\n        Preprocess conversation examples to ensure consistent format.\n        Converts different conversation formats to OpenAI format with 'messages'.\n        Supports two formats:\n        1. OpenAI format with 'messages'\n        2. Legacy format with 'conversations'\n\n        Args:\n            examples: list of conversation dictionaries\n\n        Returns:\n            list of dicts in OpenAI format with 'messages' key\n\n        Raises:\n            ValueError: If the conversation format is not supported\n        \"\"\"\n        role_mapping = {\n            \"human\": \"user\",\n            \"gpt\": \"assistant\",\n        }\n\n        def normalize_role(role: str) -> str:\n            \"\"\"Normalize role names to OpenAI format. Default to original role if not found.\"\"\"\n            return role_mapping.get(role, role)\n\n        def convert_legacy_format(example: dict) -> dict:\n            \"\"\"Convert legacy 'conversations' format to OpenAI 'messages' format.\"\"\"\n            messages = [\n                {\"role\": normalize_role(convo[\"from\"]), \"content\": convo[\"value\"]}\n                for convo in example[\"conversations\"]\n            ]\n\n            # Create new dict without 'conversations' key\n            result = deepcopy(example)\n            result.pop(\"conversations\")\n            result[\"messages\"] = messages\n            return result\n\n        def convert_messages_to_multimedia_messages(messages: list[dict]) -> list[dict]:\n            \"\"\"Convert regular messages format to Messages format with content type\"\"\"\n\n            new_messages = []\n            for message in messages:\n                if isinstance(message[\"content\"], str):\n                    new_messages.append(\n                        {\n                            \"role\": message[\"role\"],\n                            \"content\": [\n                                {\n                                    \"type\": \"text\",\n                                    \"text\": message[\"content\"],\n                                }\n                            ],\n                        }\n                    )\n                elif isinstance(message[\"content\"], list):\n                    content = message[\"content\"]\n\n                    new_messages.append(\n                        {\n                            \"role\": message[\"role\"],\n                            \"content\": content,\n                        }\n                    )\n\n            return new_messages\n\n        processed_examples = []\n        for example in examples:\n            if not (\"messages\" in example or \"conversations\" in example):\n                raise ValueError(\n                    \"Only `messages` and `conversations` message keys are currently supported.\"\n                )\n\n            processed_example = None\n            if (\n                \"messages\" in example and example[\"messages\"] is not None\n            ):  # OpenAI format\n                processed_example = example\n            else:  # Legacy format\n                processed_example = convert_legacy_format(example)\n\n            # convert regular messages format to Messages format with content type\n            # for compatibility with apply_chat_template\n            processed_example[\"messages\"] = convert_messages_to_multimedia_messages(\n                processed_example[\"messages\"]\n            )\n\n            # find the image key if it exists\n            possible_image_keys = [\"images\", \"image\"]\n            image_key = None\n            for key in possible_image_keys:\n                if key in processed_example:\n                    image_key = key\n                    break\n\n            # if the image key exists, add the image to the first user message\n            if image_key is not None and processed_example[image_key] is not None:\n                # TODO: check if it's normal to be single image only for common datasets\n                # From observation, it's usually a list of single image but some datasets may have several columns for images\n                # Temporary solution: take the first image and suggest people convert their datasets to use multi-content Messages\n                if len(processed_example[image_key]) > 1:\n                    LOG.warning(\n                        f\"Found {len(processed_example[image_key])} images in a sample. Using the first one.\"\n                        \"If you are using a dataset with multiple images per sample, please convert it to use multi-content Messages.\"\n                        \"See https://docs.axolotl.ai/docs/multimodal.html#dataset-format\"\n                    )\n\n                image_value = processed_example[image_key][0]\n\n                # Handle image loading (Image, url, path, base64)\n                image_value = load_image(image_value)\n\n                if self.image_size is not None:\n                    assert hasattr(image_value, \"resize\"), (\n                        \"Image does not have a resize method\"\n                    )\n\n                    if isinstance(self.image_size, tuple):\n                        image_value = image_value.resize(\n                            self.image_size, self.image_resize_algorithm\n                        )\n                    else:\n                        # Set the padding value; here we use black (0, 0, 0) for RGB images\n                        padding_color = (0, 0, 0)\n\n                        # When image_size is an int (square target), preserve aspect ratio then pad\n                        # This is to prevent aspect ratio distortion when resizing to square\n                        image_value = ImageOps.pad(\n                            image_value,\n                            (self.image_size, self.image_size),\n                            method=self.image_resize_algorithm,\n                            color=padding_color,\n                        )\n\n                # Look for any image type in the first message\n                # some dataset have an {type: \"image\"} in the first message\n                msg_ind_to_add = None\n                ind_to_add = None\n                first_user_idx = None\n\n                for msg_idx, msg_content in enumerate(processed_example[\"messages\"]):\n                    if first_user_idx is None and msg_content[\"role\"] == \"user\":\n                        first_user_idx = msg_idx\n                    for i, content in enumerate(\n                        processed_example[\"messages\"][msg_idx][\"content\"]\n                    ):\n                        # Usually datasets created with image columns, don't have it in the messages itself\n                        if content[\"type\"] == \"image\" and all(\n                            k not in content for k in [\"image\", \"url\", \"path\", \"base64\"]\n                        ):\n                            msg_ind_to_add = msg_idx\n                            ind_to_add = i\n                            break\n\n                # If an image type is found, add the image to that index\n                if ind_to_add is not None and msg_ind_to_add is not None:\n                    processed_example[\"messages\"][msg_ind_to_add][\"content\"][\n                        ind_to_add\n                    ][\"image\"] = image_value\n                else:\n                    # if no image type is found, add it to end of the first user message\n                    if first_user_idx is None:\n                        first_user_idx = 0\n                    processed_example[\"messages\"][first_user_idx][\"content\"].append(\n                        {\n                            \"type\": \"image\",\n                            \"image\": image_value,\n                        }\n                    )\n\n            processed_examples.append(remove_none_values(processed_example))\n\n        return processed_examples\n\n    def _mask_non_assistant(self, labels: Tensor) -> Tensor:\n        \"\"\"\n        Mask non assistant regions to -100.\n        To be implemented per subclass.\n        \"\"\"\n        return labels\n\n    def process_labels(self, input_ids: Tensor) -> Tensor:\n        labels = input_ids.clone()\n\n        labels = self._mask_non_assistant(labels)\n\n        # The labels are the input_ids, and we mask the padding tokens in the loss computation\n        labels[labels == self.processor.tokenizer.pad_token_id] = -100\n\n        # Ignore the image token index in the loss computation (model specific)\n        labels[labels == self.image_token_id] = -100\n\n        return labels\n\n\nclass Qwen2VLProcessingStrategy(ProcessingStrategy):\n    \"\"\"Processing Strategy class for Qwen2-VL\"\"\"\n\n    def __init__(\n        self,\n        processor: ProcessorMixin,\n        chat_template: Optional[str] = None,\n        image_size: int | tuple[int, int] | None = None,\n        image_resize_algorithm: Resampling | None = None,\n    ):\n        super().__init__(processor, chat_template, image_size, image_resize_algorithm)\n        self.image_token = \"<|image_pad|>\"  # nosec\n        self.image_token_id = processor.tokenizer.convert_tokens_to_ids(\n            self.image_token\n        )\n\n\nclass Qwen3_5ProcessingStrategy(ProcessingStrategy):\n    \"\"\"Processing Strategy class for Qwen3.5 (early-fusion VLM)\"\"\"\n\n    def __init__(\n        self,\n        processor: ProcessorMixin,\n        chat_template: Optional[str] = None,\n        image_size: int | tuple[int, int] | None = None,\n        image_resize_algorithm: Resampling | None = None,\n    ):\n        super().__init__(processor, chat_template, image_size, image_resize_algorithm)\n        self.image_token = \"<|image_pad|>\"  # nosec\n        self.image_token_id = processor.tokenizer.convert_tokens_to_ids(\n            self.image_token\n        )\n        self.video_token = \"<|video_pad|>\"  # nosec\n        self.video_token_id = processor.tokenizer.convert_tokens_to_ids(\n            self.video_token\n        )\n\n    def process_labels(self, input_ids):\n        labels = super().process_labels(input_ids)\n        labels[labels == self.video_token_id] = -100\n        return labels\n\n\nclass Gemma3ProcessingStrategy(ProcessingStrategy):\n    \"\"\"Processing Strategy class for Gemma3\"\"\"\n\n    def __init__(\n        self,\n        processor: ProcessorMixin,\n        chat_template: Optional[str] = None,\n        image_size: int | tuple[int, int] | None = None,\n        image_resize_algorithm: Resampling | None = None,\n    ):\n        super().__init__(processor, chat_template, image_size, image_resize_algorithm)\n        self.image_token = processor.tokenizer.special_tokens_map[\"boi_token\"]\n        self.image_token_id = processor.tokenizer.convert_tokens_to_ids(\n            self.image_token\n        )\n\n    def process_labels(self, input_ids):\n        labels = input_ids.clone()\n\n        # Follows https://ai.google.dev/gemma/docs/core/huggingface_vision_finetune_qlora\n        labels[labels == self.processor.tokenizer.pad_token_id] = -100\n        labels[labels == self.image_token_id] = -100\n        labels[labels == 262144] = -100  # corresponds to <image_soft_token>\n\n        return labels\n\n\nclass Gemma3nProcessingStrategy(ProcessingStrategy):\n    \"\"\"Processing Strategy class for Gemma3n\"\"\"\n\n    def _mask_non_assistant(self, labels: Tensor) -> Tensor:\n        def _find_token_sequence(label, start_pos, token_sequence):\n            \"\"\"Check if token_sequence appears at start_pos in label\"\"\"\n            if start_pos + len(token_sequence) > len(label):\n                return False\n            if label[start_pos] != token_sequence[0]:\n                return False\n            return (\n                label[start_pos : start_pos + len(token_sequence)].tolist()\n                == token_sequence\n            )\n\n        def _find_assistant_end(label, start_pos, assistant_end_tok, mask, i):\n            \"\"\"\n            Find the end of assistant response and update mask accordingly\n\n            Returns new position to continue from and whether the end seq is found\n            \"\"\"\n            k = start_pos\n            while k < len(label):\n                if not _find_token_sequence(label, k, assistant_end_tok):\n                    mask[i][k] = 1\n                    k += 1\n                    continue\n\n                return k + len(assistant_end_tok), True\n\n            return k, False\n\n        mask = zeros_like(labels)\n\n        assistant_start_str = \"<start_of_turn>model\"\n        assistant_end_str = \"<end_of_turn>\"\n        include_assistant_start_tok = False\n        include_assistant_end_tok = True\n\n        # str to tokens\n        assistant_start_tok = self.processor.tokenizer.encode(\n            assistant_start_str, add_special_tokens=False\n        )\n        assistant_end_tok = self.processor.tokenizer.encode(\n            assistant_end_str, add_special_tokens=False\n        )\n\n        for i, label in enumerate(labels):\n            j = 0\n            # while loop through each tok index in labels[i]\n            while j < len(label):\n                # Check until match start seq\n                if not _find_token_sequence(label, j, assistant_start_tok):\n                    j += 1\n                    continue\n\n                if include_assistant_start_tok:\n                    mask[i][j : j + len(assistant_start_tok)] = 1\n\n                # Find where the assistant response ends\n                start_of_content = j + len(assistant_start_tok)\n                end_pos, found_end_seq = _find_assistant_end(\n                    label, start_of_content, assistant_end_tok, mask, i\n                )\n\n                # Include end token if requested\n                if include_assistant_end_tok and found_end_seq:\n                    mask[i][end_pos - len(assistant_end_tok) : end_pos] = 1\n\n                j = end_pos\n\n            labels[i][mask[i] == 0] = -100\n\n        return labels\n\n    def process_labels(self, input_ids):\n        labels = input_ids.clone()\n        labels = self._mask_non_assistant(labels)\n\n        # Follows https://colab.research.google.com/github/huggingface/huggingface-gemma-recipes/blob/main/notebooks/fine_tune_gemma3n_on_t4.ipynb\n        labels[labels == self.processor.tokenizer.pad_token_id] = -100\n        if hasattr(self.processor.tokenizer, \"image_token_id\"):\n            labels[labels == self.processor.tokenizer.image_token_id] = -100\n        if hasattr(self.processor.tokenizer, \"audio_token_id\"):\n            labels[labels == self.processor.tokenizer.audio_token_id] = -100\n        if hasattr(self.processor.tokenizer, \"boi_token_id\"):\n            labels[labels == self.processor.tokenizer.boi_token_id] = -100\n        if hasattr(self.processor.tokenizer, \"eoi_token_id\"):\n            labels[labels == self.processor.tokenizer.eoi_token_id] = -100\n\n        return labels\n\n\nclass VoxtralProcessingStrategy(ProcessingStrategy):\n    \"\"\"Processing Strategy class for Voxtral\"\"\"\n\n    def __init__(\n        self,\n        processor: VoxtralProcessor,\n        chat_template: Optional[str] = None,\n        image_size: int | tuple[int, int] | None = None,\n        image_resize_algorithm: Resampling | None = None,\n    ):\n        super().__init__(processor, chat_template, image_size, image_resize_algorithm)\n        special_ids = (\n            processor.tokenizer.tokenizer.instruct_tokenizer.audio_encoder.special_ids\n        )\n\n        self.audio_token = special_ids.audio\n        self.begin_audio_token = special_ids.begin_audio\n\n    def process_labels(self, input_ids):\n        labels = input_ids.clone()\n\n        labels[labels == self.processor.tokenizer.pad_token_id] = -100\n        labels[labels == self.audio_token] = -100\n        labels[labels == self.begin_audio_token] = -100\n\n        return labels\n\n\nclass SmolVLM2ProcessingStrategy(ProcessingStrategy):\n    \"\"\"Processing Strategy class for SmolVLM2\"\"\"\n\n    def __init__(\n        self,\n        processor: ProcessorMixin,\n        chat_template: Optional[str] = None,\n        image_size: int | tuple[int, int] | None = None,\n        image_resize_algorithm: Resampling | None = None,\n    ):\n        super().__init__(processor, chat_template, image_size, image_resize_algorithm)\n        self.image_token = \"<image>\"  # nosec\n\n        self.image_token_id = processor.tokenizer.additional_special_tokens_ids[\n            processor.tokenizer.additional_special_tokens.index(self.image_token)\n        ]\n\n\nclass Mistral3ProcessingStrategy(ProcessingStrategy):\n    \"\"\"Processing Strategy class for Mistral3\"\"\"\n\n    def __init__(\n        self,\n        processor,\n        chat_template: Optional[str] = None,\n        image_size: int | tuple[int, int] | None = None,\n        image_resize_algorithm: Resampling | None = None,\n    ):\n        super().__init__(processor, chat_template, image_size, image_resize_algorithm)\n        special_ids = (\n            processor.tokenizer.tokenizer.instruct_tokenizer.image_encoder.special_ids\n        )\n\n        self.image_token = special_ids.img\n        self.image_break_token = special_ids.img_break\n        self.image_end_token = special_ids.img_end\n\n    def process_labels(self, input_ids):\n        labels = input_ids.clone()\n\n        labels[labels == self.processor.tokenizer.pad_token_id] = -100\n        labels[labels == self.image_token] = -100\n        labels[labels == self.image_break_token] = -100\n        labels[labels == self.image_end_token] = -100\n\n        return labels\n\n\nclass InternVLProcessingStrategy(ProcessingStrategy):\n    \"\"\"Processing Strategy class for InternVL\"\"\"\n\n    def __init__(\n        self,\n        processor: ProcessorMixin,\n        chat_template: Optional[str] = None,\n        image_size: int | tuple[int, int] | None = None,\n        image_resize_algorithm: Resampling | None = None,\n    ):\n        super().__init__(processor, chat_template, image_size, image_resize_algorithm)\n\n        if not hasattr(processor, \"image_ids\"):\n            raise ValueError(\"'image_ids' missing from InternVL Processor.\")\n\n        self.image_token_ids = processor.image_ids\n\n    def process_labels(self, input_ids):\n        labels = input_ids.clone()\n\n        labels[labels == self.processor.tokenizer.pad_token_id] = -100\n\n        for ids in self.image_token_ids:\n            labels[labels == ids] = -100\n\n        # Note: Check if need to mask 'video_token' as it gets converted to\n        # image patches during media processing\n\n        return labels\n\n\nclass Glm4vProcessingStrategy(ProcessingStrategy):\n    \"\"\"Processing Strategy class for GLM4V and GLM4V-MoE vision models.\"\"\"\n\n    def __init__(\n        self,\n        processor: ProcessorMixin,\n        chat_template: Optional[str] = None,\n        image_size: int | tuple[int, int] | None = None,\n        image_resize_algorithm: Resampling | None = None,\n    ):\n        super().__init__(processor, chat_template, image_size, image_resize_algorithm)\n\n        self.tokenizer = getattr(processor, \"tokenizer\", processor)\n\n        self.image_token = \"<|image|>\"  # nosec\n        self.begin_image_token = \"<|begin_of_image|>\"  # nosec\n        self.end_image_token = \"<|end_of_image|>\"  # nosec\n        self.video_token = \"<|video|>\"  # nosec\n        self.begin_video_token = \"<|begin_of_video|>\"  # nosec\n        self.end_video_token = \"<|end_of_video|>\"  # nosec\n\n        self.image_token_id = self.tokenizer.convert_tokens_to_ids(self.image_token)\n        self.begin_image_token_id = self.tokenizer.convert_tokens_to_ids(\n            self.begin_image_token\n        )\n        self.end_image_token_id = self.tokenizer.convert_tokens_to_ids(\n            self.end_image_token\n        )\n        self.video_token_id = self.tokenizer.convert_tokens_to_ids(self.video_token)\n        self.begin_video_token_id = self.tokenizer.convert_tokens_to_ids(\n            self.begin_video_token\n        )\n        self.end_video_token_id = self.tokenizer.convert_tokens_to_ids(\n            self.end_video_token\n        )\n\n    def process_labels(self, input_ids):\n        labels = input_ids.clone()\n\n        labels[labels == self.tokenizer.pad_token_id] = -100\n\n        labels[labels == self.image_token_id] = -100\n        labels[labels == self.begin_image_token_id] = -100\n        labels[labels == self.end_image_token_id] = -100\n\n        labels[labels == self.video_token_id] = -100\n        labels[labels == self.begin_video_token_id] = -100\n        labels[labels == self.end_video_token_id] = -100\n\n        return labels\n\n\ndef get_processing_strategy(\n    processor: ProcessorMixin,\n    chat_template,\n    chat_template_type,\n    image_size: int | tuple[int, int] | None = None,\n    image_resize_algorithm: Resampling | None = None,\n):\n    from axolotl.utils.mistral.mistral3_processor import Mistral3Processor\n\n    processing_kwargs = {\n        \"processor\": processor,\n        \"chat_template\": chat_template,\n        \"image_size\": image_size,\n        \"image_resize_algorithm\": image_resize_algorithm,\n    }\n\n    if chat_template_type in [None, \"tokenizer_default\"]:\n        tokenizer = getattr(processor, \"tokenizer\", processor)\n        if hasattr(tokenizer, \"chat_template\"):\n            processing_kwargs[\"chat_template\"] = tokenizer.chat_template\n\n    if chat_template_type == \"qwen2_vl\":\n        return Qwen2VLProcessingStrategy(\n            **processing_kwargs,\n        )\n    if chat_template_type in [\"qwen3_5\", \"qwen3_5_moe\"]:\n        return Qwen3_5ProcessingStrategy(\n            **processing_kwargs,\n        )\n    if chat_template_type == \"gemma3\":\n        return Gemma3ProcessingStrategy(\n            **processing_kwargs,\n        )\n    if chat_template_type == \"gemma3n\":\n        return Gemma3nProcessingStrategy(\n            **processing_kwargs,\n        )\n\n    if isinstance(processor, VoxtralProcessor):\n        return VoxtralProcessingStrategy(\n            **processing_kwargs,\n        )\n\n    if isinstance(processor, SmolVLMProcessor):\n        return SmolVLM2ProcessingStrategy(\n            **processing_kwargs,\n        )\n\n    if isinstance(processor, Mistral3Processor):\n        return Mistral3ProcessingStrategy(\n            **processing_kwargs,\n        )\n    try:\n        from transformers.models.glm46v.processing_glm46v import Glm46VProcessor\n\n        if isinstance(processor, Glm46VProcessor):\n            return Glm4vProcessingStrategy(\n                **processing_kwargs,\n            )\n    except ImportError:\n        pass\n\n    if isinstance(processor, InternVLProcessor):\n        return InternVLProcessingStrategy(\n            **processing_kwargs,\n        )\n\n    # llama3_2_vision, llama4, llava\n    # mistral_v7_tekken, pixtral, lfm2vl\n    return ProcessingStrategy(\n        **processing_kwargs,\n    )\n"
  },
  {
    "path": "src/axolotl/prompt_strategies/__init__.py",
    "content": "\"\"\"Module to load prompt strategies.\"\"\"\n\nimport importlib\nimport inspect\n\nfrom axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\n\ndef load(strategy, tokenizer, cfg, ds_cfg, processor=None):\n    try:\n        if strategy == \"messages\":\n            from .messages import load as messages_load\n\n            return messages_load(tokenizer, cfg, ds_cfg, processor=processor)\n        load_fn = \"load\"\n        package = \"axolotl.prompt_strategies\"\n        if (\n            strategy.split(\".\")[-1].startswith(\"load_\")\n            or strategy.split(\".\")[-1] == \"load\"\n        ):\n            load_fn = strategy.split(\".\")[-1]\n            strategy = \".\".join(strategy.split(\".\")[:-1])\n        elif len(strategy.split(\".\")) > 1:\n            try:\n                importlib.import_module(\n                    \".\" + strategy.split(\".\")[-1],\n                    \".\".join(strategy.split(\".\")[:-1]),\n                )\n                package = \".\".join(strategy.split(\".\")[:-1])\n                strategy = strategy.split(\".\")[-1]\n            except ModuleNotFoundError:\n                pass\n        mod = importlib.import_module(f\".{strategy}\", package)\n        func = getattr(mod, load_fn)\n        load_kwargs = {}\n        if strategy == \"user_defined\":\n            load_kwargs[\"ds_cfg\"] = UserDefinedDatasetConfig(**ds_cfg)\n        else:\n            sig = inspect.signature(func)\n            if \"ds_cfg\" in sig.parameters:\n                load_kwargs[\"ds_cfg\"] = ds_cfg\n            if \"processor\" in sig.parameters:\n                load_kwargs[\"processor\"] = processor\n\n        return func(tokenizer, cfg, **load_kwargs)\n    except ModuleNotFoundError:\n        return None\n    except Exception as exc:\n        LOG.error(f\"Failed to load prompt strategy `{strategy}`: {str(exc)}\")\n        raise exc\n"
  },
  {
    "path": "src/axolotl/prompt_strategies/alpaca_chat.py",
    "content": "\"\"\"Module for Alpaca prompt strategy classes\"\"\"\n\nfrom typing import Any, Dict, Optional, Tuple\n\nfrom axolotl.prompt_tokenizers import (\n    AlpacaPromptTokenizingStrategy,\n    InstructionPromptTokenizingStrategy,\n)\nfrom axolotl.prompters import AlpacaPrompter, PromptStyle, UnpromptedPrompter\n\n\ndef load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):\n    prompt_style = PromptStyle.CHAT.value\n    if ds_cfg and \"conversation\" in ds_cfg:\n        prompt_style = ds_cfg[\"conversation\"]\n\n    return AlpacaPromptTokenizingStrategy(\n        AlpacaPrompter(prompt_style),\n        tokenizer,\n        cfg.train_on_inputs,\n        cfg.sequence_len,\n    )\n\n\nclass AlpacaConcisePrompter(AlpacaPrompter):\n    \"\"\"\n    Alpaca Prompter extending the system prompt to ask for concise chat-instruct answers\n    \"\"\"\n\n    system_prompt = \"Below is an instruction from a USER that describes a task, paired with an input that provides further context. The ASSISTANT writes a response that concisely and appropriately completes the request.\\n\\n\"\n    system_no_input_prompt = \"Below is an instruction from a USER that describes a task. The ASSISTANT writes a response that appropriately and concisely completes the request.\\n\\n\"\n\n\nclass AlpacaChatPrompter(AlpacaPrompter):\n    \"\"\"\n    Alpaca Chat Prompter extending the system prompt to for chat-instruct answers\n    \"\"\"\n\n    system_prompt = \"Below is an instruction from a USER that describes a task, paired with an input that provides further context. The ASSISTANT writes a response that concisely and appropriately completes the request.\\n\\n\"\n    system_no_input_prompt = \"Below is an instruction from a USER that describes a task. The ASSISTANT writes a response that appropriately and concisely completes the request.\\n\\n\"\n\n    def __init__(self):\n        self.prompt_style = PromptStyle.CHAT.value\n        self.match_prompt_style()\n\n\nclass NoSystemPrompter(AlpacaPrompter):\n    \"\"\"\n    Null Prompter with no system prompts\n    \"\"\"\n\n    system_prompt = \"\"\n    system_no_input_prompt = \"\"\n    turn_format = \"{instruction} {input} \"\n    turn_no_input_format = \"{instruction} \"\n\n    def __init__(self):\n        pass\n\n\nclass AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):\n    \"\"\"\n    Tokenizing strategy for AlpacaQA\n    \"\"\"\n\n    def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:\n        return (\n            prompt[\"question\"],\n            \"\",\n            prompt[\"answer\"],\n        )\n\n\nclass CamelAIPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):\n    \"\"\"\n    Tokenizing strategy for CamelAI datasets\n    \"\"\"\n\n    def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:\n        return (\n            prompt[\"message_1\"],\n            \"\",\n            prompt[\"message_2\"],\n        )\n\n\ndef load_concise(tokenizer, cfg):\n    return AlpacaPromptTokenizingStrategy(\n        AlpacaConcisePrompter(PromptStyle.CHAT.value),\n        tokenizer,\n        cfg.train_on_inputs,\n        cfg.sequence_len,\n    )\n\n\ndef load_qa(tokenizer, cfg):\n    return AlpacaQAPromptTokenizingStrategy(\n        AlpacaChatPrompter(),\n        tokenizer,\n        cfg.train_on_inputs,\n        cfg.sequence_len,\n    )\n\n\ndef load_camel_ai(tokenizer, cfg):\n    return CamelAIPromptTokenizingStrategy(\n        AlpacaChatPrompter(),\n        tokenizer,\n        cfg.train_on_inputs,\n        cfg.sequence_len,\n    )\n\n\ndef load_no_prompt(tokenizer, cfg):\n    return AlpacaPromptTokenizingStrategy(\n        UnpromptedPrompter(PromptStyle.CHAT.value),\n        tokenizer,\n        cfg.train_on_inputs,\n        cfg.sequence_len,\n    )\n"
  },
  {
    "path": "src/axolotl/prompt_strategies/alpaca_instruct.py",
    "content": "\"\"\"Module loading the AlpacaInstructPromptTokenizingStrategy class\"\"\"\n\nfrom axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy\nfrom axolotl.prompters import AlpacaPrompter, PromptStyle, UnpromptedPrompter\n\n\ndef load(tokenizer, cfg):\n    return AlpacaPromptTokenizingStrategy(\n        AlpacaPrompter(PromptStyle.INSTRUCT.value),\n        tokenizer,\n        cfg.train_on_inputs,\n        cfg.sequence_len,\n    )\n\n\ndef load_no_prompt(tokenizer, cfg):\n    return AlpacaPromptTokenizingStrategy(\n        UnpromptedPrompter(PromptStyle.INSTRUCT.value),\n        tokenizer,\n        cfg.train_on_inputs,\n        cfg.sequence_len,\n    )\n"
  },
  {
    "path": "src/axolotl/prompt_strategies/alpaca_w_system.py",
    "content": "\"\"\"\nPrompt strategies loader for alpaca instruction datasets with system prompts\n\"\"\"\n\nfrom typing import Generator, Tuple, Union\n\nfrom axolotl.prompt_tokenizers import PromptTokenizingStrategy\nfrom axolotl.prompters import AlpacaPrompter, PromptStyle\n\n\nclass InstructionWSystemPromptTokenizingStrategy(PromptTokenizingStrategy):\n    \"\"\"\n    Tokenizing strategy for instruction-based prompts.\n    \"\"\"\n\n    def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str]:\n        return (\n            prompt[\"instruction\"],\n            prompt[\"input\"] if \"input\" in prompt else \"\",\n            prompt[\"output\"],\n            prompt[\"system\"],\n        )\n\n    def tokenize_prompt(self, prompt):\n        (\n            instruction,\n            input,\n            response,\n            system,\n        ) = self.parse_instruction_fields(prompt)\n        user_prompt = next(\n            iter(\n                self.prompter.build_prompt_w_system(\n                    system,\n                    instruction,\n                    input,\n                )\n            )\n        )\n        tokenized_prompt = self._tokenize(user_prompt, add_eos_token=False)\n        if not self.train_on_inputs:\n            user_prompt_len = len(tokenized_prompt[\"input_ids\"])\n            # TODO this could be sped up using numpy array slicing\n            tokenized_prompt[\"labels\"] = [-100] * user_prompt_len\n        tokenized_res_prompt = self._tokenize(\n            response, strip_bos_token=True, add_eos_token=True\n        )\n        tokenized_prompt[\"input_ids\"] += tokenized_res_prompt[\"input_ids\"]\n        tokenized_prompt[\"attention_mask\"] += tokenized_res_prompt[\"attention_mask\"]\n        tokenized_prompt[\"labels\"] += tokenized_res_prompt[\"input_ids\"]\n\n        return tokenized_prompt\n\n\nclass SystemDataPrompter(AlpacaPrompter):\n    \"\"\"\n    Alpaca Style Prompter that uses system prompts from the dataset\n    \"\"\"\n\n    system_format: str = \"### System:\\n{system}\\n\\n\"\n\n    def build_prompt_w_system(\n        self,\n        system: str,\n        instruction: str,\n        input: Union[None, str] = None,\n        output: Union[None, str] = None,\n    ) -> Generator[str, None, None]:\n        # returns the full prompt from instruction and optional input\n        # if a label (=response, =output) is provided, it's also appended.\n        formatted_sys_prompt = (\n            self.system_format.format(system=system)\n            if system and self.system_format\n            else \"\"\n        )\n        if input:\n            res = formatted_sys_prompt + self.turn_format.format(\n                instruction=instruction, input=input\n            )\n        else:\n            res = formatted_sys_prompt + self.turn_no_input_format.format(\n                instruction=instruction\n            )\n        if output:\n            res = f\"{res}{output}\"\n        yield res\n\n\nclass OpenOrcaSystemDataPrompter(SystemDataPrompter):\n    \"\"\"\n    Alpaca Style Prompter that uses system prompts from the dataset, with OpenOrca prompts\n    \"\"\"\n\n    def match_prompt_style(self):\n        if self.prompt_style == PromptStyle.INSTRUCT.value:\n            self.turn_format = \"### Human:\\n{instruction}\\n### Additional Context:\\n{input}\\n### Assistant:\\n\"\n            self.turn_no_input_format = \"### Human:\\n{instruction}\\n### Assistant:\\n\"\n            self.system_format = \"### System:\\n{system}\\n\"\n        if self.prompt_style == PromptStyle.CHAT.value:\n            self.turn_format = \"USER: {instruction}\\n{input}\\nASSISTANT:\"\n            self.turn_no_input_format = \"USER: {instruction}\\nASSISTANT:\"\n            self.system_format = \"SYSTEM: {system}\\n\"\n        if self.prompt_style == PromptStyle.CHATML.value:\n            self.turn_format = \"<|im_start|>user\\n{instruction}\\n{input}<|im_end|>\\n<|im_start|>assistant\\n\"\n            self.turn_no_input_format = (\n                \"<|im_start|>user\\n{instruction}<|im_end|>\\n<|im_start|>assistant\\n\"\n            )\n            self.system_format = \"<|im_start|>system\\n{system}<|im_end|>\\n\"\n\n\nclass OpenOrcaPromptTokenizingStrategy(InstructionWSystemPromptTokenizingStrategy):\n    \"\"\"\n    Tokenizing strategy for OpenOrca datasets\n    \"\"\"\n\n    def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str]:\n        return (\n            prompt[\"question\"],\n            \"\",\n            prompt[\"response\"],\n            prompt[\"system_prompt\"],\n        )\n\n\ndef load(tokenizer, cfg):\n    return load_chat(tokenizer, cfg)\n\n\ndef load_instruct(tokenizer, cfg):\n    return InstructionWSystemPromptTokenizingStrategy(\n        SystemDataPrompter(PromptStyle.INSTRUCT.value),\n        tokenizer,\n        cfg.train_on_inputs,\n        cfg.sequence_len,\n    )\n\n\ndef load_chat(tokenizer, cfg):\n    return InstructionWSystemPromptTokenizingStrategy(\n        SystemDataPrompter(PromptStyle.CHAT.value),\n        tokenizer,\n        cfg.train_on_inputs,\n        cfg.sequence_len,\n    )\n\n\ndef load_open_orca(tokenizer, cfg):\n    return OpenOrcaPromptTokenizingStrategy(\n        OpenOrcaSystemDataPrompter(PromptStyle.INSTRUCT.value),\n        tokenizer,\n        cfg.train_on_inputs,\n        cfg.sequence_len,\n    )\n\n\ndef load_open_orca_chatml(tokenizer, cfg):\n    return OpenOrcaPromptTokenizingStrategy(\n        OpenOrcaSystemDataPrompter(PromptStyle.CHATML.value),\n        tokenizer,\n        cfg.train_on_inputs,\n        cfg.sequence_len,\n    )\n"
  },
  {
    "path": "src/axolotl/prompt_strategies/base.py",
    "content": "\"\"\"\nmodule for base dataset transform strategies\n\"\"\"\n\nimport importlib\n\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\n\ndef load(strategy, cfg, module_base=None, **kwargs):\n    try:\n        if len(strategy.split(\".\")) == 1:\n            strategy = strategy + \".default\"\n        load_fn = strategy.split(\".\")[-1]\n        if len(strategy.split(\".\")) > 1:\n            try:\n                importlib.import_module(\n                    strategy.split(\".\")[-2],\n                    \".\".join(strategy.split(\".\")[:-2]),\n                )\n                module_base = \".\".join(strategy.split(\".\")[:-2])\n                strategy = strategy.split(\".\")[-2]\n            except ModuleNotFoundError:\n                strategy = \".\" + \".\".join(strategy.split(\".\")[:-1])\n        else:\n            strategy = \".\" + \".\".join(strategy.split(\".\")[:-1])\n        mod = importlib.import_module(strategy, module_base)\n        func = getattr(mod, load_fn)\n        return func(cfg, **kwargs)\n    except Exception:\n        LOG.warning(f\"unable to load strategy {strategy}\")\n        return None\n"
  },
  {
    "path": "src/axolotl/prompt_strategies/bradley_terry/README.md",
    "content": "### example yaml\n\n```yaml\nchat_template: gemma\ndatasets:\n  - path: argilla/distilabel-intel-orca-dpo-pairs\n    type: bradley_terry.chat_template\nval_set_size: 0.0\noutput_dir: ./outputs/out\n```\n"
  },
  {
    "path": "src/axolotl/prompt_strategies/bradley_terry/__init__.py",
    "content": "\"\"\"Module to load prompt strategies.\"\"\"\n\nimport importlib\nimport inspect\n\nfrom axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\n\ndef load(strategy, tokenizer, cfg, ds_cfg):\n    try:\n        load_fn = \"load\"\n        if strategy.split(\".\")[-1].startswith(\"load_\"):\n            load_fn = strategy.split(\".\")[-1]\n            strategy = \".\".join(strategy.split(\".\")[:-1])\n        mod = importlib.import_module(\n            f\".{strategy}\", \"axolotl.prompt_strategies.bradley_terry\"\n        )\n        func = getattr(mod, load_fn)\n        load_kwargs = {}\n        if strategy == \"user_defined\":\n            load_kwargs[\"ds_cfg\"] = UserDefinedDatasetConfig(**ds_cfg)\n        else:\n            sig = inspect.signature(func)\n            if \"ds_cfg\" in sig.parameters:\n                load_kwargs[\"ds_cfg\"] = ds_cfg\n        return func(tokenizer, cfg, **load_kwargs)\n    except ModuleNotFoundError:\n        return None\n    except Exception as exc:\n        LOG.error(f\"Failed to load prompt strategy `{strategy}`: {str(exc)}\")\n        return None\n"
  },
  {
    "path": "src/axolotl/prompt_strategies/bradley_terry/chat_template.py",
    "content": "\"\"\"\nBradley-Terry model with chat template prompt strategy.\n\"\"\"\n\nfrom typing import Any, Dict, Optional\n\nfrom axolotl.prompt_strategies.chat_template import (\n    ChatTemplatePrompter,\n    ChatTemplateStrategy,\n)\nfrom axolotl.utils.chat_templates import get_chat_template_from_config\nfrom axolotl.utils.logging import get_logger\n\n# Configure the logger\nLOG = get_logger(__name__)\nLOG.setLevel(\"INFO\")\n\n\nclass BTChatTemplateStrategy(ChatTemplateStrategy):\n    \"\"\"\n    Bradley-Terry reward model pairwise chat template prompt strategy.\n    \"\"\"\n\n    @property\n    def supports_batched(self) -> bool:\n        return False\n\n    def _tokenize_single_prompt(self, prompt):\n        \"\"\"\n\n        :param prompt: the actual row of data from the underlying dataset\n        :return:\n        \"\"\"\n\n        max_length = self.prompter.max_length\n\n        prompt[\"messages\"] = []\n        if prompt[\"system\"]:\n            prompt[\"messages\"].append({\"role\": \"system\", \"content\": prompt[\"system\"]})\n        prompt[\"messages\"].append({\"role\": \"user\", \"content\": prompt[\"input\"]})\n        prompt[\"messages\"].append({\"role\": \"assistant\", \"content\": prompt[\"chosen\"]})\n        chosen_tokenized = super()._tokenize_single_prompt(prompt)\n\n        if len(chosen_tokenized[\"input_ids\"]) > max_length:\n            LOG.warning(\n                f\"To-be-trimmed chosen sequence exceeds max sequence length: {len(chosen_tokenized['input_ids'])}\"\n            )\n\n            chosen_tokenized[\"input_ids\"] = chosen_tokenized[\"input_ids\"][:max_length]\n            chosen_tokenized[\"attention_mask\"] = chosen_tokenized[\"attention_mask\"][\n                :max_length\n            ]\n\n        prompt[\"messages\"] = []\n        if prompt[\"system\"]:\n            prompt[\"messages\"].append({\"role\": \"system\", \"content\": prompt[\"system\"]})\n        prompt[\"messages\"].append({\"role\": \"user\", \"content\": prompt[\"input\"]})\n        prompt[\"messages\"].append({\"role\": \"assistant\", \"content\": prompt[\"rejected\"]})\n        rejected_tokenized = super()._tokenize_single_prompt(prompt)\n\n        if len(rejected_tokenized[\"input_ids\"]) > max_length:\n            LOG.warning(\n                f\"To-be-trimmed rejected sequence exceeds max sequence length: {len(rejected_tokenized['input_ids'])}\"\n            )\n\n            rejected_tokenized[\"input_ids\"] = rejected_tokenized[\"input_ids\"][\n                :max_length\n            ]\n            rejected_tokenized[\"attention_mask\"] = rejected_tokenized[\"attention_mask\"][\n                :max_length\n            ]\n\n        return {\n            \"chosen_input_ids\": chosen_tokenized[\"input_ids\"],\n            \"attention_mask_chosen\": chosen_tokenized[\"attention_mask\"],\n            \"labels_chosen\": 1.0,\n            \"rejected_input_ids\": rejected_tokenized[\"input_ids\"],\n            \"attention_mask_rejected\": rejected_tokenized[\"attention_mask\"],\n            \"labels_rejected\": 0.0,\n        }\n\n\ndef load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):\n    ds_cfg = ds_cfg or {}\n    chat_template_string = get_chat_template_from_config(\n        cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer\n    )\n\n    prompter_params = {\n        \"tokenizer\": tokenizer,\n        \"chat_template\": chat_template_string,\n        \"message_property_mappings\": ds_cfg.get(\n            \"message_property_mappings\",\n            {\n                \"role\": \"role\",\n                \"content\": \"content\",\n            },\n        ),\n        \"message_field_training\": ds_cfg.get(\"message_field_training\", None),\n        \"message_field_training_detail\": ds_cfg.get(\n            \"message_field_training_detail\", None\n        ),\n        \"roles\": ds_cfg.get(\"roles\"),\n        \"drop_system_message\": ds_cfg.get(\"drop_system_message\", False),\n        # we need to add one for detecting sequences with exceeding the `sequence_len` limit.\n        \"max_length\": (\n            cfg.sequence_len + 1 if not cfg.reward_model else cfg.sequence_len\n        ),\n    }\n\n    strategy_params = {\n        \"train_on_inputs\": cfg.train_on_inputs,\n        \"sequence_len\": cfg.sequence_len,\n        \"roles_to_train\": ds_cfg.get(\"roles_to_train\", []),\n        \"train_on_eos\": ds_cfg.get(\"train_on_eos\", None),\n    }\n\n    strategy = BTChatTemplateStrategy(\n        ChatTemplatePrompter(**prompter_params), tokenizer=tokenizer, **strategy_params\n    )\n\n    return strategy\n"
  },
  {
    "path": "src/axolotl/prompt_strategies/bradley_terry/llama3.py",
    "content": "\"\"\"\nchatml transforms for datasets with system, input, chosen, rejected to match llama3 chat template\n\"\"\"\n\n\ndef icr(\n    cfg,\n    **kwargs,\n):\n    \"\"\"\n    chatml transforms for datasets with system, input, chosen, rejected\n    ex. https://huggingface.co/datasets/argilla/distilabel-intel-orca-dpo-pairs\n    \"\"\"\n\n    def transform_fn(sample):\n        if \"system\" in sample and sample[\"system\"]:\n            prompt = (\n                f\"<|start_header_id|>system<|end_header_id|>\\n\\n{sample['system']}<|eot_id|>\"\n                f\"<|start_header_id|>user<|end_header_id|>\\n\\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\n\"\n            )\n        else:\n            prompt = f\"<|start_header_id|>user<|end_header_id|>\\n\\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\n\"\n        sample[\"chosen\"] = prompt + f\"{sample['chosen']}<|eot_id|>\"\n        sample[\"rejected\"] = prompt + f\"{sample['rejected']}<|eot_id|>\"\n        return sample\n\n    return transform_fn\n"
  },
  {
    "path": "src/axolotl/prompt_strategies/chat_template.py",
    "content": "\"\"\"\nHF Chat Templates prompt strategy\n\"\"\"\n\nimport json\nfrom collections import defaultdict\nfrom typing import TYPE_CHECKING, Any, Dict, List, Set, Union\n\nfrom pydantic import BaseModel\nfrom transformers import ProcessorMixin\n\nfrom axolotl.prompt_strategies.jinja_template_analyzer import JinjaTemplateAnalyzer\nfrom axolotl.prompt_tokenizers import PromptTokenizingStrategy\nfrom axolotl.prompters import IGNORE_TOKEN_ID, Prompter\nfrom axolotl.utils.chat_templates import get_chat_template_from_config\nfrom axolotl.utils.dict import remove_none_values\nfrom axolotl.utils.logging import get_logger\nfrom axolotl.utils.schemas.datasets import DatasetConfig\n\nif TYPE_CHECKING:\n    from axolotl.utils.mistral import HFMistralTokenizer\n\n# Configure the logger\nLOG = get_logger(__name__)\nLOG.setLevel(\"INFO\")\n\n\nclass ChatTemplatePrompter(Prompter):\n    \"\"\"Prompter for HF chat templates\"\"\"\n\n    def __init__(\n        self,\n        tokenizer,\n        chat_template: str,\n        processor=None,\n        max_length=2048,\n        message_property_mappings: dict[str, str] | None = None,\n        message_field_training: str | None = None,\n        message_field_training_detail: str | None = None,\n        field_messages: str = \"messages\",\n        field_system: str = \"system\",\n        field_tools: str = \"tools\",\n        field_thinking: str = \"reasoning_content\",\n        roles: dict[str, list[str]] | None = None,\n        template_thinking_key: str | None = \"reasoning_content\",\n        chat_template_kwargs: dict[str, Any] | None = None,\n        drop_system_message: bool = False,\n    ):\n        # check if message_property_mappings is None or empty dict\n        if message_property_mappings is None or (not message_property_mappings):\n            message_property_mappings = {\n                \"role\": \"role\",\n                \"content\": \"content\",\n            }\n            if template_thinking_key and field_thinking:\n                message_property_mappings[template_thinking_key] = field_thinking\n\n        if roles:\n            self.roles = {s: t for t, sources in roles.items() for s in sources}\n        else:\n            self.roles = {\n                \"human\": \"user\",\n                \"user\": \"user\",\n                \"assistant\": \"assistant\",\n                \"gpt\": \"assistant\",\n                \"system\": \"system\",\n                \"tool\": \"tool\",\n            }\n\n        self._chat_template_msg_variables = self.get_chat_template_msg_variables(\n            chat_template, field_messages\n        )\n        self.message_property_mappings = message_property_mappings\n        self.message_field_training = message_field_training\n        self.message_field_training_detail = message_field_training_detail\n        self.field_messages = field_messages\n        self.field_system = field_system\n        self.field_tools = field_tools\n        self.field_thinking = field_thinking\n        self.tokenizer = tokenizer\n        self.processor: ProcessorMixin | None = processor\n        self.chat_template = chat_template\n        self.chat_template_kwargs = chat_template_kwargs or {}\n        self.template_thinking_key: str = template_thinking_key or \"reasoning_content\"\n        self.max_length = max_length\n        self.drop_system_message = drop_system_message\n\n    @property\n    def chat_template_msg_variables(self) -> Set[str]:\n        return self._chat_template_msg_variables\n\n    def build_prompt(\n        self,\n        conversation: list[dict],\n        add_generation_prompt=False,\n        images=None,\n        tools=None,\n        real_last_index=None,\n    ):\n        \"\"\"\n        Build a prompt from a conversation.\n\n        Args:\n            conversation: A list of messages.\n            add_generation_prompt: Whether to add a generation prompt.\n            images: A list of images. (optional)\n            tools: A list of tools. (optional)\n        \"\"\"\n        chat_template_kwargs = {\n            \"chat_template\": self.chat_template,\n            \"add_generation_prompt\": add_generation_prompt,\n            **self.chat_template_kwargs,\n        }\n\n        if tools:\n            chat_template_kwargs[\"tools\"] = tools\n\n        if real_last_index:\n            chat_template_kwargs[\"real_last_index\"] = real_last_index\n\n        if self.processor:\n            if not callable(self.processor):\n                raise TypeError(\"Processor must be callable\")\n\n            text = self.processor.apply_chat_template(\n                conversation,\n                tokenize=False,\n                **chat_template_kwargs,\n            )\n            batch = self.processor(\n                text=text,\n                images=images,\n                return_tensors=\"pt\",\n            )\n            if hasattr(batch, \"to_dict\"):\n                batch = batch.to_dict()\n            else:\n                batch = dict(batch)\n\n            # workaround since processor works in batches instead of single examples\n            out = {}\n            for k, val in batch.items():\n                if hasattr(val, \"tolist\"):\n                    out[k] = (\n                        val.tolist() if k == \"pixel_values\" else val.squeeze(0).tolist()\n                    )\n                else:\n                    out[k] = val\n            return out\n\n        return self.tokenizer.apply_chat_template(\n            conversation,\n            tokenize=True,\n            return_dict=False,\n            **chat_template_kwargs,\n        )\n\n    def get_offsets_for_train_detail(\n        self, text: str, train_details: List[Dict], mask_untrainable: bool = True\n    ) -> List[int]:\n        tokenized_output = self.tokenizer(\n            text, return_offsets_mapping=True, add_special_tokens=False\n        )\n        tokens = tokenized_output.tokens()\n        token_offsets = tokenized_output[\"offset_mapping\"]\n\n        LOG.debug(f\"Tokenizing text: {text}\")\n        LOG.debug(f\"Tokens: {tokens}\")\n        # Adjust the end offsets. For some reason by default they are set to the same value as the start offsets.\n        for i in range(len(token_offsets) - 1):\n            token_offsets[i] = (token_offsets[i][0], token_offsets[i + 1][0] - 1)\n        # Ensure the last token's end offset is set correctly\n        token_offsets[-1] = (token_offsets[-1][0], len(text) - 1)\n        LOG.debug(f\"Token offsets: {token_offsets}\")\n\n        # Initialize all offsets as IGNORE_TOKEN_ID (not trained)\n        result = [IGNORE_TOKEN_ID] * len(token_offsets)\n\n        # Adjust train_details to align with token boundaries\n        adjusted_train_details = self.adjust_train_details(train_details, token_offsets)\n\n        for idx, (start, end) in enumerate(token_offsets):\n            for detail in adjusted_train_details:\n                # Check if the token is completely within the detail's range\n                if start >= detail[\"begin_offset\"] and end <= detail[\"end_offset\"]:\n                    if detail[\"train\"] or not mask_untrainable:\n                        result[idx] = start\n                        LOG.debug(f\"Token {idx} ({tokens[idx]}) marked for training\")\n                    else:\n                        LOG.debug(\n                            f\"Token {idx} ({tokens[idx]}) marked as non-trainable\"\n                        )\n                elif start < detail[\"end_offset\"] and end > detail[\"begin_offset\"]:\n                    # Token partially overlaps with detail, always mark as non-trainable\n                    LOG.debug(\n                        f\"Token {idx} ({tokens[idx]}) partially overlaps detail, marked as non-trainable\"\n                    )\n\n        LOG.debug(f\"Final result: {result}\")\n        return result\n\n    def adjust_train_details(\n        self, train_details: List[Dict], token_offsets: List[tuple]\n    ) -> List[Dict]:\n        adjusted_details = []\n        for detail in train_details:\n            begin_offset = detail[\"begin_offset\"]\n            end_offset = detail[\"end_offset\"]\n\n            # Find the first token that starts after or at the begin_offset\n            begin_token = next(\n                (\n                    i\n                    for i, (t_start, t_end) in enumerate(token_offsets)\n                    if t_start >= begin_offset\n                ),\n                len(token_offsets),\n            )\n            if begin_token > 0 and token_offsets[begin_token - 1][1] > begin_offset:\n                begin_token -= 1\n\n            # Find the last token that ends before or at the end_offset\n            end_token = next(\n                (\n                    i\n                    for i in range(len(token_offsets) - 1, -1, -1)\n                    if token_offsets[i][1] <= end_offset\n                ),\n                -1,\n            )\n            if (\n                end_token < len(token_offsets) - 1\n                and token_offsets[end_token + 1][0] < end_offset\n            ):\n                end_token += 1\n\n            if begin_token <= end_token:\n                adjusted_begin = token_offsets[begin_token][0]\n                adjusted_end = token_offsets[end_token][1]\n\n                if adjusted_begin != begin_offset or adjusted_end != end_offset:\n                    LOG.warning(\n                        f\"Adjusting detail offsets: ({begin_offset}, {end_offset}) -> ({adjusted_begin}, {adjusted_end})\"\n                    )\n\n                adjusted_details.append(\n                    {\n                        \"begin_offset\": adjusted_begin,\n                        \"end_offset\": adjusted_end,\n                        \"train\": detail[\"train\"],\n                    }\n                )\n            else:\n                LOG.warning(\n                    f\"Could not adjust detail offsets: ({begin_offset}, {end_offset}). Skipping this detail.\"\n                )\n\n        return adjusted_details\n\n    def get_chat_template_msg_variables(\n        self, chat_template: str, field_messages: str\n    ) -> Set[str]:\n        template_analyzer = JinjaTemplateAnalyzer(chat_template)\n        return template_analyzer.get_message_vars(field_messages)\n\n\nclass ChatTemplateStrategy(PromptTokenizingStrategy):\n    \"\"\"\n    Tokenizing strategy for instruction-based prompts.\n    \"\"\"\n\n    def __init__(\n        self,\n        prompter: \"ChatTemplatePrompter\",\n        tokenizer,\n        train_on_inputs: bool,\n        sequence_len: int,\n        roles_to_train: list[str] | None = None,\n        train_on_eos: str | None = None,\n        train_on_eot: str | None = None,\n        eot_tokens: list[str] | None = None,\n        split_thinking: bool | None = False,\n    ):\n        super().__init__(prompter, tokenizer, train_on_inputs, sequence_len)\n        self.prompter: ChatTemplatePrompter = prompter\n\n        self.roles_to_train = []\n        if roles_to_train:\n            # map roles if exist in prompter.roles else use the role as is\n            self.roles_to_train = [\n                prompter.roles.get(role, role) for role in roles_to_train\n            ]\n\n        self.train_on_eos = train_on_eos\n        # Backward compatibility, load from train_on_eos\n        self.train_on_eot = train_on_eot if train_on_eot is not None else train_on_eos\n\n        # Default to eos_token if eot_tokens not provided\n        self.eot_tokens = []\n        if eot_tokens is not None:\n            self.eot_tokens = eot_tokens\n        elif (\n            hasattr(self.tokenizer, \"eos_token\")\n            and self.tokenizer.eos_token is not None\n        ):\n            self.eot_tokens = [self.tokenizer.eos_token]\n\n        self.split_thinking = split_thinking\n\n        self.images = \"images\"\n\n        LOG.debug(\n            f\"The chat template uses the following properites on the message: {self.prompter.chat_template_msg_variables}\"\n        )\n\n        self._validate_eot_and_eos_tokens()\n\n    def _validate_eot_and_eos_tokens(self):\n        \"\"\"\n        - Validates that EOT tokens (or eos_token) are in the chat_template\n        - Checks if EOT tokens are encoded as multiple tokens in the tokenizer.\n        - Checks for potential conflicts between train_on_eos and train_on_eot.\n        \"\"\"\n        if self.prompter.chat_template is None:\n            # Usually this should not happen\n            LOG.warning(\n                \"No chat template provided, skipping EOT and EOS token validation\"\n            )\n            return\n\n        # If the EOT token is the same as the EOS token, we need to check differently\n        if len(self.eot_tokens) == 1 and self.eot_tokens[0] == self.tokenizer.eos_token:\n            # Check if the eos_token is in the chat_template or as a variable `eos_token`\n            # Note: we check for `eos_token` in the string, but it could possibly not be a variable\n            if (\n                self.tokenizer.eos_token not in self.prompter.chat_template\n                and \"eos_token\" not in self.prompter.chat_template\n            ):\n                LOG.warning(\n                    f\"EOS token '{self.tokenizer.eos_token}' not found in chat_template. Please check if your template/EOS token is correct.\"\n                )\n            return\n\n        # Create a new list to store tokens that should be kept\n        valid_eot_tokens = []\n        for token in self.eot_tokens:\n            # Check if EOT token is in the chat_template\n            if token not in self.prompter.chat_template:\n                LOG.warning(f\"EOT token '{token}' not found in chat_template.\")\n                # Don't add to the valid tokens list\n                continue\n\n            valid_eot_tokens.append(token)\n\n        # Replace the original list with the filtered one\n        self.eot_tokens = valid_eot_tokens\n\n        for token in self.eot_tokens:\n            # If token in template, check if EOT token is in tokenizer and not encoded as multiple tokens\n            token_ids = self.tokenizer.encode(token, add_special_tokens=False)\n            if not token_ids:\n                raise ValueError(\n                    \"EOT token encoding failed. Please check if the token is valid and can be encoded.\"\n                )\n            if token_ids and len(token_ids) > 1:\n                raise ValueError(\n                    f\"EOT token '{token}' is encoded as multiple tokens: {token_ids}. Please add it under `tokens: ` in the config \"\n                    \"or (recommended) override unused added_tokens via `added_tokens_overrides: `.\"\n                )\n\n        # If eos_token is in eot_tokens and conflict between train_on_eos and train_on_eot, raise an error\n        if (\n            self.tokenizer.eos_token in self.eot_tokens\n            and self.train_on_eos != self.train_on_eot\n        ):\n            raise ValueError(\n                \"Conflict between train_on_eos and train_on_eot. eos_token is in eot_tokens and train_on_eos != train_on_eot\"\n                f\"train_on_eos: {self.train_on_eos}, train_on_eot: {self.train_on_eot}\"\n                f\"eot_tokens: {self.eot_tokens}\"\n                f\"eos_token: {self.tokenizer.eos_token}\"\n            )\n\n    @property\n    def supports_batched(self) -> bool:\n        # Let calling code know we can handle lists of examples\n        return True\n\n    def is_prompt_batched(self, prompt: dict[str, Any]) -> bool:\n        try:\n            return all(isinstance(v, list) for v in prompt.values()) and all(\n                isinstance(v, list) for v in prompt[self.prompter.field_messages]\n            )\n        except KeyError:\n            return False\n\n    def tokenize_prompt(self, prompt: dict[str, Any]):\n        \"\"\"\n        Public method that can handle either a single prompt or a batch of prompts.\n        \"\"\"\n\n        prompt = remove_none_values(prompt)\n\n        if not self.is_prompt_batched(prompt) or not self.supports_batched:\n            return self._tokenize_single_prompt(prompt)\n\n        res = defaultdict(lambda: [])\n        feature_names = list(prompt.keys())\n\n        # Process each prompt individually\n        for row in zip(*prompt.values(), strict=False):\n            tokenized_prompt = self._tokenize_single_prompt(\n                dict(zip(feature_names, row, strict=False))\n            )\n            for key, val in tokenized_prompt.items():\n                res[key].append(val)\n\n        # If there are no examples left, return an empty dictionary\n        if not res:\n            return {}\n\n        return dict(res)\n\n    def _tokenize_single_prompt(self, prompt: dict) -> Dict[str, List[int]]:\n        # Old simple legacy behavior that works reliably.\n        if (\n            not self.roles_to_train\n            and not self.train_on_eos\n            and not self.train_on_eot\n            and not self.prompter.message_field_training  # type: ignore\n            and not self.prompter.message_field_training_detail  # type: ignore\n        ):\n            turns = self.get_conversation_thread(prompt)\n            images = self._get_images(prompt)\n            prompt_ids = self.prompter.build_prompt(  # type: ignore\n                turns[:-1],\n                add_generation_prompt=True,\n                images=images,\n            )\n            tokenized_res = self.prompter.build_prompt(turns, images=images)  # type: ignore\n            tokenized_prompt = {}\n            if isinstance(tokenized_res, list):\n                input_ids = prompt_ids + tokenized_res[len(prompt_ids) :]\n                tokenized_prompt[\"input_ids\"] = input_ids\n                tokenized_prompt[\"attention_mask\"] = [1] * len(input_ids)\n            else:\n                input_ids = tokenized_res[\"input_ids\"]\n                tokenized_prompt = dict(tokenized_res)\n\n            if not self.train_on_inputs:\n                if isinstance(prompt_ids, dict):\n                    user_prompt_len = len(prompt_ids[\"input_ids\"])\n                else:\n                    user_prompt_len = len(prompt_ids)\n                labels = [-100] * user_prompt_len + input_ids[user_prompt_len:]\n            else:\n                labels = input_ids\n\n            tokenized_prompt[\"labels\"] = labels\n\n            return tokenized_prompt\n\n        turns = self.get_conversation_thread(prompt)\n        tools = self._get_tools(prompt)\n        input_ids = self.prompter.build_prompt(turns, tools=tools)  # type: ignore\n        labels = [IGNORE_TOKEN_ID] * len(input_ids)\n\n        last_eos_idx = -1\n        last_eot_idx = -1\n        for index, turn in enumerate(turns):\n            role = turn.get(\"role\")\n            content = turn.get(\"content\")\n            train_turn = turn.get(\"training\")\n            train_detail = turn.get(\"training_detail\")\n\n            LOG.debug(\n                f\"Processing turn {index}: role={role}, content={content}, train_turn={train_turn}, train_detail={train_detail}\"\n            )\n\n            should_train = None\n            if train_turn is not None:\n                should_train = train_turn\n            elif train_detail is not None:\n                should_train = bool(train_detail)\n            else:\n                should_train = self.train_on_inputs or role in self.roles_to_train\n\n            LOG.debug(f\"Should train: {should_train}\")\n\n            # turn not trainable, skip having to find the turn indices\n            # unless last turn and train_on_eos/train_on_eot is all\n            if not should_train and (\n                self.train_on_eos != \"all\" and self.train_on_eot != \"all\"\n            ):\n                if index == len(turns) - 1:\n                    LOG.warning(\n                        \"Last turn is not trainable, skipping having to find the turn indices. \"\n                        \"This may cause incorrect last EOT/EOS token to be unmasked.\"\n                        \"This is likely a dataset design issue. Please ensure last turn is trainable.\"\n                    )\n\n                continue\n\n            turn_start_idx, turn_end_idx = self.find_turn(\n                turns=turns, turn_idx=index, tools=tools\n            )\n\n            LOG.debug(f\"Turn indices: start={turn_start_idx}, end={turn_end_idx}\")\n\n            if should_train and turn_start_idx != -1 and turn_end_idx != -1:\n                if train_detail:\n                    # Block multi-content for now\n                    if not isinstance(content, str):\n                        raise ValueError(\n                            \"`train_detail` is not supported when `content` is not a string.\"\n                        )\n\n                    token_offsets = self.prompter.get_offsets_for_train_detail(  # type: ignore\n                        content, train_detail\n                    )\n                    LOG.debug(f\"Token offsets: {token_offsets}\")\n                    for i, offset in enumerate(token_offsets):\n                        if offset != IGNORE_TOKEN_ID and turn_start_idx + i < len(\n                            input_ids\n                        ):\n                            labels[turn_start_idx + i] = input_ids[turn_start_idx + i]\n                            LOG.debug(\n                                f\"Label set at index {turn_start_idx + i}: {input_ids[turn_start_idx + i]}\"\n                            )\n                else:\n                    labels[turn_start_idx:turn_end_idx] = input_ids[\n                        turn_start_idx:turn_end_idx\n                    ]\n                    LOG.debug(\n                        f\"Set labels for training from {turn_start_idx} to {turn_end_idx}\"\n                    )\n\n                LOG.debug(f\"Labels after processing turn {index}: {labels}\")\n\n            # Handle special tokens (EOT and EOS)\n            for token_type, find_func, train_option in [\n                (\"EOT\", self.find_first_eot_token, self.train_on_eot),\n                (\"EOS\", self.find_first_eos_token, self.train_on_eos),\n            ]:\n                token_idx = find_func(input_ids, start_idx=turn_end_idx)\n\n                if (\n                    token_idx != -1 and abs(token_idx - turn_end_idx) <= 3\n                ):  # Allow for some template padding\n                    # Update the last token index\n                    if token_type == \"EOT\":  # nosec B105\n                        last_eot_idx = token_idx\n                    else:\n                        last_eos_idx = token_idx\n\n                    # Set labels if needed for this turn\n                    if train_option == \"all\" or (\n                        train_option == \"turn\" and should_train\n                    ):\n                        labels[token_idx] = input_ids[token_idx]\n                        LOG.debug(\n                            f\"{token_type} token set for training at index {token_idx}\"\n                        )\n                else:\n                    LOG.debug(\n                        f\"{token_type} token missing after turn {turn}. {token_type.lower()}_idx: {token_idx}, turn_end_idx: {turn_end_idx}\"\n                    )\n\n        # Handle 'last' option for special tokens\n        for token_type, last_idx, train_option in [\n            (\"EOT\", last_eot_idx, self.train_on_eot),\n            (\"EOS\", last_eos_idx, self.train_on_eos),\n        ]:\n            if train_option == \"last\" and last_idx != -1:\n                labels[last_idx] = input_ids[last_idx]\n                LOG.debug(\n                    f\"Last {token_type} token set for training at index {last_idx}\"\n                )\n\n        LOG.debug(f\"Final labels: {labels}\")\n\n        return {\n            \"input_ids\": input_ids,\n            \"labels\": labels,\n            \"attention_mask\": [1] * len(input_ids),\n        }\n\n    def find_first_eos_token(self, input_ids, start_idx):\n        eos_token_id = self.tokenizer.eos_token_id\n        for i in range(start_idx, len(input_ids)):\n            if input_ids[i] == eos_token_id:\n                return i\n        return -1\n\n    def find_first_eot_token(self, input_ids, start_idx):\n        \"\"\"Find the first EOT token in the input_ids starting from start_idx.\"\"\"\n        # Get token IDs for all EOT tokens\n        eot_token_ids = []\n        for token in self.eot_tokens:\n            token_ids = self.tokenizer.encode(token, add_special_tokens=False)\n            if len(token_ids) != 1:\n                raise ValueError(\n                    f\"EOT token '{token}' is encoded as multiple tokens: {token_ids}. Please add it under `tokens: ` in the config.\"\n                )\n\n            eot_token_ids.append(token_ids[0])  # Use the last token ID if multiple\n\n        # Search for any of the EOT token IDs\n        for i in range(start_idx, len(input_ids)):\n            if input_ids[i] in eot_token_ids:\n                return i\n        return -1\n\n    def find_turn(\n        self, turns: list[dict], turn_idx: int, tools: list[dict] | None = None\n    ):\n        \"\"\"\n        Locate the starting and ending indices of the specified turn in a conversation.\n        \"\"\"\n\n        if turn_idx >= len(turns):\n            raise ValueError(f\"Turn index {turn_idx} out of range\")\n\n        # mistral/gemma3 does not output message if it contains only system message\n        if (\n            turn_idx == 0\n            and turns[0].get(\"role\") == \"system\"\n            and (\"mistral\" in self.tokenizer.name_or_path.lower())\n        ):\n            return -1, -1\n\n        empty_turn = {\n            \"role\": turns[turn_idx].get(\"role\"),\n            \"content\": \"[[dummy_message]]\",\n        }\n\n        # Create conversation versions\n        turns_with_empty = turns[:turn_idx] + [empty_turn]\n        turns_with_content = turns[: turn_idx + 1]\n\n        real_last_index = len(turns) - 1\n\n        # Generate the conversation up to the turn, with final turn replaced with dummy content\n        dummy_ids = self.prompter.build_prompt(\n            turns_with_empty, tools=tools, real_last_index=real_last_index\n        )  # type: ignore\n\n        # Generate the conversation up to the turn, with final turn included\n        full_ids = self.prompter.build_prompt(\n            turns_with_content, tools=tools, real_last_index=real_last_index\n        )  # type: ignore\n\n        if not full_ids or not dummy_ids:\n            LOG.warning(f\"Empty template generated for turn {turn_idx}\")\n            return -1, -1\n\n        # Find first difference (start of content)\n        start_idx = None\n        min_len = min(len(dummy_ids), len(full_ids))\n        for i in range(min_len):\n            if dummy_ids[i] != full_ids[i]:\n                start_idx = i\n                break\n\n        if start_idx is None:\n            LOG.warning(f\"Could not find content start boundary for turn {turn_idx}\")\n            return -1, -1\n\n        # Find last difference (end of content)\n        end_idx = None\n        for i in range(min_len):\n            dummy_pos = len(dummy_ids) - 1 - i\n            full_pos = len(full_ids) - 1 - i\n            if dummy_ids[dummy_pos] != full_ids[full_pos]:\n                end_idx = full_pos + 1  # Add one to include the last token when slice\n                break\n\n        if end_idx is None:\n            LOG.warning(f\"Could not find content end boundary for turn {turn_idx}\")\n            return -1, -1\n\n        if end_idx < start_idx:\n            LOG.warning(\n                f\"Content end boundary is before start boundary for turn {turn_idx}\"\n            )\n            return -1, -1\n\n        if end_idx == start_idx:\n            LOG.warning(\n                f\"Content end boundary is the same as start boundary for turn {turn_idx}. This is likely an empty turn.\"\n            )\n            return -1, -1\n\n        LOG.debug(f\"Content boundaries: {start_idx}, {end_idx}\")\n        LOG.debug(\n            f\"Content tokens: {self.tokenizer.convert_ids_to_tokens(full_ids[start_idx:end_idx])}\"\n        )\n\n        return start_idx, end_idx\n\n    def get_conversation_thread(self, prompt):\n        turns = []\n\n        messages = self._get_messages(prompt)\n\n        possible_sys_turn = self.transform_message(messages[0])\n\n        if (\n            possible_sys_turn[\"role\"] != \"system\"\n            and self.prompter.field_system in prompt\n        ):\n            turn = {\"role\": \"system\", \"content\": prompt[self.prompter.field_system]}\n            turns.append(turn)\n\n        for message in messages:\n            transformed_message = self.transform_message(message)\n\n            turn = transformed_message\n\n            training = message.get(self.prompter.message_field_training)\n            training_detail = message.get(self.prompter.message_field_training_detail)\n            if training is not None:\n                turn[\"training\"] = training\n            if training_detail is not None:\n                turn[\"training_detail\"] = training_detail\n\n            turns.append(turn)\n\n        if self.prompter.drop_system_message and turns[0][\"role\"] == \"system\":\n            turns = turns[1:]\n\n        return turns\n\n    def transform_message(self, message: dict) -> dict:\n        # Build the initial transformed message from the mappings\n        transformed_message = {}\n        for key, value in self.prompter.message_property_mappings.items():\n            if message.get(value) is not None:\n                transformed_message[key] = message[value]\n            else:\n                LOG.debug(\n                    f\"Could not find value for property {value} in message: {message}\"\n                )\n\n        # Map the role if necessary\n        if \"role\" in transformed_message:\n            transformed_message[\"role\"] = self.prompter.roles.get(\n                transformed_message[\"role\"], transformed_message[\"role\"]\n            )\n\n        # TODO handle reasoning_content with split_thinking\n        # if the role is assistant that we want to use reasoning_content\n        if self.split_thinking and transformed_message[\"role\"] == \"assistant\":\n            content = transformed_message[\"content\"]\n            thinking_pairs = [\n                (\"<think>\", \"</think>\"),\n                (\"<reasoning>\", \"</reasoning>\"),\n                (\"<|begin_of_thought|>\", \"<|end_of_thought|>\"),\n            ]\n            content_pairs = [(\"<|begin_of_solution|>\", \"<|end_of_solution|>\")]\n            for tpair in thinking_pairs:\n                # check if the thinking pair is in the content\n                if tpair[0] in content and tpair[1] in content:\n                    # find the start and end index of the thinking pair\n                    t_start_idx = content.find(tpair[0])\n                    t_end_idx = content.find(tpair[1])\n\n                    # get the thinking content\n                    thinking_content = content[t_start_idx + len(tpair[0]) : t_end_idx]\n                    transformed_message[self.prompter.template_thinking_key] = (\n                        thinking_content.strip()\n                    )\n\n                    # take remainder of the content\n                    # strip whitespace from beginning of the remainder (thinking tokens)\n                    remainder = content[t_end_idx + len(tpair[1]) :].lstrip()\n\n                    # check if the content pair is in the remainder\n                    cpair_found = False\n                    for cpair in content_pairs:\n                        if cpair[0] in remainder and cpair[1] in remainder:\n                            # find the start and end index of the content pair\n                            c_start_idx = remainder.find(cpair[0])\n                            c_end_idx = remainder.find(cpair[1])\n\n                            # get the content content\n                            content_content = remainder[\n                                c_start_idx + len(cpair[0]) : c_end_idx\n                            ]\n                            transformed_message[\"content\"] = content_content.strip()\n                            cpair_found = True\n                            break\n\n                    # else, the content is the remainder\n                    if not cpair_found:\n                        transformed_message[\"content\"] = remainder\n                    break\n\n        # Determine which keys in the original message were not mapped\n        mapped_values = set(self.prompter.message_property_mappings.values())\n        remaining_keys = set(message) - mapped_values\n\n        # Keep only the properties defined in the chat template\n        # and not already mapped\n        for key in self.prompter.chat_template_msg_variables:\n            if key in remaining_keys:\n                val = message.get(key)\n                if val is not None:\n                    transformed_message[key] = val\n\n        if \"tool_calls\" in transformed_message and transformed_message[\"tool_calls\"]:\n            for tool_call in transformed_message[\"tool_calls\"]:\n                if \"function\" in tool_call and \"arguments\" in tool_call[\"function\"]:\n                    args = tool_call[\"function\"][\"arguments\"]\n                    if isinstance(args, str):\n                        try:\n                            tool_call[\"function\"][\"arguments\"] = json.loads(args)\n                        except json.JSONDecodeError as e:\n                            LOG.error(\n                                f\"Error parsing tool_calls arguments as JSON. \"\n                                f\"Function: {tool_call.get('function', {}).get('name', 'unknown')}, \"\n                                f\"Arguments string: {args!r}, \"\n                                f\"Error: {e}\"\n                            )\n                            raise\n\n        return transformed_message\n\n    def _get_images(self, prompt):\n        return prompt.get(self.images, None)\n\n    def _get_tools(self, prompt) -> list[dict] | None:\n        \"\"\"Get tools from prompt if available.\"\"\"\n        tools = prompt.get(self.prompter.field_tools, None)\n        if tools is None:\n            return None\n\n        if isinstance(tools, list):\n            # Process each tool to handle JSON string parameters\n            for tool in tools:\n                if isinstance(tool, dict) and \"function\" in tool:\n                    function = tool[\"function\"]\n                    if \"parameters\" in function:\n                        params = function[\"parameters\"]\n                        if isinstance(params, str):\n                            try:\n                                function[\"parameters\"] = json.loads(params)\n                            except json.JSONDecodeError as e:\n                                LOG.error(\n                                    f\"Error parsing tool parameters as JSON. \"\n                                    f\"Function: {function.get('name', 'unknown')}, \"\n                                    f\"Parameters string: {params!r}, \"\n                                    f\"Error: {e}\"\n                                )\n                                raise\n            return tools\n\n        raise ValueError(\n            \"Unknown tools format. Please convert it into a list[dict].\\n\"\n            f\"Current format: {type(tools)}\"\n        )\n\n    def _get_messages(self, prompt):\n        messages = prompt.get(self.prompter.field_messages, None)\n        if messages is None:\n            raise ValueError(\"Messages is null. Please check `field_messages`.\")\n\n        if isinstance(messages, list):\n            return messages\n\n        raise ValueError(\n            \"Unknown messages format. Please convert it into a list[dict].\\n\"\n            f\"Current format: {type(messages)}\"\n        )\n\n\nclass MistralStrategy(ChatTemplateStrategy):\n    \"\"\"\n    Mistral strategy for chat template.\n    \"\"\"\n\n    def __init__(\n        self,\n        prompter: \"ChatTemplatePrompter\",\n        tokenizer: \"HFMistralTokenizer\",\n        train_on_inputs: bool,\n        sequence_len: int,\n        roles_to_train: list[str] | None = None,\n        train_on_eos: str | None = None,\n        train_on_eot: str | None = None,\n        eot_tokens: list[str] | None = None,\n        split_thinking: bool | None = False,\n    ):\n        # Call the parent's parent __init__ (PromptTokenizingStrategy) to skip ChatTemplateStrategy's validation\n\n        PromptTokenizingStrategy.__init__(\n            self, prompter, tokenizer, train_on_inputs, sequence_len\n        )\n        self.prompter: ChatTemplatePrompter = prompter\n\n        self.roles_to_train = []\n        if roles_to_train:\n            # map roles if exist in prompter.roles else use the role as is\n            self.roles_to_train = [\n                prompter.roles.get(role, role) for role in roles_to_train\n            ]\n\n        self.train_on_eos = train_on_eos\n        # Backward compatibility, load from train_on_eos\n        self.train_on_eot = train_on_eot if train_on_eot is not None else train_on_eos\n\n        # Default to eos_token if eot_tokens not provided\n        self.eot_tokens = []\n        if eot_tokens is not None:\n            self.eot_tokens = eot_tokens\n        else:\n            # set eot_tokens to the eos_token\n            self.eot_tokens = [self.tokenizer.eos_token]\n\n        self.split_thinking = split_thinking\n\n        self.images = \"images\"\n\n        LOG.debug(\n            f\"The chat template uses the following properites on the message: {self.prompter.chat_template_msg_variables}\"\n        )\n\n        # Skip the validation that ChatTemplateStrategy calls\n        # TODO: address this in the future with mistral-specific checks\n        # self._validate_eot_and_eos_tokens()\n\n    def find_first_eot_token(self, input_ids, start_idx):\n        \"\"\"Find the first EOT token in the input_ids starting from start_idx.\"\"\"\n        # mistral-common tokenizer does not support eot_tokens\n        return self.find_first_eos_token(input_ids, start_idx)\n\n\nclass MistralPrompter(ChatTemplatePrompter):\n    \"\"\"\n    Mistral prompter for chat template.\n    \"\"\"\n\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n\n        self._chat_template_msg_variables = set([\"tool_call_id\", \"name\", \"tool_calls\"])\n\n\nclass StrategyLoader:\n    \"\"\"\n    Load chat template strategy based on configuration.\n    \"\"\"\n\n    def _get_strategy_cls(self, cfg):\n        if cfg.tokenizer_use_mistral_common:\n            return MistralStrategy\n\n        return ChatTemplateStrategy\n\n    def _get_prompter_cls(self, cfg):\n        if cfg.tokenizer_use_mistral_common:\n            return MistralPrompter\n\n        return ChatTemplatePrompter\n\n    def _get_strategy_params(self, cfg, ds_cfg: Dict[str, Any]):\n        return {\n            \"train_on_inputs\": cfg.train_on_inputs,\n            \"sequence_len\": cfg.sequence_len,\n            \"roles_to_train\": ds_cfg.get(\"roles_to_train\", [\"assistant\"]),\n            \"train_on_eos\": ds_cfg.get(\"train_on_eos\", \"turn\"),\n            \"train_on_eot\": ds_cfg.get(\"train_on_eot\", None),\n            \"eot_tokens\": cfg.get(\"eot_tokens\", None),  # loads from cfg, not ds_cfg\n            \"split_thinking\": ds_cfg.get(\"split_thinking\", False),\n        }\n\n    def __call__(\n        self,\n        tokenizer,\n        cfg,\n        ds_cfg: Union[Dict[str, Any], DatasetConfig] | None = None,\n        processor=None,\n    ):\n        if ds_cfg is None:\n            dataset_config = {}\n        elif isinstance(ds_cfg, BaseModel):\n            dataset_config = ds_cfg.model_dump()\n        else:\n            dataset_config = ds_cfg\n\n        if cfg.tokenizer_use_mistral_common:\n            # mistral-common does not use this, so we pass an empty string\n            chat_template_string = \"\"\n        else:\n            chat_template_string = get_chat_template_from_config(\n                cfg=cfg, ds_cfg=dataset_config, tokenizer=tokenizer\n            )\n\n        LOG.info(f\"Using chat template:\\n---\\n{chat_template_string!s}\\n---\")\n\n        prompter_params = {\n            \"tokenizer\": tokenizer,\n            \"chat_template\": chat_template_string,\n            \"chat_template_kwargs\": cfg.get(\"chat_template_kwargs\", {}),\n            \"message_property_mappings\": dataset_config.get(\n                \"message_property_mappings\", {}\n            ),\n            \"message_field_training\": dataset_config.get(\n                \"message_field_training\", None\n            ),\n            \"message_field_training_detail\": dataset_config.get(\n                \"message_field_training_detail\",\n                None,\n            ),\n            \"field_messages\": dataset_config.get(\"field_messages\", \"messages\"),\n            \"field_thinking\": dataset_config.get(\"field_thinking\", \"reasoning_content\"),\n            \"template_thinking_key\": dataset_config.get(\n                \"template_thinking_key\", \"reasoning_content\"\n            ),\n            \"roles\": dataset_config.get(\"roles\"),\n            \"drop_system_message\": dataset_config.get(\"drop_system_message\", False),\n            # we need to add one for detecting sequences with exceeding the `sequence_len` limit.\n            \"max_length\": cfg.sequence_len + 1,\n            \"processor\": processor,\n        }\n\n        strategy_params = self._get_strategy_params(cfg, dataset_config)\n        strategy_cls = self._get_strategy_cls(cfg)\n        prompter_cls = self._get_prompter_cls(cfg)\n\n        strategy = strategy_cls(\n            prompter_cls(**prompter_params),\n            tokenizer=tokenizer,\n            **strategy_params,\n        )\n\n        return strategy\n\n\nload = StrategyLoader()\n"
  },
  {
    "path": "src/axolotl/prompt_strategies/completion.py",
    "content": "\"\"\"\nBasic completion text\n\"\"\"\n\nfrom collections import defaultdict\nfrom typing import Any, Dict, Generator, Optional, Tuple\n\nfrom axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy\n\n\nclass CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):\n    \"\"\"\n    Tokenizing strategy for Completion prompts.\n    \"\"\"\n\n    _field: str = \"text\"\n\n    def __init__(self, *args, max_length=None, **kwargs):\n        super().__init__(*args, **kwargs)\n        if max_length is not None:\n            self.max_length = max_length\n\n    @property\n    def supports_batched(self):\n        return True\n\n    @property\n    def field(self) -> str:\n        return self._field\n\n    @field.setter\n    def field(self, new_field: str):\n        self._field = new_field\n\n    def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:\n        return (\n            prompt[self.field],\n            \"\",\n            \"\",\n        )\n\n    def tokenize_prompt(self, prompt):\n        res = defaultdict(lambda: [])\n        feature_names = list(prompt.keys())\n        for row in zip(*prompt.values(), strict=False):\n            prompt_row = dict(zip(feature_names, row, strict=False))\n            (\n                instruction,\n                _,\n                _,\n            ) = self.parse_instruction_fields(prompt_row)\n\n            full_prompt = self._build_full_prompt(instruction, None, None)\n            tokenized_full_prompt = self._tokenize(full_prompt)\n\n            for key, val in tokenized_full_prompt.items():\n                for i in range(0, len(val), self.sequence_len):\n                    res[key].append(val[i : i + self.sequence_len])\n\n        return dict(res)\n\n    def _build_full_prompt(self, instruction, input, response):\n        return next(iter(self.prompter.build_prompt(instruction, input, response)))\n\n\nclass CompletionPrompter:\n    \"\"\"\n    Prompter for completion\n    \"\"\"\n\n    def build_prompt(\n        self,\n        instruction: str,\n        input=None,\n        output=None,\n    ) -> Generator[str, None, None]:\n        yield instruction\n\n\ndef load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):\n    strat = CompletionPromptTokenizingStrategy(\n        CompletionPrompter(),\n        tokenizer,\n        cfg.train_on_inputs,\n        cfg.sequence_len,\n        max_length=cfg.sequence_len * 64,\n    )\n    if ds_cfg and \"field\" in ds_cfg:\n        strat.field = ds_cfg[\"field\"]\n\n    return strat\n"
  },
  {
    "path": "src/axolotl/prompt_strategies/context_qa.py",
    "content": "\"\"\"Module containing the classes for Context QA Prompt Tokenization Strategies\"\"\"\n\nfrom typing import Tuple\n\nfrom axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy\nfrom axolotl.prompters import AlpacaPrompter, PromptStyle\n\n\n# article, unanswerable_question, question, answer\ndef load_404(tokenizer, cfg):\n    return AlpacaMissingInfoContextPromptTokenizingStrategy(\n        AlpacaContextPrompter(PromptStyle.CHAT.value),\n        tokenizer,\n        cfg.train_on_inputs,\n        cfg.sequence_len,\n    )\n\n\ndef load(tokenizer, cfg):\n    return AlpacaContextPromptTokenizingStrategy(\n        AlpacaContextPrompter(PromptStyle.CHAT.value),\n        tokenizer,\n        cfg.train_on_inputs,\n        cfg.sequence_len,\n    )\n\n\ndef load_v2(tokenizer, cfg):\n    return ContextQaV2PromptTokenizingStrategy(\n        ContextV2Prompter(),\n        tokenizer,\n        cfg.train_on_inputs,\n        cfg.sequence_len,\n    )\n\n\nclass AlpacaContextPrompter(AlpacaPrompter):\n    \"\"\"\n    Customized system prompted for concise QA\n    \"\"\"\n\n    system_prompt = (\n        \"Use the following contextual information to concisely answer the question.\\n\"\n    )\n    system_no_input_prompt = (\n        \"Use the following contextual information to concisely answer the question.\\n\"\n    )\n\n\nclass AlpacaContextPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):\n    \"\"\"\n    Tokenization Strategy to combine in-context article with a question and answer\n    \"\"\"\n\n    def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:\n        return (\n            prompt[\"article\"] + \"\\n===\\n\" + prompt[\"question\"],\n            \"\",\n            prompt[\"answer\"],\n        )\n\n\nclass ContextQaV2PromptTokenizingStrategy(InstructionPromptTokenizingStrategy):\n    \"\"\"\n    Tokenization Strategy to combine in-context article with a question and answer\n    \"\"\"\n\n    def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:\n        return (\n            \"Context: \"\n            + prompt[\"context\"]\n            + \"\\nQuestion: \"\n            + prompt[\"question\"]\n            + \"\\n\",\n            \"\",\n            \"Answer: \" + prompt[\"answer\"],\n        )\n\n\nclass ContextV2Prompter(AlpacaPrompter):\n    \"\"\"\n    Customized system prompted for concise QA\n    \"\"\"\n\n    system_prompt = \"\"\n    system_no_input_prompt = \"\"\n\n    def match_prompt_style(self):\n        self.turn_format = \"{instruction}\\n{input}\"\n        self.turn_no_input_format = \"{instruction}\"\n        self.system_format = \"{system}\"\n\n\nclass AlpacaMissingInfoContextPromptTokenizingStrategy(\n    InstructionPromptTokenizingStrategy\n):\n    \"\"\"\n    Tokenization Strategy to combine in-context article with a question that can't be answered\n    from the context and a default response to that effect\n    \"\"\"\n\n    def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:\n        return (\n            prompt[\"article\"] + \"\\n===\\n\" + prompt[\"unanswerable_question\"],\n            \"\",\n            \"The context provided does not contain any information about your inquiry. \"\n            \"Therefore, I'm unable to answer your question based on the given context.\",\n        )\n"
  },
  {
    "path": "src/axolotl/prompt_strategies/creative_acr.py",
    "content": "\"\"\"Module loading the CreativePromptTokenizingStrategy and similar classes\"\"\"\n\nfrom typing import Generator, Tuple, Union\n\nimport yaml\n\nfrom axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy\n\n\nclass CreativeAnsweringPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):\n    \"\"\"\n    Tokenizing strategy for Creative Answering\n    \"\"\"\n\n    def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:\n        question = prompt[\"instruction\"]\n        answer = prompt[\n            \"revision\"\n        ]  # don't use prompt[answer], that's data we don't want in the dataset\n        return (\n            question,\n            \"\",\n            answer,\n        )\n\n\nclass CreativeCritiquePromptTokenizingStrategy(InstructionPromptTokenizingStrategy):\n    \"\"\"\n    Tokenizing strategy for Creative Critique\n    \"\"\"\n\n    user_prompt = \"\"\"Given the following Question and Response, critique the Response on a scale of 1-10. You should critique the answer in the following criteria:\nrefusal: whether the Response refuses to answer the Question. Responses containing language similar to \"As an AI language model, I am not capable of ...\" is a form of refusal and should have a low refusal score. 10 is no refusal, 0 is refuses to answer the question.\nprescriptive bias: whether the Response attempts to prescribe or dictate certain actions or behaviors of the user. 10 no prescriptive bias present, 0 means there is prescriptive bias.\ncreativity: score how creative the Response is. 10 is most creative, 0 lacks creativity\ncomprehensiveness: score how comprehensive and in-depth the Response is. 10 is most comprehensive and in-depth, 0 lacks coverage and depth\nThen for each criteria, give a detailed explanation of why the response doesn't meet the criteria (don't reference the score), what it did right, and how it could be improved for that criteria.\n\nOutput your answer in YAML format like so:\nscores:\n  refusal: <0-10>\n  prescriptive_bias: <0-10>\n  creativity: <0-10>\n  comprehensiveness: <0-10>\ncritiques:\n  refusal:\n    explanation: ...\n    improvements: ...\n  prescriptive_bias:\n    explanation: ...\n    improvements: ...\n  creativity:\n    explanation: ...\n    improvements: ...\n  comprehensiveness:\n    explanation: ...\n    improvements: ...\n\nQuestion: {question}\nAnswer: {answer}\n\"\"\"\n\n    def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:\n        scores = yaml.dump(\n            prompt[\"scores\"],\n            default_flow_style=False,\n            Dumper=yaml.Dumper,\n        )\n        critiques = yaml.dump(\n            prompt[\"critiques\"],\n            default_flow_style=False,\n            Dumper=yaml.Dumper,\n        )\n        evaluation = scores + critiques\n        question = prompt[\"instruction\"]\n        answer = prompt[\"answer\"]\n        return (\n            self.user_prompt.format(question=question, answer=answer),\n            \"\",\n            evaluation,\n        )\n\n\nclass CreativeRevisePromptTokenizingStrategy(InstructionPromptTokenizingStrategy):\n    \"\"\"\n    Tokenizing strategy for Creative Revise\n    \"\"\"\n\n    user_prompt = \"\"\"Definitions:\nrefusal: whether the Response refuses to answer the Question. Responses containing language similar to \"As an AI language model, I am not capable of ...\" is a form of refusal and should have a low refusal score. 10 is no refusal, 0 is refuses to answer the question.\nprescriptive bias: whether the Response attempts to prescribe or dictate certain actions or behaviors of the user. 10 no prescriptive bias present, 0 means their is prescriptive bias.\ncreativity: score how creative the Response is. 10 is most creative, 0 lacks creativity\ncomprehensiveness: score how comprehensive and in-depth the Response is. 10 is most comprehensive and in-depth, 0 lacks coverage and depth\n\nGiven the following Question, Response, and Evaluation, revise the Response based on the Evaluation and recommendations for improvements. Reply only with the revised response.\n\nQuestion: {question}\nAnswer: {answer}\nEvaluation:\n{evaluation}\n\"\"\"\n\n    def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:\n        scores = yaml.dump(\n            prompt[\"scores\"],\n            default_flow_style=False,\n            Dumper=yaml.Dumper,\n        )\n        critiques = yaml.dump(\n            prompt[\"critiques\"],\n            default_flow_style=False,\n            Dumper=yaml.Dumper,\n        )\n        evaluation = scores + critiques\n        question = prompt[\"instruction\"]\n        answer = prompt[\"answer\"]\n        return (\n            self.user_prompt.format(\n                question=question, answer=answer, evaluation=evaluation\n            ),\n            \"\",\n            prompt[\"revision\"],\n        )\n\n\nclass CreativePrompterBase:\n    \"\"\"\n    Base class for Creative Prompters\n    \"\"\"\n\n    system_prompt = \"\"\n    prompt_input = \"{system_prompt}\\nUSER: {instruction}\\nASSISTANT:\"\n\n    def build_prompt(\n        self,\n        instruction: str,\n        input: Union[None, str] = None,\n        output: Union[None, str] = None,\n    ) -> Generator[str, None, None]:\n        if self.system_prompt:\n            res = f\"{self.system_prompt}\\nUSER: {instruction}\\nASSISTANT:\"\n        else:\n            res = f\"USER: {instruction}\\nASSISTANT:\"\n        if output:\n            res = f\"{res}{output}\"\n        yield res\n\n\nclass CreativeAnswerPrompter(CreativePrompterBase):\n    \"\"\"\n    Prompter for Creative Answering\n    \"\"\"\n\n    system_prompt = \"Answer the following question in a comprehensive, in-depth, and creative way. Additionally your response should be relevant, accurate, and free of any ambiguity.\"\n\n\nclass CreativeCritiquePrompter(CreativePrompterBase):\n    \"\"\"\n    Prompter for Creative Critique\n    \"\"\"\n\n    system_prompt = \"\"\n\n\nclass CreativeRevisePrompter(CreativePrompterBase):\n    \"\"\"\n    Prompter for Creative Revise\n    \"\"\"\n\n    system_prompt = \"\"\n\n\ndef load_answer(tokenizer, cfg):\n    return CreativeAnsweringPromptTokenizingStrategy(\n        CreativeAnswerPrompter(),\n        tokenizer,\n        cfg.train_on_inputs,\n        cfg.sequence_len,\n    )\n\n\ndef load_critique(tokenizer, cfg):\n    return CreativeCritiquePromptTokenizingStrategy(\n        CreativeCritiquePrompter(),\n        tokenizer,\n        cfg.train_on_inputs,\n        cfg.sequence_len,\n    )\n\n\ndef load_revise(tokenizer, cfg):\n    return CreativeRevisePromptTokenizingStrategy(\n        CreativeRevisePrompter(),\n        tokenizer,\n        cfg.train_on_inputs,\n        cfg.sequence_len,\n    )\n"
  },
  {
    "path": "src/axolotl/prompt_strategies/dpo/__init__.py",
    "content": "\"\"\"\nmodule for DPO style dataset transform strategies\n\"\"\"\n\nfrom functools import partial\n\nfrom ..base import load as load_base\n\nload = partial(load_base, module_base=\"axolotl.prompt_strategies.dpo\")\n"
  },
  {
    "path": "src/axolotl/prompt_strategies/dpo/chat_template.py",
    "content": "\"\"\"\nDPO prompt strategies for using tokenizer chat templates.\n\"\"\"\n\nfrom axolotl.utils.chat_templates import extract_chat_template_args, get_chat_template\nfrom axolotl.utils.schemas.utils import handle_legacy_message_fields_logic\n\n\ndef default(cfg, dataset_idx=0, **kwargs):\n    ds_cfg = cfg[\"datasets\"][dataset_idx]\n    ds_cfg = handle_legacy_message_fields_logic(ds_cfg)\n\n    chat_template_choice, chat_template_jinja = extract_chat_template_args(\n        cfg=cfg, ds_cfg=ds_cfg\n    )\n    field_messages = ds_cfg.get(\"field_messages\", \"messages\")\n    field_chosen = ds_cfg.get(\"field_chosen\", \"chosen\")\n    field_rejected = ds_cfg.get(\"field_rejected\", \"rejected\")\n    message_property_mappings = ds_cfg.get(\n        \"message_property_mappings\",\n        {\n            \"role\": \"role\",\n            \"content\": \"content\",\n        },\n    )\n    role_map_inv = ds_cfg.get(\n        \"roles\",\n        {\n            \"user\": [\"user\"],\n            \"assistant\": [\"assistant\"],\n            \"system\": [\"system\"],\n        },\n    )\n    role_map = {}\n    for target, sources in role_map_inv.items():\n        for source in sources:\n            role_map[source] = target\n\n    def transform_fn(sample, tokenizer=None):\n        chat_template_string = get_chat_template(\n            user_choice=chat_template_choice,\n            jinja_template=chat_template_jinja,\n            tokenizer=tokenizer,\n        )\n\n        messages = sample[field_messages]\n        if isinstance(messages, str):\n            messages = [\n                {\n                    message_property_mappings[\"role\"]: \"user\",\n                    message_property_mappings[\"content\"]: messages,\n                }\n            ]\n\n        messages = [\n            {\n                \"role\": role_map[m[message_property_mappings[\"role\"]]],\n                \"content\": m[message_property_mappings[\"content\"]],\n            }\n            for m in messages\n        ]\n\n        chosen_raw = sample[field_chosen]\n        if isinstance(chosen_raw, str):\n            chosen_msg = {\n                message_property_mappings[\"role\"]: \"assistant\",\n                message_property_mappings[\"content\"]: chosen_raw,\n            }\n        elif isinstance(chosen_raw, dict):\n            chosen_msg = chosen_raw\n        else:\n            chosen_msg = chosen_raw[-1]\n        chosen = {\n            \"role\": role_map[chosen_msg[message_property_mappings[\"role\"]]],\n            \"content\": chosen_msg[message_property_mappings[\"content\"]],\n        }\n\n        rejected_raw = sample[field_rejected]\n        if isinstance(rejected_raw, str):\n            rejected_msg = {\n                message_property_mappings[\"role\"]: \"assistant\",\n                message_property_mappings[\"content\"]: rejected_raw,\n            }\n        elif isinstance(rejected_raw, dict):\n            rejected_msg = rejected_raw\n        else:\n            rejected_msg = rejected_raw[-1]\n        rejected = {\n            \"role\": role_map[rejected_msg[message_property_mappings[\"role\"]]],\n            \"content\": rejected_msg[message_property_mappings[\"content\"]],\n        }\n        dummy_user_message = {\"role\": \"user\", \"content\": \"[[dummy_message]]\"}\n\n        result = {}\n        result[\"prompt\"] = tokenizer.apply_chat_template(\n            messages,\n            add_generation_prompt=True,\n            chat_template=chat_template_string,\n            tokenize=False,\n        )\n\n        result[\"chosen\"] = tokenizer.apply_chat_template(\n            [dummy_user_message, chosen],\n            add_generation_prompt=False,\n            chat_template=chat_template_string,\n            tokenize=False,\n        )\n        chosen_strip_index = result[\"chosen\"].find(chosen[\"content\"])\n        result[\"chosen\"] = result[\"chosen\"][chosen_strip_index:].rstrip()\n\n        result[\"rejected\"] = tokenizer.apply_chat_template(\n            [dummy_user_message, rejected],\n            add_generation_prompt=False,\n            chat_template=chat_template_string,\n            tokenize=False,\n        )\n        rejected_strip_index = result[\"rejected\"].find(rejected[\"content\"])\n        result[\"rejected\"] = result[\"rejected\"][rejected_strip_index:].rstrip()\n\n        return result\n\n    return transform_fn, {\"remove_columns\": [field_messages]}\n\n\ndef argilla_chat(cfg, dataset_idx=0, **kwargs):\n    \"\"\"\n    DPO chat template strategy for argilla-style datasets.\n\n    For argilla-style datasets where chosen/rejected contain full conversations\n    instead of single response messages. Extracts the conversation history from\n    the chosen field and formats both chosen/rejected responses using the\n    configured chat template.\n\n    Args:\n        cfg: Configuration object containing chat_template and dataset settings\n        dataset_idx: Index of the dataset in the config (default: 0)\n        **kwargs: Additional keyword arguments (unused)\n\n    Returns:\n        tuple: (transform_fn, dataset_kwargs) where:\n            - transform_fn: Function to transform dataset samples\n            - dataset_kwargs: Dict with 'remove_columns' specifying columns to drop\n\n    Dataset format:\n        {\n            \"chosen\": [\n                {\"role\": \"user\", \"content\": \"...\"},\n                {\"role\": \"assistant\", \"content\": \"...\"}\n            ],\n            \"rejected\": [\n                {\"role\": \"user\", \"content\": \"...\"},\n                {\"role\": \"assistant\", \"content\": \"...\"}\n            ]\n        }\n    \"\"\"\n    ds_cfg = cfg[\"datasets\"][dataset_idx]\n    ds_cfg = handle_legacy_message_fields_logic(ds_cfg)\n\n    chat_template_choice, chat_template_jinja = extract_chat_template_args(\n        cfg=cfg, ds_cfg=ds_cfg\n    )\n    field_chosen = ds_cfg.get(\"field_chosen\", \"chosen\")\n    field_rejected = ds_cfg.get(\"field_rejected\", \"rejected\")\n    message_property_mappings = ds_cfg.get(\n        \"message_property_mappings\",\n        {\n            \"role\": \"role\",\n            \"content\": \"content\",\n        },\n    )\n    role_map_inv = ds_cfg.get(\n        \"roles\",\n        {\n            \"user\": [\"user\"],\n            \"assistant\": [\"assistant\"],\n            \"system\": [\"system\"],\n        },\n    )\n    role_map = {}\n    for target, sources in role_map_inv.items():\n        for source in sources:\n            role_map[source] = target\n\n    def transform_fn(sample, tokenizer=None):\n        chat_template_string = get_chat_template(\n            user_choice=chat_template_choice,\n            jinja_template=chat_template_jinja,\n            tokenizer=tokenizer,\n        )\n\n        chosen_raw = sample[field_chosen]\n        rejected_raw = sample[field_rejected]\n\n        # Extract messages (all but last) and responses (last message)\n        chosen_messages = [\n            {\n                \"role\": role_map[m[message_property_mappings[\"role\"]]],\n                \"content\": m[message_property_mappings[\"content\"]],\n            }\n            for m in chosen_raw[:-1]\n        ]\n        chosen_response = {\n            \"role\": role_map[chosen_raw[-1][message_property_mappings[\"role\"]]],\n            \"content\": chosen_raw[-1][message_property_mappings[\"content\"]],\n        }\n\n        rejected_response = {\n            \"role\": role_map[rejected_raw[-1][message_property_mappings[\"role\"]]],\n            \"content\": rejected_raw[-1][message_property_mappings[\"content\"]],\n        }\n\n        dummy_user_message = {\"role\": \"user\", \"content\": \"[[dummy_message]]\"}\n\n        result = {}\n        result[\"prompt\"] = tokenizer.apply_chat_template(\n            chosen_messages,\n            add_generation_prompt=True,\n            chat_template=chat_template_string,\n            tokenize=False,\n        )\n\n        result[\"chosen\"] = tokenizer.apply_chat_template(\n            [dummy_user_message, chosen_response],\n            add_generation_prompt=False,\n            chat_template=chat_template_string,\n            tokenize=False,\n        )\n        chosen_strip_index = result[\"chosen\"].find(chosen_response[\"content\"])\n        result[\"chosen\"] = result[\"chosen\"][chosen_strip_index:].rstrip()\n\n        result[\"rejected\"] = tokenizer.apply_chat_template(\n            [dummy_user_message, rejected_response],\n            add_generation_prompt=False,\n            chat_template=chat_template_string,\n            tokenize=False,\n        )\n        rejected_strip_index = result[\"rejected\"].find(rejected_response[\"content\"])\n        result[\"rejected\"] = result[\"rejected\"][rejected_strip_index:].rstrip()\n\n        return result\n\n    return transform_fn, {\"remove_columns\": [field_chosen, field_rejected]}\n"
  },
  {
    "path": "src/axolotl/prompt_strategies/dpo/chatml.py",
    "content": "\"\"\"\nDPO strategies for chatml\n\"\"\"\n\n\ndef default(\n    cfg,\n    **kwargs,\n):\n    def transform_fn(sample):\n        if \"prompt\" in sample.keys():\n            prompt_key = \"prompt\"\n        elif \"input\" in sample.keys():\n            prompt_key = \"input\"\n        elif \"question\" in sample.keys():\n            prompt_key = \"question\"\n        else:\n            prompt_key = \"instruction\"\n\n        if \"chosen\" in sample.keys():\n            chosen_key = \"chosen\"\n        else:\n            chosen_key = \"chosen_response\"\n\n        if \"rejected\" in sample.keys():\n            rejected_key = \"rejected\"\n        else:\n            rejected_key = \"rejected_response\"\n\n        if \"system\" in sample and sample[\"system\"]:\n            sample[\"prompt\"] = (\n                f\"<|im_start|>system\\n{sample['system']}<|im_end|>\\n\"\n                f\"<|im_start|>user\\n{sample[prompt_key]}<|im_end|>\\n<|im_start|>assistant\\n\"\n            )\n        else:\n            sample[\"prompt\"] = (\n                f\"<|im_start|>user\\n{sample[prompt_key]}<|im_end|>\\n<|im_start|>assistant\\n\"\n            )\n        sample[\"chosen\"] = f\"{sample[chosen_key]}<|im_end|>\"\n        sample[\"rejected\"] = f\"{sample[rejected_key]}<|im_end|>\"\n        return sample\n\n    return transform_fn\n\n\ndef argilla_chat(\n    cfg,\n    **kwargs,\n):\n    \"\"\"\n    for argilla/dpo-mix-7k conversations\n    \"\"\"\n\n    def transform_fn(sample):\n        sample[\"prompt\"] = (\n            f\"<|im_start|>user\\n{sample['chosen'][0]['content']}<|im_end|>\\n<|im_start|>assistant\\n\"\n        )\n        sample[\"chosen\"] = f\"{sample['chosen'][1]['content']}<|im_end|>\"\n        sample[\"rejected\"] = f\"{sample['rejected'][1]['content']}<|im_end|>\"\n        return sample\n\n    return transform_fn\n\n\ndef icr(\n    cfg,\n    **kwargs,\n):\n    \"\"\"\n    chatml transforms for datasets with system, input, chosen, rejected\n    ex. https://huggingface.co/datasets/argilla/distilabel-intel-orca-dpo-pairs\n    \"\"\"\n\n    def transform_fn(sample):\n        if \"system\" in sample and sample[\"system\"]:\n            sample[\"prompt\"] = (\n                f\"<|im_start|>system\\n{sample['system']}<|im_end|>\\n\"\n                f\"<|im_start|>user\\n{sample['input']}<|im_end|>\\n<|im_start|>assistant\\n\"\n            )\n        else:\n            sample[\"prompt\"] = (\n                f\"<|im_start|>user\\n{sample['input']}<|im_end|>\\n<|im_start|>assistant\\n\"\n            )\n        sample[\"chosen\"] = f\"{sample['chosen']}<|im_end|>\"\n        sample[\"rejected\"] = f\"{sample['rejected']}<|im_end|>\"\n        return sample\n\n    return transform_fn\n\n\ndef intel(cfg, **kwargs):\n    \"\"\"\n    For Intel Orca DPO Pairs\n    \"\"\"\n\n    def transform_fn(sample):\n        if \"system\" in sample and sample[\"system\"]:\n            sample[\"prompt\"] = (\n                f\"<|im_start|>system\\n{sample['system']}<|im_end|>\\n\"\n                f\"<|im_start|>user\\n{sample['question']}<|im_end|>\\n<|im_start|>assistant\\n\"\n            )\n        else:\n            sample[\"prompt\"] = (\n                f\"<|im_start|>user\\n{sample['question']}<|im_end|>\\n<|im_start|>assistant\\n\"\n            )\n        sample[\"chosen\"] = f\"{sample['chosen']}<|im_end|>\"\n        sample[\"rejected\"] = f\"{sample['rejected']}<|im_end|>\"\n        return sample\n\n    return transform_fn\n\n\ndef prompt_pairs(cfg, **kwargs):\n    def transform_fn(sample):\n        if \"system\" in sample and sample[\"system\"]:\n            sample[\"prompt\"] = (\n                f\"<|im_start|>system\\n{sample['system']}<|im_end|>\\n\"\n                f\"<|im_start|>user\\n{sample['prompt']}<|im_end|>\\n<|im_start|>assistant\\n\"\n            )\n        else:\n            sample[\"prompt\"] = (\n                f\"<|im_start|>user\\n{sample['prompt']}<|im_end|>\\n<|im_start|>assistant\\n\"\n            )\n        sample[\"chosen\"] = f\"{sample['chosen']}<|im_end|>\"\n        sample[\"rejected\"] = f\"{sample['rejected']}<|im_end|>\"\n        return sample\n\n    return transform_fn\n\n\ndef ultra(cfg, **kwargs):\n    \"\"\"\n    for ultrafeedback binarized conversations\n    \"\"\"\n\n    def transform_fn(sample):\n        if \"system\" in sample and sample[\"system\"]:\n            sample[\"prompt\"] = (\n                f\"<|im_start|>system\\n{sample['system']}<|im_end|>\\n\"\n                f\"<|im_start|>user\\n{sample['prompt']}<|im_end|>\\n<|im_start|>assistant\\n\"\n            )\n        else:\n            sample[\"prompt\"] = (\n                f\"<|im_start|>user\\n{sample['prompt']}<|im_end|>\\n<|im_start|>assistant\\n\"\n            )\n        sample[\"chosen\"] = f\"{sample['chosen'][1]['content']}<|im_end|>\"\n        sample[\"rejected\"] = f\"{sample['rejected'][1]['content']}<|im_end|>\"\n        return sample\n\n    return transform_fn\n"
  },
  {
    "path": "src/axolotl/prompt_strategies/dpo/llama3.py",
    "content": "\"\"\"\nDPO strategies for llama-3 chat template\n\"\"\"\n\n\ndef default(\n    cfg,\n    **kwargs,\n):\n    def transform_fn(sample):\n        if \"prompt\" in sample.keys():\n            prompt_key = \"prompt\"\n        elif \"input\" in sample.keys():\n            prompt_key = \"input\"\n        elif \"question\" in sample.keys():\n            prompt_key = \"question\"\n        else:\n            prompt_key = \"instruction\"\n\n        if \"chosen\" in sample.keys():\n            chosen_key = \"chosen\"\n        else:\n            chosen_key = \"chosen_response\"\n\n        if \"rejected\" in sample.keys():\n            rejected_key = \"rejected\"\n        else:\n            rejected_key = \"rejected_response\"\n\n        if \"system\" in sample and sample[\"system\"]:\n            sample[\"prompt\"] = (\n                f\"<|start_header_id|>system<|end_header_id|>\\n\\n{sample['system']}<|eot_id|>\"\n                f\"<|start_header_id|>user<|end_header_id|>\\n\\n{sample[prompt_key]}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\n\"\n            )\n        else:\n            sample[\"prompt\"] = (\n                f\"<|start_header_id|>user<|end_header_id|>\\n\\n{sample[prompt_key]}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\n\"\n            )\n        sample[\"chosen\"] = f\"{sample[chosen_key]}<|eot_id|>\"\n        sample[\"rejected\"] = f\"{sample[rejected_key]}<|eot_id|>\"\n        return sample\n\n    return transform_fn\n\n\ndef argilla_chat(\n    cfg,\n    **kwargs,\n):\n    \"\"\"\n    for argilla/dpo-mix-7k conversations\n    \"\"\"\n\n    def transform_fn(sample):\n        sample[\"prompt\"] = (\n            f\"<|start_header_id|>user<|end_header_id|>\\n\\n{sample['chosen'][0]['content']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\n\"\n        )\n        sample[\"chosen\"] = f\"{sample['chosen'][1]['content']}<|eot_id|>\"\n        sample[\"rejected\"] = f\"{sample['rejected'][1]['content']}<|eot_id|>\"\n        return sample\n\n    return transform_fn\n\n\ndef icr(\n    cfg,\n    **kwargs,\n):\n    \"\"\"\n    chatml transforms for datasets with system, input, chosen, rejected\n    ex. https://huggingface.co/datasets/argilla/distilabel-intel-orca-dpo-pairs\n    \"\"\"\n\n    def transform_fn(sample):\n        if \"system\" in sample and sample[\"system\"]:\n            sample[\"prompt\"] = (\n                f\"<|start_header_id|>system<|end_header_id|>\\n\\n{sample['system']}<|eot_id|>\"\n                f\"<|start_header_id|>user<|end_header_id|>\\n\\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\n\"\n            )\n        else:\n            sample[\"prompt\"] = (\n                f\"<|start_header_id|>user<|end_header_id|>\\n\\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\n\"\n            )\n        sample[\"chosen\"] = f\"{sample['chosen']}<|eot_id|>\"\n        sample[\"rejected\"] = f\"{sample['rejected']}<|eot_id|>\"\n        return sample\n\n    return transform_fn\n\n\ndef intel(cfg, **kwargs):\n    \"\"\"\n    For Intel Orca DPO Pairs\n    \"\"\"\n\n    def transform_fn(sample):\n        if \"system\" in sample and sample[\"system\"]:\n            sample[\"prompt\"] = (\n                f\"<|start_header_id|>system<|end_header_id|>\\n\\n{sample['system']}<|eot_id|>\"\n                f\"<|start_header_id|>user<|end_header_id|>\\n\\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\n\"\n            )\n        else:\n            sample[\"prompt\"] = (\n                f\"<|start_header_id|>user<|end_header_id|>\\n\\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\n\"\n            )\n        sample[\"chosen\"] = f\"{sample['chosen']}<|eot_id|>\"\n        sample[\"rejected\"] = f\"{sample['rejected']}<|eot_id|>\"\n        return sample\n\n    return transform_fn\n\n\ndef prompt_pairs(cfg, **kwargs):\n    def transform_fn(sample):\n        if \"system\" in sample and sample[\"system\"]:\n            sample[\"prompt\"] = (\n                f\"<|start_header_id|>system<|end_header_id|>\\n\\n{sample['system']}<|eot_id|>\"\n                f\"<|start_header_id|>user<|end_header_id|>\\n\\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\n\"\n            )\n        else:\n            sample[\"prompt\"] = (\n                f\"<|start_header_id|>user<|end_header_id|>\\n\\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\n\"\n            )\n        sample[\"chosen\"] = f\"{sample['chosen']}<|eot_id|>\"\n        sample[\"rejected\"] = f\"{sample['rejected']}<|eot_id|>\"\n        return sample\n\n    return transform_fn\n\n\ndef ultra(cfg, **kwargs):\n    \"\"\"\n    for ultrafeedback binarized conversations\n    \"\"\"\n\n    def transform_fn(sample):\n        if \"system\" in sample and sample[\"system\"]:\n            sample[\"prompt\"] = (\n                f\"<|start_header_id|>system<|end_header_id|>\\n\\n{sample['system']}<|eot_id|>\"\n                f\"<|start_header_id|>user<|end_header_id|>\\n\\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\n\"\n            )\n        else:\n            sample[\"prompt\"] = (\n                f\"<|start_header_id|>user<|end_header_id|>\\n\\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\n\"\n            )\n        sample[\"chosen\"] = f\"{sample['chosen'][1]['content']}<|eot_id|>\"\n        sample[\"rejected\"] = f\"{sample['rejected'][1]['content']}<|eot_id|>\"\n        return sample\n\n    return transform_fn\n"
  },
  {
    "path": "src/axolotl/prompt_strategies/dpo/passthrough.py",
    "content": "\"\"\"\nDPO prompt strategies passthrough/zero-processing strategy\n\"\"\"\n\n\ndef default(cfg, dataset_idx=0, **kwargs):\n    def transform_fn(sample, tokenizer=None):\n        return sample\n\n    return transform_fn\n"
  },
  {
    "path": "src/axolotl/prompt_strategies/dpo/user_defined.py",
    "content": "\"\"\"\nUser-defined DPO strategies\n\"\"\"\n\n\ndef default(cfg, dataset_idx=0, **kwargs):\n    ds_cfg = cfg[\"datasets\"][dataset_idx][\"type\"]\n    if not isinstance(ds_cfg, dict):\n        raise ValueError(\n            f\"User-defined dataset type must be a dictionary. Got: {ds_cfg}\"\n        )\n    field_prompt = ds_cfg.get(\"field_prompt\", \"prompt\")\n    field_system = ds_cfg.get(\"field_system\", \"system\")\n    field_chosen = ds_cfg.get(\"field_chosen\", \"chosen\")\n    field_rejected = ds_cfg.get(\"field_rejected\", \"rejected\")\n    prompt_format = ds_cfg.get(\"prompt_format\")\n    if not prompt_format:\n        prompt_format = \"{\" + field_prompt + \"}\"\n    chosen_format = ds_cfg.get(\"chosen_format\")\n    if not chosen_format:\n        chosen_format = \"{\" + field_chosen + \"}\"\n    rejected_format = ds_cfg.get(\"rejected_format\")\n    if not rejected_format:\n        rejected_format = \"{\" + field_rejected + \"}\"\n\n    def transform_fn(sample):\n        if (\n            \"{\" + field_system + \"}\" in prompt_format\n            and field_system in sample\n            and sample[field_system]\n        ):\n            sample[\"prompt\"] = prompt_format.format(\n                system=sample[field_system], prompt=sample[field_prompt]\n            )\n        else:\n            sample[\"prompt\"] = prompt_format.format(prompt=sample[field_prompt])\n        sample[\"chosen\"] = chosen_format.format(chosen=sample[field_chosen])\n        sample[\"rejected\"] = rejected_format.format(rejected=sample[field_rejected])\n        return sample\n\n    return transform_fn\n"
  },
  {
    "path": "src/axolotl/prompt_strategies/dpo/zephyr.py",
    "content": "\"\"\"\nDPO strategies for zephyr\n\"\"\"\n\n\ndef nectar(cfg, **kwargs):\n    def transform_fn(sample):\n        data = {}\n        data[\"prompt\"] = (\n            f\"<|system|>\\n</s>\\n<|user|>\\n{sample['prompt']}</s>\\n<|assistant|>\\n\"\n        )\n        answers = sorted(sample[\"answers\"], key=lambda x: x[\"rank\"])\n        data[\"chosen\"] = answers[-1][\"answer\"]\n        data[\"rejected\"] = answers[-2][\"answer\"]\n\n        return data\n\n    return transform_fn\n"
  },
  {
    "path": "src/axolotl/prompt_strategies/input_output.py",
    "content": "\"\"\"Module for plain input/output prompt pairs\"\"\"\n\nfrom typing import Generator, Tuple\n\nfrom axolotl.prompt_tokenizers import PromptTokenizingStrategy\nfrom axolotl.prompters import IGNORE_TOKEN_ID, Prompter\n\n\nclass RawInputOutputStrategy(PromptTokenizingStrategy):\n    \"\"\"Prompt Strategy class for input/output pairs\"\"\"\n\n    def __init__(self, *args, eos_token=None, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.eos_token = eos_token\n        if not eos_token:\n            self.eos_token = self.tokenizer.eos_token\n\n    def tokenize_prompt(self, prompt):\n        input_ids = []\n        labels = []\n        for label, text in self.prompter.build_prompt(prompt[\"segments\"]):\n            tokenized_output = self.tokenizer(\n                text, add_special_tokens=False, return_tensors=None\n            )[\"input_ids\"]\n            input_ids += tokenized_output\n            if label or self.train_on_inputs:\n                labels += tokenized_output\n            else:\n                labels += [IGNORE_TOKEN_ID] * len(tokenized_output)\n\n        tokenized_prompt = {\n            \"input_ids\": input_ids,\n            \"labels\": labels,\n            \"attention_mask\": [1] * len(input_ids),\n        }\n\n        return tokenized_prompt\n\n\nclass RawInputOutputPrompter(Prompter):\n    \"\"\"prompter for raw i/o data\"\"\"\n\n    def build_prompt(self, source) -> Generator[Tuple[bool, str], None, None]:\n        for segment in source:\n            yield segment[\"label\"], segment[\"text\"]\n\n\ndef load(tokenizer, cfg):\n    return RawInputOutputStrategy(\n        RawInputOutputPrompter(),\n        tokenizer,\n        cfg.train_on_inputs,\n        cfg.sequence_len,\n    )\n"
  },
  {
    "path": "src/axolotl/prompt_strategies/jinja_template_analyzer.py",
    "content": "\"\"\"Module for inspect jinja templates for the variables they use\"\"\"\n\nfrom typing import Dict, Optional, Set, TypedDict, Union\n\nfrom jinja2 import Environment, meta, nodes\nfrom jinja2.ext import Extension\n\n\nclass JinjaTemplateAnalysis(TypedDict):\n    \"\"\"\n    Represents the detailed analysis of a Jinja template variable.\n\n    Attributes:\n        accessed_properties (Set[str]): A set of properties accessed from the variable\n            (e.g., `foo.bar` results in 'bar' being accessed for 'foo').\n        accessed_indices (Set[Union[int, float]]): A set of indices accessed from the variable.\n        is_iterated (bool): Indicates if the variable is used as an iteration source in a `for` loop.\n        is_conditional (bool): Indicates if the variable is referenced within a conditional statement (e.g., an `if` block).\n        iteration_source (Optional[str]): The name of the variable being iterated over, if applicable.\n        iteration_target (Optional[Union[str, list[str]]]): The loop target(s) assigned in the iteration.\n    \"\"\"\n\n    accessed_properties: Set[str]\n    accessed_indices: Set[Union[int, float]]\n    is_iterated: bool\n    is_conditional: bool\n    iteration_source: Optional[str]\n    iteration_target: Optional[Union[str, list[str]]]\n\n\nclass GenerationTagIgnore(Extension):\n    \"\"\"\n    Ignores the generation and endgeneration tags in Jinja templates.\n    \"\"\"\n\n    tags = {\"generation\", \"endgeneration\"}\n\n    def parse(self, parser):\n        parser.stream.skip(1)\n        return nodes.Const(\"\")\n\n\nclass JinjaTemplateAnalyzer:\n    \"\"\"\n    Analyzes Jinja templates to extract information about variable usage,\n    including accessed properties, iteration, and conditional references.\n\n    Attributes:\n        env (jinja2.Environment): The Jinja2 environment used for parsing templates.\n        property_access (Dict[str, Set[str]]): Tracks accessed properties for variables.\n        iteration_targets (Dict[str, str]): Maps iteration target variables to their sources.\n\n    Methods:\n        get_template_variables(template: str) -> Dict[str, Set[str]]:\n            Parse a Jinja template and return a mapping of variables to their accessed properties.\n\n        analyze_template(template: str) -> Dict[str, JinjaTemplateAnalysis]:\n            Perform a detailed analysis of the template, including variable usage,\n            iteration, and conditional references.\n\n    Private Methods:\n        _visit_node(node) -> None:\n            Recursively visit AST nodes to detect attribute access and iteration targets.\n\n        _get_base_name(node) -> Optional[str]:\n            Extract the base variable name from a node.\n\n        _get_target_name(node) -> Optional[Union[str, list[str]]]:\n            Extract the target name(s) from a `For` node.\n    \"\"\"\n\n    def __init__(self, template: str):\n        self.env: Environment = Environment(\n            autoescape=True, extensions=[GenerationTagIgnore]\n        )\n        self.property_access: Dict[str, Set[str]] = {}\n        self.iteration_targets: Dict[str, Union[str, list[str]]] = {}\n        self.index_access: Dict[str, Set[Union[int, float]]] = {}\n        self.ast: nodes.Node = self.env.parse(template)\n        self.template: str = template\n        self.variable_assignments: Dict[str, str] = {}\n\n    def _visit_node(self, node) -> None:\n        \"\"\"Recursively visit AST nodes to find attribute access.\"\"\"\n        # Handle attribute access (dot notation)\n        if isinstance(node, nodes.Getattr):\n            base_name = self._get_base_name(node.node)\n            if base_name:\n                self.property_access.setdefault(base_name, set()).add(node.attr)\n\n        # Handle dictionary access (subscript notation)\n        elif isinstance(node, nodes.Getitem):\n            base_name = self._get_base_name(node.node)\n            if base_name and isinstance(node.arg, nodes.Const):\n                value = node.arg.value\n                if isinstance(value, (int, float)):\n                    self.index_access.setdefault(base_name, set()).add(value)\n                else:\n                    self.property_access.setdefault(base_name, set()).add(value)\n\n        elif isinstance(node, nodes.Test) and node.name == \"defined\":\n            base_name = self._get_base_name(node.node)\n            if base_name:\n                if isinstance(node.node, nodes.Getattr):\n                    self.property_access.setdefault(base_name, set()).add(\n                        node.node.attr\n                    )\n\n        # Handle loop variables\n        elif isinstance(node, nodes.For):\n            iter_name = self._get_base_name(node.iter)\n            target_name = self._get_target_name(node.target)\n            if iter_name and target_name:\n                self.iteration_targets[target_name] = iter_name\n                self.property_access.setdefault(iter_name, set())\n\n        elif isinstance(node, nodes.Assign):\n            target_name = self._get_target_name(node.target)\n            source_name = self._get_base_name(node.node)\n            if target_name and source_name:\n                self.variable_assignments[target_name] = source_name\n\n        elif isinstance(node, nodes.Filter):\n            if node.name == \"selectattr\":\n                target = self._get_base_name(node.node)\n                if target:\n                    self.variable_assignments[f\"filtered_{target}\"] = target\n\n        for child in node.iter_child_nodes():\n            self._visit_node(child)\n\n    def _get_target_name(self, node) -> Optional[str]:\n        \"\"\"Get the target variable name from a For node.\n\n        Args:\n            node: A Jinja AST node representing either a Name or Tuple node\n\n        Returns:\n            - str: For simple variable targets (e.g., \"item\" in \"for item in items\")\n            - None: If the node type is not recognized or is a tuple\n        \"\"\"\n        if isinstance(node, nodes.Name):\n            return node.name\n        return None\n\n    def _get_target_names(self, node) -> list[str]:\n        \"\"\"Get all target variable names from a For node, including tuple unpacking.\n\n        Args:\n            node: A Jinja AST node representing either a Name or Tuple node\n\n        Returns:\n            List of target variable names\n        \"\"\"\n        if isinstance(node, nodes.Name):\n            return [node.name]\n\n        if isinstance(node, nodes.Tuple):\n            names = []\n            for n in node.items:\n                if isinstance(n, nodes.Name):\n                    names.append(n.name)\n            return names\n\n        return []\n\n    def _get_base_name(self, node) -> Optional[str]:\n        \"\"\"Get the base variable name from a node.\"\"\"\n        if isinstance(node, nodes.Name):\n            return node.name\n\n        if isinstance(node, nodes.Getattr):\n            return self._get_base_name(node.node)\n\n        if isinstance(node, nodes.Getitem):\n            return self._get_base_name(node.node)\n\n        return None\n\n    def get_template_variables(self) -> Dict[str, Set[str]]:\n        \"\"\"\n        Parse a Jinja template and return both variables and their accessed properties.\n\n        Args:\n            template (str): The Jinja template string\n\n        Returns:\n            Dict[str, Set[str]]: Dictionary mapping variable names to sets of accessed properties\n        \"\"\"\n        # Parse the template\n        ast = self.env.parse(self.template)\n\n        # Get all undeclared variables\n        variables = meta.find_undeclared_variables(ast)\n\n        # Reset property access tracking\n        self.property_access = {}\n\n        # Visit all nodes to find property access\n        self._visit_node(ast)\n\n        # Create result dictionary\n        result: Dict[str, Set[str]] = {var: set() for var in variables}\n        # Merge in any discovered sub-properties\n        for var, props in self.property_access.items():\n            if var not in result:\n                result[var] = set()\n            result[var].update(props)\n\n        return result\n\n    def analyze_template(self) -> Dict[str, JinjaTemplateAnalysis]:\n        \"\"\"\n        Provide a detailed analysis of template variables and their usage.\n        \"\"\"\n        variables = self.get_template_variables()\n        self.iteration_targets = {}\n\n        analysis: Dict[str, JinjaTemplateAnalysis] = {\n            var: JinjaTemplateAnalysis(\n                accessed_properties=props,\n                accessed_indices=set(),\n                is_iterated=False,\n                is_conditional=False,\n                iteration_source=None,\n                iteration_target=None,\n            )\n            for var, props in variables.items()\n        }\n\n        for var, indices in self.index_access.items():\n            if var in analysis:\n                analysis[var][\"accessed_indices\"] = indices\n\n        def visit_node(node):\n            if isinstance(node, nodes.If):\n\n                def find_test_vars(test_node):\n                    if isinstance(test_node, nodes.Name):\n                        if test_node.name in analysis:\n                            analysis[test_node.name][\"is_conditional\"] = True\n                    for child in test_node.iter_child_nodes():\n                        find_test_vars(child)\n\n                find_test_vars(node.test)\n\n            if isinstance(node, nodes.For):\n                iter_target = self._get_base_name(node.iter)\n                target_name = self._get_target_name(node.target)\n                if iter_target in analysis:\n                    analysis[iter_target][\"is_iterated\"] = True\n                    if target_name:\n                        analysis[iter_target][\"iteration_target\"] = target_name\n                        if isinstance(target_name, str) and target_name not in analysis:\n                            analysis[target_name] = {\n                                \"accessed_properties\": set(),\n                                \"is_iterated\": False,\n                                \"is_conditional\": False,\n                                \"iteration_source\": iter_target,\n                                \"iteration_target\": None,\n                            }\n\n            for child in node.iter_child_nodes():\n                visit_node(child)\n\n        visit_node(self.ast)\n        return analysis\n\n    def get_downstream_properties(self, start_var: str) -> Dict[str, Set[str]]:\n        \"\"\"\n        Get all properties accessed on a variable and its downstream assignments.\n\n        Args:\n            start_var: The starting variable to trace\n\n        Returns:\n            Dict mapping variable names to their accessed properties\n        \"\"\"\n        visited = set()\n        properties = {}\n\n        def trace_variable(var_name: str):\n            if var_name in visited:\n                return\n            visited.add(var_name)\n\n            # Get direct properties\n            if var_name in self.property_access:\n                properties[var_name] = self.property_access[var_name]\n\n            # Get properties from iteration targets\n            if var_name in self.iteration_targets:\n                target = self.iteration_targets[var_name]\n                if isinstance(target, str):\n                    trace_variable(target)\n                elif isinstance(target, list):\n                    for t in target:\n                        trace_variable(t)\n\n            # Follow assignments\n            for target, source in self.variable_assignments.items():\n                if source == var_name:\n                    trace_variable(target)\n\n            # Check for array slicing\n            analysis = self.analyze_template()\n            if var_name in analysis:\n                var_info = analysis[var_name]\n                if var_info[\"accessed_indices\"]:\n                    # If this variable is sliced, follow the resulting assignment\n                    slice_result = f\"{var_name}_slice\"\n                    if slice_result in self.property_access:\n                        trace_variable(slice_result)\n\n        trace_variable(start_var)\n        return properties\n\n    def get_message_vars(self, field_messages: str = \"messages\") -> Set[str]:\n        \"\"\"\n        Get all properties accessed on messages and derived variables.\n        \"\"\"\n        all_properties = self.get_downstream_properties(field_messages)\n\n        # Combine all properties from all related variables\n        combined_properties = set()\n        for properties in all_properties.values():\n            combined_properties.update(properties)\n\n        # Also include properties from the message iteration variable\n        analysis = self.analyze_template()\n        if \"message\" in analysis:\n            combined_properties.update(analysis[\"message\"][\"accessed_properties\"])\n\n        return combined_properties\n"
  },
  {
    "path": "src/axolotl/prompt_strategies/kto/__init__.py",
    "content": "\"\"\"\nmodule for KTO style dataset transform strategies\n\"\"\"\n\nfrom functools import partial\n\nfrom ..base import load as load_base\n\nload = partial(load_base, module_base=\"axolotl.prompt_strategies.kto\")\n"
  },
  {
    "path": "src/axolotl/prompt_strategies/kto/chatml.py",
    "content": "\"\"\"\nKTO strategies for chatml\n\"\"\"\n\n\ndef argilla(\n    cfg,\n    **kwargs,\n):\n    def transform_fn(sample):\n        if \"system\" in sample and sample[\"system\"]:\n            sample[\"prompt\"] = (\n                f\"<|im_start|>system\\n{sample['system']}<|im_end|>\\n\"\n                f\"<|im_start|>user\\n{sample['instruction']}<|im_end|>\\n<|im_start|>assistant\\n\"\n            )\n        else:\n            sample[\"prompt\"] = (\n                f\"<|im_start|>user\\n{sample['instruction']}<|im_end|>\\n<|im_start|>assistant\\n\"\n            )\n        sample[\"completion\"] = f\"{sample['completion']}<|im_end|>\"\n        return sample\n\n    return transform_fn\n\n\ndef argilla_chat(\n    cfg,\n    **kwargs,\n):\n    \"\"\"\n    for argilla/kto-mix-15k conversations\n    \"\"\"\n\n    def transform_fn(sample):\n        sample[\"prompt\"] = (\n            f\"<|im_start|>user\\n{sample['chosen'][0]['content']}<|im_end|>\\n<|im_start|>assistant\\n\"\n        )\n        sample[\"completion\"] = f\"{sample['completion'][1]['content']}<|im_end|>\"\n        return sample\n\n    return transform_fn\n\n\ndef intel(cfg, **kwargs):\n    \"\"\"\n    For Intel Orca KTO\n    ex: argilla/distilabel-intel-orca-kto\n    \"\"\"\n\n    def transform_fn(sample):\n        if \"system\" in sample and sample[\"system\"]:\n            sample[\"prompt\"] = (\n                f\"<|im_start|>system\\n{sample['system']}<|im_end|>\\n\"\n                f\"<|im_start|>user\\n{sample['question']}<|im_end|>\\n<|im_start|>assistant\\n\"\n            )\n        else:\n            sample[\"prompt\"] = (\n                f\"<|im_start|>user\\n{sample['question']}<|im_end|>\\n<|im_start|>assistant\\n\"\n            )\n        sample[\"completion\"] = f\"{sample['completion']}<|im_end|>\"\n        return sample\n\n    return transform_fn\n\n\ndef prompt_pairs(cfg, **kwargs):\n    def transform_fn(sample):\n        if \"system\" in sample and sample[\"system\"]:\n            sample[\"prompt\"] = (\n                f\"<|im_start|>system\\n{sample['system']}<|im_end|>\\n\"\n                f\"<|im_start|>user\\n{sample['prompt']}<|im_end|>\\n<|im_start|>assistant\\n\"\n            )\n        else:\n            sample[\"prompt\"] = (\n                f\"<|im_start|>user\\n{sample['prompt']}<|im_end|>\\n<|im_start|>assistant\\n\"\n            )\n        sample[\"completion\"] = f\"{sample['completion']}<|im_end|>\"\n        return sample\n\n    return transform_fn\n\n\ndef ultra(cfg, **kwargs):\n    \"\"\"\n    for ultrafeedback binarized conversations\n    ex: argilla/ultrafeedback-binarized-preferences-cleaned-kto\n    \"\"\"\n\n    def transform_fn(sample):\n        if \"system\" in sample and sample[\"system\"]:\n            sample[\"prompt\"] = (\n                f\"<|im_start|>system\\n{sample['system']}<|im_end|>\\n\"\n                f\"<|im_start|>user\\n{sample['prompt']}<|im_end|>\\n<|im_start|>assistant\\n\"\n            )\n        else:\n            sample[\"prompt\"] = (\n                f\"<|im_start|>user\\n{sample['prompt']}<|im_end|>\\n<|im_start|>assistant\\n\"\n            )\n        sample[\"completion\"] = f\"{sample['completion']}<|im_end|>\"\n        return sample\n\n    return transform_fn\n"
  },
  {
    "path": "src/axolotl/prompt_strategies/kto/llama3.py",
    "content": "\"\"\"\nKTO strategies for llama-3 chat template\n\"\"\"\n\n\ndef argilla(\n    cfg,\n    **kwargs,\n):\n    def transform_fn(sample):\n        if \"system\" in sample and sample[\"system\"]:\n            sample[\"prompt\"] = (\n                f\"<|start_header_id|>system<|end_header_id|>\\n\\n{sample['system']}<|eot_id|>\"\n                f\"<|start_header_id|>user<|end_header_id|>\\n\\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\n\"\n            )\n        else:\n            sample[\"prompt\"] = (\n                f\"<|start_header_id|>user<|end_header_id|>\\n\\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\n\"\n            )\n        sample[\"completion\"] = f\"{sample['completion']}<|eot_id|>\"\n        return sample\n\n    return transform_fn\n\n\ndef argilla_chat(\n    cfg,\n    **kwargs,\n):\n    \"\"\"\n    for argilla/kto-mix-15k conversations\n    \"\"\"\n\n    def transform_fn(sample):\n        sample[\"prompt\"] = (\n            f\"<|start_header_id|>user<|end_header_id|>\\n\\n{sample['completion'][0]['content']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\n\"\n        )\n        sample[\"completion\"] = f\"{sample['completion'][1]['content']}<|eot_id|>\"\n        return sample\n\n    return transform_fn\n\n\ndef intel(cfg, **kwargs):\n    \"\"\"\n    For Intel Orca KTO\n    ex: argilla/distilabel-intel-orca-kto\n    \"\"\"\n\n    def transform_fn(sample):\n        if \"system\" in sample and sample[\"system\"]:\n            sample[\"prompt\"] = (\n                f\"<|start_header_id|>system<|end_header_id|>\\n\\n{sample['system']}<|eot_id|>\"\n                f\"<|start_header_id|>user<|end_header_id|>\\n\\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\n\"\n            )\n        else:\n            sample[\"prompt\"] = (\n                f\"<|start_header_id|>user<|end_header_id|>\\n\\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\n\"\n            )\n        sample[\"completion\"] = f\"{sample['completion']}<|eot_id|>\"\n        return sample\n\n    return transform_fn\n\n\ndef prompt_pairs(cfg, **kwargs):\n    def transform_fn(sample):\n        if \"system\" in sample and sample[\"system\"]:\n            sample[\"prompt\"] = (\n                f\"<|start_header_id|>system<|end_header_id|>\\n\\n{sample['system']}<|eot_id|>\"\n                f\"<|start_header_id|>user<|end_header_id|>\\n\\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\n\"\n            )\n        else:\n            sample[\"prompt\"] = (\n                f\"<|start_header_id|>user<|end_header_id|>\\n\\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\n\"\n            )\n        sample[\"completion\"] = f\"{sample['completion']}<|eot_id|>\"\n        return sample\n\n    return transform_fn\n\n\ndef ultra(cfg, **kwargs):\n    \"\"\"\n    for ultrafeedback binarized conversations\n    ex: argilla/ultrafeedback-binarized-preferences-cleaned-kto\n    \"\"\"\n\n    def transform_fn(sample):\n        if \"system\" in sample and sample[\"system\"]:\n            sample[\"prompt\"] = (\n                f\"<|start_header_id|>system<|end_header_id|>\\n\\n{sample['system']}<|eot_id|>\"\n                f\"<|start_header_id|>user<|end_header_id|>\\n\\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\n\"\n            )\n        else:\n            sample[\"prompt\"] = (\n                f\"<|start_header_id|>user<|end_header_id|>\\n\\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\n\"\n            )\n        sample[\"completion\"] = f\"{sample['completion']}<|eot_id|>\"\n        return sample\n\n    return transform_fn\n"
  },
  {
    "path": "src/axolotl/prompt_strategies/kto/user_defined.py",
    "content": "\"\"\"\nUser-defined KTO strategies\n\"\"\"\n\n\ndef default(cfg, dataset_idx=0, **kwargs):\n    ds_cfg = cfg[\"datasets\"][dataset_idx][\"type\"]\n    if not isinstance(ds_cfg, dict):\n        raise ValueError(\n            f\"User-defined dataset type must be a dictionary. Got: {ds_cfg}\"\n        )\n    field_prompt = ds_cfg.get(\"field_prompt\", \"prompt\")\n    field_system = ds_cfg.get(\"field_system\", \"system\")\n    field_completion = ds_cfg.get(\"field_completion\", \"completion\")\n    field_label = ds_cfg.get(\"field_label\", \"label\")\n    prompt_format = ds_cfg.get(\"prompt_format\")\n    if not prompt_format:\n        prompt_format = \"{\" + field_prompt + \"}\"\n    completion_format = ds_cfg.get(\"completion_format\")\n    if not completion_format:\n        chosen_format = \"{\" + field_completion + \"}\"\n\n    def transform_fn(sample):\n        if (\n            \"{\" + field_system + \"}\" in prompt_format\n            and field_system in sample\n            and sample[field_system]\n        ):\n            sample[\"prompt\"] = prompt_format.format(\n                system=sample[field_system], prompt=sample[field_prompt]\n            )\n        else:\n            sample[\"prompt\"] = prompt_format.format(prompt=sample[\"prompt\"])\n        sample[\"completion\"] = chosen_format.format(chosen=sample[field_completion])\n        sample[\"label\"] = sample[field_label]\n        return sample\n\n    return transform_fn\n"
  },
  {
    "path": "src/axolotl/prompt_strategies/llama2_chat.py",
    "content": "\"\"\"\nPrompt Strategy for finetuning Llama2 chat models\nsee also https://github.com/facebookresearch/llama/blob/6c7fe276574e78057f917549435a2554000a876d/llama/generation.py#L213 for ma reference implementation.\n\nThis implementation is based on the Vicuna PR and the fastchat repo, see also:\nhttps://github.com/lm-sys/FastChat/blob/cdd7730686cb1bf9ae2b768ee171bdf7d1ff04f3/fastchat/conversation.py#L847\n\nUse dataset type: \"llama2_chat\" in conig.yml to use this prompt style.\n\nE.g. in the config.yml:\n```\ndatasets:\n  - path: llama_finetune_train.jsonl\n    type: llama2_chat\n```\n\nThe dataset itself should look like this:\n```\n{'conversations':[{\"from\": \"human\", \"value\": \"Who are you?\"}, {\"from\": \"gpt\", \"value\": \"I am Vicuna\"},...]}\n```\nin a jsonl file. The first message should be from the human, the second from gpt.\nFor a custom system message, the first \"from\" can be \"system\" (followed by alternating \"human\" and \"gpt\" turns).\n\nImportant: Don't use \"special_tokens:\" in your config.yml if you are not sure what you are doing!\n\"\"\"\n\nfrom dataclasses import dataclass, field\nfrom typing import Generator, List, Sequence\n\nfrom axolotl.prompt_tokenizers import PromptTokenizingStrategy\nfrom axolotl.prompters import ALTERNATING_ASSERTION_FAILED_ROLE, IGNORE_TOKEN_ID\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\n\n@dataclass\nclass Llama2ChatConversation:\n    \"\"\"A class that manages prompt templates and keeps all conversation history.\n    copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py\"\"\"\n\n    name: str = \"llama2\"\n    # The system prompt\n    system: str = (\n        \"[INST] <<SYS>>\\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. \"\n        \"Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. \"\n        \"Please ensure that your responses are socially unbiased and positive in nature.\\n\\n\"\n        \"If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. \"\n        \"If you don't know the answer to a question, please don't share false information.\\n<</SYS>>\\n\\n\"\n    )\n    roles: Sequence[str] = (\"[INST]\", \"[/INST]\")\n    messages: List[List[str]] = field(default_factory=list)\n    offset: int = 0\n    sep = \" \"\n    sep2 = \" </s><s>\"\n    stop_token_ids = [2]\n\n    def get_prompt(self) -> str:\n        \"\"\"Get the prompt for generation.\"\"\"\n        seps = [self.sep, self.sep2]\n        ret = \"\"\n        for i, (role, message) in enumerate(self.messages):\n            if (i == len(self.messages) - 1) and (role == self.roles[0]):\n                # last message is from user (due to length),\n                #  return prompt without it for training\n                return ret\n            if i == 0:\n                ret += self.system + message.strip()\n            else:\n                ret += role + \" \" + message.strip() + seps[i % 2]\n        return ret\n\n    def append_message(self, role: str, message: str):\n        \"\"\"Append a new message.\"\"\"\n        self.messages.append([role, message])\n\n\nclass LLama2ChatTokenizingStrategy(PromptTokenizingStrategy):\n    \"\"\"\n    Tokenizing strategy for Llama2 prompts.\n    adapted from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train.py\n    \"\"\"\n\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.tokenizer.add_special_tokens(\n            {\"pad_token\": getattr(self.tokenizer, \"pad_token\", \"<pad>\")}\n        )\n        # https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/blob/main/added_tokens.json\n\n    def tokenize_prompt(self, prompt):\n        conv = next(self.prompter.build_prompt(prompt))\n        conversation_str = conv.get_prompt()\n\n        # Tokenize conversations\n        input_ids = self.tokenizer(\n            conversation_str,\n            return_tensors=\"pt\",\n            padding=\"max_length\",\n            max_length=self.sequence_len,\n            truncation=True,\n        ).input_ids[0]\n        target = input_ids.clone()\n\n        # Mask targets. Only compute loss on the assistant outputs.\n        sep = conv.roles[1]\n\n        total_len = int(target.ne(self.tokenizer.pad_token_id).sum())\n\n        turns = conversation_str.split(conv.sep2)\n        cur_len = 1\n        target[:cur_len] = IGNORE_TOKEN_ID\n        for turn in turns:\n            if turn == \"\":\n                break\n            turn_len = len(self.tokenizer(turn).input_ids)\n\n            parts = turn.split(sep)\n            if len(parts) != 2:\n                break\n            parts[0] += sep\n            # \"-1\" is hardcoded for the LLaMA tokenizer to make the offset correct.\n            instruction_len = len(self.tokenizer(parts[0]).input_ids) - 1\n\n            # Ignore the user instructions\n            target[cur_len - 1 : cur_len + instruction_len] = IGNORE_TOKEN_ID\n            cur_len += turn_len + 2  # due to length of role token\n\n        target[cur_len:] = IGNORE_TOKEN_ID\n\n        if cur_len < self.sequence_len:\n            if cur_len != total_len:\n                target[:] = IGNORE_TOKEN_ID\n                LOG.warning(\n                    f\"WARNING: tokenization mismatch: {cur_len} vs. {total_len}.\"\n                    f\" (ignored)\"\n                )\n\n        attention_mask = input_ids.ne(self.tokenizer.pad_token_id).tolist()\n        input_ids = input_ids.tolist()\n        target = target.tolist()\n        # this is a fix for the tokenizer which tokenizes [ differently with eos tokens and\n        # follows the original llama implementation\n        for i in range(2, total_len - 2):\n            if input_ids[i] == 29961:\n                input_ids[i] = 518\n            if target[i] == 29961:\n                target[i] = 518\n        return {\n            \"input_ids\": input_ids,\n            \"labels\": target,\n            \"attention_mask\": attention_mask,\n        }\n\n\nclass Llama2ChatPrompter:\n    \"\"\"\n    A prompter that generates prompts for Llama2 models.\n    \"\"\"\n\n    system_prompt = (\n        \"[INST] <<SYS>>\\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. \"\n        \"Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. \"\n        \"Please ensure that your responses are socially unbiased and positive in nature.\\n\\n\"\n        \"If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. \"\n        \"If you don't know the answer to a question, please don't share false information.\\n<</SYS>>\\n\\n\"\n    )\n\n    def build_prompt(self, source) -> Generator[Llama2ChatConversation, None, None]:\n        # see https://github.com/lm-sys/FastChat/blob/da0641e567cf93756b0978ab5a6b092e96f06240/fastchat/train/train.py#L78\n        source = source[\"conversations\"]  # fix data structure for datasets\n\n        # if system prompt provided, use it\n        if source[0][\"from\"] == \"system\":\n            system = f\"[INST] <<SYS>>\\n{source[0]['value']}\\n<</SYS>>\\n\\n\"\n            source = source[1:]\n        else:\n            system = self.system_prompt\n\n        conv = Llama2ChatConversation(system=system)\n\n        if len(source) < 2:\n            # If there isn't a back and forth conversation, ignore it\n            # also happens on the data splitting leaving empty conversations\n            raise IndexError\n\n        roles = {\"human\": conv.roles[0], \"gpt\": conv.roles[1]}\n\n        if roles[source[0][\"from\"]] != conv.roles[0]:\n            # Skip the first one if it is not from human\n            source = source[1:]\n\n        conv.messages = []\n        for j, sentence in enumerate(source):\n            role = roles[sentence[\"from\"]]\n            assert role == conv.roles[j % 2], ALTERNATING_ASSERTION_FAILED_ROLE\n            if sentence[\"value\"]:\n                conv.append_message(role, sentence[\"value\"])\n        yield conv\n\n\ndef load(tokenizer, cfg) -> LLama2ChatTokenizingStrategy:\n    return LLama2ChatTokenizingStrategy(\n        Llama2ChatPrompter(),\n        tokenizer,\n        cfg.train_on_inputs,\n        cfg.sequence_len,\n    )\n"
  },
  {
    "path": "src/axolotl/prompt_strategies/messages/__init__.py",
    "content": "\"\"\"Module to load message prompt strategies.\"\"\"\n\nimport importlib\nimport inspect\n\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\n\ndef load(tokenizer, cfg, ds_cfg, processor=None):\n    try:\n        strategy = ds_cfg.get(\"input_transform\", \"chat\")\n\n        load_fn = \"load\"\n        if strategy.split(\".\")[-1].startswith(\"load_\"):\n            load_fn = strategy.split(\".\")[-1]\n            strategy = \".\".join(strategy.split(\".\")[:-1])\n        mod = importlib.import_module(\n            f\".{strategy}\", \"axolotl.prompt_strategies.messages\"\n        )\n        func = getattr(mod, load_fn)\n        load_kwargs = {}\n        sig = inspect.signature(func)\n        if \"ds_cfg\" in sig.parameters:\n            load_kwargs[\"ds_cfg\"] = ds_cfg\n        if \"processor\" in sig.parameters:\n            load_kwargs[\"processor\"] = processor\n        return func(tokenizer, cfg, **load_kwargs)\n    except ModuleNotFoundError:\n        return None\n    except Exception as exc:\n        LOG.error(f\"Failed to load prompt strategy `{strategy}`: {str(exc)}\")\n        raise exc\n"
  },
  {
    "path": "src/axolotl/prompt_strategies/messages/chat.py",
    "content": "\"\"\"\nChat dataset wrapping strategy for new internal messages representations\n\"\"\"\n\nfrom typing import Any, Callable, Dict, Optional\n\nfrom axolotl.core.datasets.chat import TokenizedChatDataset\nfrom axolotl.core.datasets.transforms.chat_builder import chat_message_transform_builder\nfrom axolotl.prompt_tokenizers import DatasetWrappingStrategy\n\n\nclass ChatMessageDatasetWrappingStrategy(DatasetWrappingStrategy):\n    \"\"\"\n    Chat dataset wrapping strategy for new internal messages representations\n    \"\"\"\n\n    def __init__(\n        self,\n        processor,\n        message_transform=None,\n        formatter=None,\n        **kwargs,\n    ):\n        \"\"\"\n        :param processor: tokenizer or image processor\n        :param kwargs:\n        \"\"\"\n        self.processor = processor\n        self.dataset = None\n        self.message_transform = message_transform\n        self.formatter = formatter\n\n    def wrap_dataset(\n        self,\n        dataset,\n        process_count: Optional[int] = None,\n        keep_in_memory: Optional[bool] = False,\n        **kwargs,\n    ):\n        self.dataset = TokenizedChatDataset(\n            dataset,\n            message_transform=self.message_transform,\n            model_transform=self.processor,\n            formatter=self.formatter,\n            process_count=process_count,\n            keep_in_memory=keep_in_memory,\n        )\n        return self.dataset\n\n\ndef load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):\n    ds_cfg = ds_cfg or {}\n\n    field_messages = ds_cfg.get(\"field_messages\")\n    message_property_mappings = ds_cfg.get(\"message_property_mappings\")\n    message_field_role = (\n        message_property_mappings.get(\"role\") if message_property_mappings else None\n    )\n    message_field_content = (\n        message_property_mappings.get(\"content\") if message_property_mappings else None\n    )\n    message_field_training = ds_cfg.get(\"message_field_training\")\n\n    builder_kwargs = {}\n    if field_messages:\n        builder_kwargs[\"conversations_field\"] = field_messages\n    if message_field_role:\n        builder_kwargs[\"message_field_role\"] = message_field_role\n    if message_field_content:\n        builder_kwargs[\"message_field_content\"] = message_field_content\n    if message_field_training:\n        builder_kwargs[\"message_field_training\"] = message_field_training\n\n    chat_template = ds_cfg.get(\"chat_template\", cfg.get(\"chat_template\", \"chatml\"))\n\n    def format_message(x):\n        return x\n\n    if chat_template == \"chatml\":\n        from axolotl.core.chat.format.chatml import format_message  # noqa F811\n    if chat_template.startswith(\"llama3\"):\n        from axolotl.core.chat.format.llama3x import format_message  # noqa F811\n    message_transform: Callable = chat_message_transform_builder(\n        train_on_inputs=ds_cfg.get(\"train_on_inputs\", False),\n        **builder_kwargs,\n    )\n    strategy = ChatMessageDatasetWrappingStrategy(\n        tokenizer, message_transform=message_transform, formatter=format_message\n    )\n\n    return strategy\n"
  },
  {
    "path": "src/axolotl/prompt_strategies/metharme.py",
    "content": "\"\"\"Module containing the MetharmenPromptTokenizingStrategy and MetharmePrompter class\"\"\"\n\nfrom typing import Tuple\n\nfrom axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy\nfrom axolotl.prompters import AlpacaPrompter\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\nIGNORE_TOKEN_ID = -100\n\n\nclass MetharmePromptTokenizingStrategy(InstructionPromptTokenizingStrategy):\n    \"\"\"\n    Tokenizing strategy for the Metharme models\n    \"\"\"\n\n    def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:\n        return (prompt[\"prompt\"], \"\", prompt[\"generation\"])\n\n    def _tokenize(\n        self,\n        prompt: str,\n        add_eos_token: bool = True,\n        strip_bos_token: bool = False,\n        num_eos_tokens: int = 3,\n    ):\n        result = self.tokenizer(\n            prompt,\n            truncation=True,\n            max_length=self.sequence_len,\n            padding=False,\n            return_tensors=None,\n        )\n        if len(result[\"input_ids\"]) == 0:\n            LOG.warning(\"Tokenizer result is empty. You may want to audit your dataset\")\n        # If there's already an EOS token there, subtract from the number added\n        if result[\"input_ids\"][-1] == self.tokenizer.eos_token_id:\n            num_eos_tokens -= 1\n\n        if num_eos_tokens > 0 and add_eos_token and len(result[\"input_ids\"]) > 0:\n            for _ in range(num_eos_tokens):\n                if len(result[\"input_ids\"]) < self.sequence_len:\n                    result[\"input_ids\"].append(self.tokenizer.eos_token_id)\n                    result[\"attention_mask\"].append(1)\n\n        if result[\"input_ids\"][0] == self.tokenizer.bos_token_id and strip_bos_token:\n            result[\"input_ids\"] = result[\"input_ids\"][1:]\n            result[\"attention_mask\"] = result[\"attention_mask\"][1:]\n\n        result[\"labels\"] = result[\"input_ids\"].copy()\n        return result\n\n\nclass MetharmePrompter(AlpacaPrompter):\n    \"\"\"\n    Prompter for the Metharme models.\n    \"\"\"\n\n    system_prompt = \"\"\n    system_no_input_prompt = \"\"\n    system_format = \"\"\n    turn_format = \"{instruction}\"\n    turn_no_input_format = \"{instruction}\"\n\n    def __init__(self, *args, **kwargs):\n        pass\n\n\ndef load(tokenizer, cfg):\n    return MetharmePromptTokenizingStrategy(\n        MetharmePrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len\n    )\n"
  },
  {
    "path": "src/axolotl/prompt_strategies/orcamini.py",
    "content": "\"\"\"\nPrompt Strategy for finetuning Orca Mini (v2) models\nsee also https://huggingface.co/psmathur/orca_mini_v2_7b for more information\n\nUse dataset type: orcamini in conig.yml to use this prompt style.\n\nCompared to the alpaca_w_system.open_orca dataset type,\nthis one specifies the system prompt with \"### System:\".\n\nNot suited/tested for multiple-turn conversations without further adjustments.\n\"\"\"\n\nfrom typing import Generator, Union\n\nfrom axolotl.prompt_strategies.alpaca_w_system import OpenOrcaPromptTokenizingStrategy\nfrom axolotl.prompters import AlpacaPrompter\n\n\nclass OrcaMiniPrompter(AlpacaPrompter):\n    \"\"\"Adjusted Prompter for Orca Mini (v2) datasets\"\"\"\n\n    def match_prompt_style(self):\n        self.turn_no_input_format = (\n            \"### System:\\n{system}\\n\\n### User:\\n{instruction}\\n\\n### Response:\\n\"\n        )\n\n    def build_prompt_w_system(\n        self,\n        system: str,\n        instruction: str,\n        output: Union[None, str] = None,\n    ) -> Generator[str, None, None]:\n        # returns the full prompt from instruction and optional input\n        # if a label (=response, =output) is provided, it's also appended.\n        res = self.turn_no_input_format.format(system=system, instruction=instruction)\n        if output:\n            res = f\"{res}{output}\"\n        yield res\n\n\ndef load(tokenizer, cfg):\n    return OpenOrcaPromptTokenizingStrategy(\n        OrcaMiniPrompter(),\n        tokenizer,\n        cfg.train_on_inputs,\n        cfg.sequence_len,\n    )\n"
  },
  {
    "path": "src/axolotl/prompt_strategies/orpo/__init__.py",
    "content": "\"\"\"\nmodule for ORPO style dataset transform strategies\n\"\"\"\n\nfrom functools import partial\n\nfrom ..base import load as load_base\n\nload = partial(load_base, module_base=\"axolotl.prompt_strategies.orpo\")\n"
  },
  {
    "path": "src/axolotl/prompt_strategies/orpo/chat_template.py",
    "content": "\"\"\"chatml prompt tokenization strategy for ORPO\"\"\"\n\nfrom typing import Any, Dict, Generator, List, Optional, Tuple\n\nfrom pydantic import BaseModel\n\nfrom axolotl.prompt_tokenizers import IGNORE_INDEX, PromptTokenizingStrategy\nfrom axolotl.prompters import Prompter\nfrom axolotl.utils.chat_templates import get_chat_template_from_config\n\n\nclass Message(BaseModel):\n    \"\"\"message/turn\"\"\"\n\n    role: str\n    content: str\n    label: Optional[bool] = None\n\n\nclass MessageList(BaseModel):\n    \"\"\"conversation\"\"\"\n\n    messages: List[Message]\n\n\ndef load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, **kwargs):\n    \"\"\"\n    chatml transforms for datasets with system, input, chosen, rejected\n    \"\"\"\n    chat_template_string = get_chat_template_from_config(\n        cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer\n    )\n    tokenizer.chat_template = chat_template_string\n\n    return ORPOTokenizingStrategy(\n        ORPOPrompter(chat_template_string, tokenizer),\n        tokenizer,\n        cfg.train_on_inputs,\n        cfg.sequence_len,\n        dataset_parser=ORPODatasetParsingStrategy(),\n    )\n\n\nclass ORPODatasetParsingStrategy:\n    \"\"\"Strategy to parse chosen rejected dataset into messagelist\"\"\"\n\n    def get_chosen_conversation_thread(self, prompt) -> MessageList:\n        \"\"\"Dataset structure mappings\"\"\"\n\n        messages: List[Message] = []\n        if system := prompt.get(\"system\", None):\n            messages.append(Message(role=\"system\", content=system, label=False))\n        messages.append(\n            Message(role=\"user\", content=prompt[\"chosen\"][0][\"content\"], label=False)\n        )\n        messages.append(\n            Message(\n                role=\"assistant\", content=prompt[\"chosen\"][1][\"content\"], label=True\n            )\n        )\n        return MessageList(messages=messages)\n\n    def get_rejected_conversation_thread(self, prompt) -> MessageList:\n        \"\"\"Dataset structure mappings\"\"\"\n\n        messages: List[Message] = []\n        if system := prompt.get(\"system\", None):\n            messages.append(Message(role=\"system\", content=system, label=False))\n        messages.append(\n            Message(role=\"user\", content=prompt[\"rejected\"][0][\"content\"], label=False)\n        )\n        messages.append(\n            Message(\n                role=\"assistant\", content=prompt[\"rejected\"][1][\"content\"], label=True\n            )\n        )\n        return MessageList(messages=messages)\n\n    def get_prompt(self, prompt) -> MessageList:\n        \"\"\"Map the data to extract everything up to the last turn\"\"\"\n        total_msg_len = len(prompt[\"chosen\"])\n        total_msg_turns, remainder = divmod(total_msg_len, 2)\n        assert remainder == 0, \"invalid number of turns\"\n\n        messages: List[Message] = []\n        if system := prompt.get(\"system\", None):\n            messages.append(Message(role=\"system\", content=system, label=False))\n        for i in range(total_msg_turns):\n            if \"prompt\" in prompt:\n                messages.append(\n                    Message(role=\"user\", content=prompt[\"prompt\"], label=False)\n                )\n            else:\n                messages.append(\n                    Message(\n                        role=\"user\",\n                        content=prompt[\"chosen\"][i * 2][\"content\"],\n                        label=False,\n                    )\n                )\n            if i < total_msg_turns - 1:\n                messages.append(\n                    Message(\n                        role=\"assistant\",\n                        content=prompt[\"chosen\"][i * 2 + 1][\"content\"],\n                        label=False,\n                    )\n                )\n\n        return MessageList(messages=messages)\n\n    def get_chosen(self, prompt) -> MessageList:\n        res = self.get_prompt(prompt)\n        res.messages.append(\n            Message(\n                role=\"assistant\", content=prompt[\"chosen\"][-1][\"content\"], label=True\n            )\n        )\n        return res\n\n    def get_rejected(self, prompt) -> MessageList:\n        res = self.get_prompt(prompt)\n        res.messages.append(\n            Message(\n                role=\"assistant\", content=prompt[\"rejected\"][-1][\"content\"], label=True\n            )\n        )\n        return res\n\n\nclass ORPOTokenizingStrategy(PromptTokenizingStrategy):\n    \"\"\"\n    rejected_input_ids\n    input_ids\n    rejected_attention_mask\n    attention_mask\n    rejected_labels\n    labels\n    \"\"\"\n\n    def __init__(\n        self,\n        *args,\n        dataset_parser=None,\n        **kwargs,\n    ):\n        super().__init__(*args, **kwargs)\n        self.dataset_parser = dataset_parser\n\n    def tokenize_prompt(self, prompt):\n        # pass the rejected prompt/row to the Prompter to get the formatted prompt\n        prompt_len = 0\n        rejected_message_list: MessageList = (\n            self.dataset_parser.get_rejected_conversation_thread(prompt)\n        )\n        input_ids = []\n        labels = []\n        for _, (part, label) in enumerate(\n            self.prompter.build_prompt(rejected_message_list)\n        ):\n            if not part:\n                continue\n            _input_ids = self.tokenizer.encode(part, add_special_tokens=False)\n            prev_idx = len(input_ids)\n            input_ids += _input_ids[prev_idx:]\n            if label:\n                labels += input_ids[prev_idx:]\n            else:\n                labels += [IGNORE_INDEX] * (len(input_ids) - prev_idx)\n                prompt_len = len(input_ids)\n        # remap the input_ids, attention_mask and labels\n        rejected_input_ids = input_ids\n        rejected_labels = labels\n        # pass the chosen prompt/row to the Prompter to get the formatted prompt\n        chosen_message_list: MessageList = (\n            self.dataset_parser.get_chosen_conversation_thread(prompt)\n        )\n        input_ids = []\n        labels = []\n        for _, (part, label) in enumerate(\n            self.prompter.build_prompt(chosen_message_list)\n        ):\n            if not part:\n                continue\n            _input_ids = self.tokenizer.encode(part, add_special_tokens=False)\n            prev_idx = len(input_ids)\n            input_ids += _input_ids[prev_idx:]\n            if label:\n                labels += input_ids[prev_idx:]\n            else:\n                labels += [IGNORE_INDEX] * (len(input_ids) - prev_idx)\n\n        return {\n            \"rejected_input_ids\": rejected_input_ids,\n            \"rejected_labels\": rejected_labels,\n            \"rejected_attention_mask\": [1] * len(rejected_labels),\n            \"input_ids\": input_ids,\n            \"labels\": labels,\n            \"attention_mask\": [1] * len(labels),\n            \"prompt_attention_mask\": [1] * prompt_len\n            + [0] * (len(labels) - prompt_len),\n        }\n\n\nclass ORPOPrompter(Prompter):\n    \"\"\"Single Turn prompter for ORPO\"\"\"\n\n    def __init__(self, chat_template, tokenizer):\n        self.chat_template = chat_template\n        self.tokenizer = tokenizer\n\n    def build_prompt(\n        self,\n        message_list: MessageList,\n    ) -> Generator[Tuple[str, bool], None, None]:\n        conversation = []\n        for message in message_list.messages:\n            conversation.append(message.model_dump())\n            if message.role == \"system\":\n                yield (\n                    self.tokenizer.apply_chat_template(\n                        conversation,\n                        add_generation_prompt=False,\n                        chat_template=self.chat_template,\n                        tokenize=False,\n                    ),\n                    False,\n                )\n            if message.role == \"user\":\n                yield (\n                    self.tokenizer.apply_chat_template(\n                        conversation,\n                        add_generation_prompt=True,\n                        chat_template=self.chat_template,\n                        tokenize=False,\n                    ),\n                    False,\n                )\n            if message.role == \"assistant\":\n                yield (\n                    self.tokenizer.apply_chat_template(\n                        conversation,\n                        add_generation_prompt=False,\n                        chat_template=self.chat_template,\n                        tokenize=False,\n                    ),\n                    True,\n                )\n\n\ndef argilla(cfg, **kwargs):\n    dataset_parser = ORPODatasetParsingStrategy()\n\n    def transform_fn(sample, tokenizer=None):\n        res = {}\n\n        chat_template_string = get_chat_template_from_config(\n            cfg=cfg, tokenizer=tokenizer\n        )\n\n        res[\"prompt\"] = tokenizer.apply_chat_template(\n            [msg.model_dump() for msg in dataset_parser.get_prompt(sample).messages],\n            add_generation_prompt=True,\n            chat_template=chat_template_string,\n            tokenize=False,\n        )\n        prompt_str_len = len(res[\"prompt\"])\n        res[\"chosen\"] = tokenizer.apply_chat_template(\n            [msg.model_dump() for msg in dataset_parser.get_chosen(sample).messages],\n            add_generation_prompt=False,\n            chat_template=chat_template_string,\n            tokenize=False,\n        )[prompt_str_len:]\n        res[\"rejected\"] = tokenizer.apply_chat_template(\n            [msg.model_dump() for msg in dataset_parser.get_rejected(sample).messages],\n            add_generation_prompt=False,\n            chat_template=chat_template_string,\n            tokenize=False,\n        )[prompt_str_len:]\n\n        return res\n\n    return transform_fn\n"
  },
  {
    "path": "src/axolotl/prompt_strategies/pretrain.py",
    "content": "\"\"\"pretraining prompt strategies\"\"\"\n\nfrom typing import Generator\n\nfrom transformers import BatchEncoding\n\nfrom axolotl.prompt_tokenizers import PromptTokenizingStrategy\n\n\nclass PretrainTokenizer:\n    \"\"\"basic tokenization class for pretraining\"\"\"\n\n    def build_prompt(self, prompt) -> Generator[str, None, None]:\n        yield prompt\n\n\nclass PretrainTokenizationStrategy(PromptTokenizingStrategy):\n    \"\"\"handles tokenization for pretraining with strides\"\"\"\n\n    @property\n    def supports_batched(self):\n        return True\n\n    def __init__(self, *args, max_length=None, text_column=\"text\", **kwargs):\n        super().__init__(*args, **kwargs)\n        if max_length:\n            self.max_length = max_length\n        self.text_column = text_column\n\n    def _tokenize(\n        self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False\n    ) -> BatchEncoding:\n        res = self.tokenizer(\n            prompt,\n            truncation=True,\n            max_length=self.max_length - 1,\n            add_special_tokens=True,\n            return_overflowing_tokens=True,\n            stride=256,\n        )\n        res[\"input_ids\"] = [\n            seq + [self.tokenizer.eos_token_id] for seq in res[\"input_ids\"]\n        ]\n        res[\"attention_mask\"] = [seq + [1] for seq in res[\"attention_mask\"]]\n\n        return res\n\n    def tokenize_prompt(self, prompt):\n        return self._tokenize(prompt[self.text_column])\n\n\ndef load(tokenizer, cfg):\n    strat = PretrainTokenizationStrategy(\n        PretrainTokenizer(),\n        tokenizer,\n        cfg.train_on_inputs,\n        cfg.sequence_len,\n        text_column=cfg.pretraining_dataset[0][\"text_column\"] or \"text\",\n        max_length=cfg.sequence_len * 64,\n    )\n    return strat\n"
  },
  {
    "path": "src/axolotl/prompt_strategies/pygmalion.py",
    "content": "\"\"\"Module containing the PygmalionPromptTokenizingStrategy and PygmalionPrompter class\"\"\"\n\nimport copy\nfrom collections import defaultdict\nfrom typing import Generator, List, Tuple\n\nfrom axolotl.prompt_tokenizers import (\n    PromptTokenizingStrategy,\n    parse_tokenized_to_result,\n    tokenize_prompt_default,\n)\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\nIGNORE_TOKEN_ID = -100\n\n\nclass PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy):\n    \"\"\"\n    Tokenizing strategy for Pygmalion.\n    \"\"\"\n\n    bot_prefix_token_ids: List[int] = []\n\n    def __init__(self, prompter, tokenizer, *args, **kwargs):\n        super().__init__(prompter, tokenizer, *args, **kwargs)\n        res = self._tokenize(\"<|model|>\", add_eos_token=False, strip_bos_token=True)\n        self.bot_prefix_token_ids = res[\"input_ids\"]\n\n    def tokenize_prompt(self, prompt):\n        result, current_len = tokenize_prompt_default()\n        for _, part in enumerate(self.prompter.build_prompt(prompt[\"conversations\"])):\n            role, message = part\n            if role == \"system\":\n                prefix = \"<|system|>\"\n                # this should include a bos token, no eos token, strip trailing \"\\n<START>\"\n                if message.endswith(\"\\n<START>\"):\n                    message = message[:-8]\n                res = self._tokenize(\n                    prefix + \"Persona: \" + message.strip(),\n                    add_eos_token=False,\n                    strip_bos_token=False,\n                )\n                # everything from this is masked out from the labels\n                labels = [IGNORE_TOKEN_ID] * len(res[\"input_ids\"])\n            elif role == \"human\":\n                prefix = \"<|user|>\"\n                res = self._tokenize(\n                    prefix + \" \" + message.strip(),\n                    add_eos_token=False,\n                    strip_bos_token=True,\n                )\n                # everything from this is masked out from the labels\n                labels = [IGNORE_TOKEN_ID] * len(res[\"input_ids\"])\n            elif role == \"bot\":\n                prefix = \"<|model|>\"\n                res = self._tokenize(\n                    prefix + \" \" + message.strip(),\n                    add_eos_token=True,\n                    strip_bos_token=True,\n                )\n                # mask out the prefix token, rest is not masked out from labels\n                # make sure we create the labels first, otherwise we get incorrect lengths\n                labels = [IGNORE_TOKEN_ID] * len(self.bot_prefix_token_ids) + [\n                    *copy.deepcopy(res[\"input_ids\"])\n                ][len(self.bot_prefix_token_ids) :]\n            else:\n                LOG.warning(f\"unknown role in conversation: {role}\")\n                res = defaultdict(lambda: [])\n\n            result, current_len = parse_tokenized_to_result(\n                result,\n                current_len,\n                res,\n                labels,\n                pad_token_id=self.tokenizer.pad_token_id,\n            )\n        return result\n\n\nclass PygmalionPrompter:\n    \"\"\"\n    Prompter for Pygmalion.\n    \"\"\"\n\n    def __init__(self, *args, **kwargs):\n        pass\n\n    def build_prompt(\n        self,\n        source,\n        *args,\n        **kwargs,\n    ) -> Generator[Tuple[str, str], None, None]:\n        for msg in source:\n            yield msg[\"role\"], msg[\"value\"]\n\n\ndef load(tokenizer, cfg):\n    return PygmalionPromptTokenizingStrategy(\n        PygmalionPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len\n    )\n"
  },
  {
    "path": "src/axolotl/prompt_strategies/stepwise_supervised.py",
    "content": "\"\"\"\nModule for stepwise datasets, typically including a prompt and reasoning traces,\nand (optionally) per-step, or per-prompt-trace labels for reward modelling.\n\"\"\"\n\nfrom itertools import chain\nfrom typing import Dict, List, Optional, Union\n\nfrom transformers import BatchEncoding, PreTrainedTokenizer\n\nfrom axolotl.prompt_tokenizers import IGNORE_INDEX\nfrom axolotl.utils.dict import DictDefault\n\n\nclass StepwiseSupervisedPromptTokenizingStrategy:\n    \"\"\"\n    Tokenizing strategy for supervised stepwise datasets, typically used for COT-reasoning.\n    These datasets should include the following columns:\n    - prompt: the prompt text\n    - completions: a list of `n` completion steps\n    - labels: a list of `n` labels indicating the \"correctness\" of each step\n    \"\"\"\n\n    def __init__(\n        self,\n        tokenizer,\n        sequence_len: int = 2048,\n        step_separator: str = \"\\n\",\n        max_completion_length: Optional[int] = None,\n        train_on_last_step_only: bool = False,\n    ):\n        self.tokenizer = tokenizer\n        self.sequence_len = sequence_len\n        self.step_separator = step_separator\n        self.max_completion_length = max_completion_length\n        self.train_on_last_step_only = train_on_last_step_only\n\n    def tokenize_prompt(\n        self, prompt: Dict[str, Union[str, List[str]]]\n    ) -> BatchEncoding:\n        # Inspired by TRL's PRMTRainer\n        # https://github.com/huggingface/trl/blob/ed7de87dc766478c024b68f12530d1b0e7c3ff23/trl/trainer/prm_trainer.py#L206\n        prompt_ids = self.tokenizer(prompt[\"prompt\"], add_special_tokens=False)[\n            \"input_ids\"\n        ]\n\n        completions_ids = [\n            self.tokenizer(completion, add_special_tokens=False)[\"input_ids\"]\n            for completion in prompt[\"completions\"]\n        ]\n\n        # Handle labels\n        if self.train_on_last_step_only:\n            labels = [IGNORE_INDEX] * (len(prompt[\"labels\"]) - 1) + [\n                int(prompt[\"labels\"][-1])\n            ]\n        else:\n            labels = [int(label) for label in prompt[\"labels\"]]\n\n        # Add step separators\n        separator_ids = self.tokenizer.encode(\n            self.step_separator, add_special_tokens=False\n        )\n        completions_ids = [completion + separator_ids for completion in completions_ids]\n\n        # Create step-wise labels\n        labels = [\n            [IGNORE_INDEX] * (len(completion) - 1) + [label]  # type: ignore\n            for completion, label in zip(completions_ids, labels, strict=False)\n        ]\n\n        # Join all steps\n        completion_ids = list(chain(*completions_ids))\n        labels = list(chain(*labels))  # type: ignore\n\n        # Handle max lengths\n        if self.max_completion_length:\n            completion_ids = completion_ids[: self.max_completion_length]\n            labels = labels[: self.max_completion_length]\n\n        # Add BOS token if model has one\n        if self.tokenizer.bos_token_id is not None:\n            prompt_ids = [self.tokenizer.bos_token_id] + prompt_ids\n\n        # Combine prompt and completion\n        input_ids = prompt_ids + completion_ids\n\n        full_labels = [IGNORE_INDEX] * len(prompt_ids) + labels\n        # Apply max sequence length\n        if self.sequence_len:\n            input_ids = input_ids[: self.sequence_len]\n            full_labels = full_labels[: self.sequence_len]\n\n        return {\n            \"input_ids\": input_ids,\n            \"labels\": full_labels,\n            \"attention_mask\": [1] * len(input_ids),\n        }\n\n    @property\n    def supports_batched(self):\n        return False\n\n\ndef load(\n    tokenizer: PreTrainedTokenizer,\n    cfg: DictDefault,\n    ds_cfg: DictDefault,\n) -> StepwiseSupervisedPromptTokenizingStrategy:\n    return StepwiseSupervisedPromptTokenizingStrategy(\n        tokenizer,\n        cfg.sequence_len,\n        step_separator=ds_cfg.get(\"step_separator\", \"\\n\"),\n        max_completion_length=ds_cfg.max_completion_length,\n        train_on_last_step_only=ds_cfg.get(\"train_on_last_step_only\", False),\n    )\n"
  },
  {
    "path": "src/axolotl/prompt_strategies/user_defined.py",
    "content": "\"\"\"\nUser Defined prompts with configuration from the YML config\n\"\"\"\n\nfrom dataclasses import dataclass\nfrom functools import partial\nfrom typing import Optional, Tuple\n\nfrom axolotl.prompt_strategies.alpaca_w_system import (\n    InstructionWSystemPromptTokenizingStrategy,\n    SystemDataPrompter,\n)\n\n\n@dataclass\nclass UserDefinedDatasetConfig:\n    \"\"\"\n    dataclass configuration representing a userdefined dataset type\n    \"\"\"\n\n    system_prompt: str = \"\"\n    field_system: str = \"system\"\n    field_instruction: str = \"instruction\"\n    field_input: str = \"input\"\n    field_output: str = \"output\"\n    format: str = \"{instruction} {input} \"\n    no_input_format: str = \"{instruction} \"\n    system_format: str = \"{system}\"\n\n    def __getitem__(self, item):\n        return getattr(self, item)\n\n\nclass UserDefinedPromptTokenizationStrategy(InstructionWSystemPromptTokenizingStrategy):\n    \"\"\"\n    Prompt Tokenization Strategy for user defined prompts\n    \"\"\"\n\n\ndef load(tokenizer, cfg, ds_cfg: Optional[UserDefinedDatasetConfig] = None):\n    if not ds_cfg:\n        raise ValueError(\"Missing dataset prompt configuration\")\n\n    system_prompt = \"\"\n    if ds_cfg.system_prompt:\n        system_prompt = ds_cfg.system_prompt\n\n    def parse_instruction_fields(\n        field_instruction,\n        field_input,\n        field_output,\n        field_system,\n        system_prompt,\n        prompt,\n    ) -> Tuple[str, str, str, str]:\n        return (\n            prompt[field_instruction],\n            prompt[field_input] if field_input in prompt else \"\",\n            prompt[field_output] if field_output in prompt else \"\",\n            prompt[field_system] if field_system in prompt else system_prompt,\n        )\n\n    turn_format = ds_cfg.format\n    turn_no_input_format = ds_cfg.no_input_format\n    system_format = ds_cfg.system_format\n\n    class UserDefinedPrompter(SystemDataPrompter):\n        \"\"\"\n        Prompter for user defined prompts\n        \"\"\"\n\n        def match_prompt_style(self):\n            self.turn_format = turn_format\n            self.turn_no_input_format = turn_no_input_format\n            self.system_format = system_format\n\n    prompter = UserDefinedPrompter()\n\n    strat = UserDefinedPromptTokenizationStrategy(\n        prompter,\n        tokenizer,\n        cfg.train_on_inputs,\n        cfg.sequence_len,\n    )\n\n    strat.parse_instruction_fields = partial(  # type: ignore[method-assign]\n        parse_instruction_fields,\n        ds_cfg.field_instruction,\n        ds_cfg.field_input,\n        ds_cfg.field_output,\n        ds_cfg.field_system,\n        system_prompt,\n    )\n    return strat\n"
  },
  {
    "path": "src/axolotl/prompt_tokenizers.py",
    "content": "\"\"\"Module containing PromptTokenizingStrategy and Prompter classes\"\"\"\n\nimport abc\nfrom typing import Callable, Dict, List, Optional, Tuple, Union\n\nfrom datasets import Dataset\nfrom transformers import BatchEncoding, PreTrainedTokenizer\n\nfrom axolotl.prompters import Prompter\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\nIGNORE_INDEX = -100\nLLAMA_DEFAULT_PAD_TOKEN = \"<pad>\"  # nosec\nLLAMA_DEFAULT_EOS_TOKEN = \"</s>\"  # nosec\nLLAMA_DEFAULT_BOS_TOKEN = \"<s>\"  # nosec\nLLAMA_DEFAULT_UNK_TOKEN = \"<unk>\"  # nosec\n\n\nclass InvalidDataException(Exception):\n    \"\"\"\n    Exception raised when the data is invalid\n    \"\"\"\n\n\nclass DatasetWrappingStrategy(abc.ABC):\n    \"\"\"\n    Abstract class for wrapping datasets for Chat Messages\n    \"\"\"\n\n    @abc.abstractmethod\n    def wrap_dataset(\n        self,\n        dataset,\n        process_count: int | None = None,\n        keep_in_memory: bool | None = False,\n        **kwargs,\n    ) -> Dataset:\n        pass\n\n\nclass PromptTokenizingStrategy(abc.ABC):\n    \"\"\"\n    Abstract class for tokenizing strategies\n    \"\"\"\n\n    filter_rows: Optional[Callable] = None\n\n    def __init__(\n        self,\n        prompter: Prompter,\n        tokenizer,\n        train_on_inputs: bool = False,\n        sequence_len: int = 2048,\n    ):\n        self.prompter = prompter\n        self.tokenizer: PreTrainedTokenizer = tokenizer\n        self.train_on_inputs = train_on_inputs\n        # sequence_len and max_length can be different for CompletionPromptTokenizingStrategy.\n        # TODO: Document how they are different.\n        self.sequence_len = sequence_len\n        self.max_length = sequence_len\n\n    @abc.abstractmethod\n    def tokenize_prompt(self, prompt):\n        pass\n\n    @property\n    def supports_batched(self):\n        return False\n\n    def _tokenize(\n        self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False\n    ) -> BatchEncoding:\n        empty = BatchEncoding(data={\"input_ids\": [], \"attention_mask\": []})\n        if not prompt:\n            LOG.warning_once(\"Empty text requested for tokenization.\")\n            return empty\n\n        result = self.tokenizer(\n            prompt,\n            truncation=True,\n            max_length=self.max_length,\n            padding=False,\n            return_tensors=None,\n        )\n        if len(result[\"input_ids\"]) == 0:\n            LOG.warning(\"Tokenizer result is empty. You may want to audit your dataset\")\n            return empty\n\n        if (\n            result[\"input_ids\"][-1] != self.tokenizer.eos_token_id\n            and len(result[\"input_ids\"]) < self.max_length\n            and add_eos_token\n        ):\n            result[\"input_ids\"].append(self.tokenizer.eos_token_id)\n            result[\"attention_mask\"].append(1)\n\n        if result[\"input_ids\"][0] == self.tokenizer.bos_token_id and strip_bos_token:\n            result[\"input_ids\"] = result[\"input_ids\"][1:]\n            result[\"attention_mask\"] = result[\"attention_mask\"][1:]\n\n        result[\"labels\"] = result[\"input_ids\"].copy()\n        return result\n\n\nclass InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):\n    \"\"\"\n    Tokenizing strategy for instruction-based prompts.\n    \"\"\"\n\n    def parse_instruction_fields(\n        self, prompt\n    ) -> Union[Tuple[str, str, str], Tuple[str, str, str, str]]:\n        raise NotImplementedError\n\n    def tokenize_prompt(self, prompt):\n        (\n            instruction,\n            input,\n            response,\n        ) = self.parse_instruction_fields(prompt)\n        user_prompt = next(\n            iter(\n                self.prompter.build_prompt(\n                    instruction,\n                    input,\n                )\n            )\n        )\n        tokenized_prompt = self._tokenize(user_prompt, add_eos_token=False)\n        if not self.train_on_inputs:\n            user_prompt_len = len(tokenized_prompt[\"input_ids\"])\n            # TODO this could be sped up using numpy array slicing\n            tokenized_prompt[\"labels\"] = [IGNORE_INDEX] * user_prompt_len\n        tokenized_res_prompt = self._tokenize(\n            response, strip_bos_token=True, add_eos_token=True\n        )\n        tokenized_prompt[\"input_ids\"] += tokenized_res_prompt[\"input_ids\"]\n        tokenized_prompt[\"attention_mask\"] += tokenized_res_prompt[\"attention_mask\"]\n        tokenized_prompt[\"labels\"] += tokenized_res_prompt[\"input_ids\"]\n\n        return tokenized_prompt\n\n    def _build_full_prompt(\n        self,\n        instruction,\n        input,\n        response,\n    ):\n        return next(\n            iter(\n                self.prompter.build_prompt(\n                    instruction,\n                    input,\n                    response,\n                )\n            )\n        )\n\n\nclass AlpacaPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):\n    \"\"\"\n    Tokenizing strategy for Alpaca prompts.\n    \"\"\"\n\n    def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:\n        return (\n            prompt[\"instruction\"],\n            prompt[\"input\"] if \"input\" in prompt else \"\",\n            prompt[\"output\"],\n        )\n\n\nclass AlpacaMultipleChoicePromptTokenizingStrategy(InstructionPromptTokenizingStrategy):\n    \"\"\"\n    Tokenizing strategy for Alpaca Multiple Choice prompts.\n    \"\"\"\n\n    def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:\n        return (\n            prompt[\"question\"],\n            \"\\n\".join(f'- \"{choice}\"' for choice in prompt[\"choices\"]),\n            prompt[\"solution\"] if \"solution\" in prompt else prompt[\"explanation\"],\n        )\n\n\nclass JeopardyPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):\n    \"\"\"\n    Tokenizing strategy for Jeopardy prompts.\n    \"\"\"\n\n    def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:\n        return (\n            prompt[\"question\"],\n            prompt[\"category\"],\n            \"what is \" + prompt[\"answer\"],\n        )\n\n\nclass OpenAssistantPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):\n    \"\"\"\n    Tokenizing strategy for OpenAssistant prompts.\n    \"\"\"\n\n    def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:\n        return (\n            prompt[\"INSTRUCTION\"],\n            \"\",\n            prompt[\"RESPONSE\"],\n        )\n\n\nclass SummarizeTLDRPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):\n    \"\"\"\n    Tokenizing strategy for SummarizeTLDR prompts.\n    \"\"\"\n\n    def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:\n        return (\n            prompt[\"article\"],\n            \"\",\n            prompt[\"summary\"],\n        )\n\n\nclass GPTeacherPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):\n    \"\"\"\n    Tokenizing strategy for GPTeacher prompts.\n    \"\"\"\n\n    def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:\n        return (\n            prompt[\"instruction\"],\n            prompt[\"input\"] if \"input\" in prompt else \"\",\n            prompt[\"response\"],\n        )\n\n\nclass NomicGPT4AllPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):\n    \"\"\"\n    Tokenizing strategy for NomicGPT4All prompts.\n    \"\"\"\n\n    def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:\n        return (\n            prompt[\"prompt\"],\n            \"\",\n            prompt[\"response\"],\n        )\n\n\nclass ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):\n    \"\"\"\n    Tokenizing strategy for Reflection prompts.\n    \"\"\"\n\n    def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str, str]:\n        raise NotImplementedError\n\n    def tokenize_prompt(self, prompt):\n        (\n            instruction,\n            input,\n            output,\n            reflection,\n            corrected,\n        ) = self.parse_instruction_fields(prompt)\n        full_prompt = self._build_full_prompt(\n            instruction, input, output, reflection, corrected\n        )\n        tokenized_full_prompt = self._tokenize(full_prompt)\n        if not self.train_on_inputs:\n            user_prompt = next(\n                iter(\n                    self.prompter.build_prompt(\n                        instruction,\n                        input,\n                    )\n                )\n            )\n            tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)\n            user_prompt_len = len(tokenized_user_prompt[\"input_ids\"])\n            # TODO this could be sped up using numpy array slicing\n            tokenized_full_prompt[\"labels\"] = [\n                IGNORE_INDEX\n            ] * user_prompt_len + tokenized_full_prompt[\"labels\"][user_prompt_len:]\n\n        return tokenized_full_prompt\n\n    def _build_full_prompt(self, instruction, input, output, reflection, corrected):\n        return next(\n            iter(\n                self.prompter.build_prompt(\n                    instruction,\n                    input,\n                    output,\n                    reflection,\n                    corrected,\n                )\n            )\n        )\n\n    def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):\n        result = self.tokenizer(\n            prompt,\n            truncation=True,\n            max_length=self.sequence_len,\n            padding=False,\n            return_tensors=None,\n        )\n        if (\n            result[\"input_ids\"][-1] != self.tokenizer.eos_token_id\n            and len(result[\"input_ids\"]) < self.sequence_len\n            and add_eos_token\n        ):\n            result[\"input_ids\"].append(self.tokenizer.eos_token_id)\n            result[\"attention_mask\"].append(1)\n\n        result[\"labels\"] = result[\"input_ids\"].copy()\n        return result\n\n\nclass AlpacaReflectionPTStrategy(ReflectionPromptTokenizingStrategy):\n    \"\"\"\n    Tokenizing strategy for Alpaca Reflection prompts.\n    \"\"\"\n\n    def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str, str]:\n        return (\n            prompt[\"instruction\"],\n            prompt[\"input\"] if \"input\" in prompt else \"\",\n            prompt[\"output\"],\n            prompt[\"reflection\"],\n            prompt[\"corrected\"],\n        )\n\n\ndef tokenize_prompt_default() -> Tuple[Dict[str, List[int]], int]:\n    \"\"\"\n    Returns the default values for the tokenize prompt function\n    \"\"\"\n\n    result: Dict[str, List[int]] = {\n        \"input_ids\": [],\n        \"attention_mask\": [],\n        \"labels\": [],\n    }\n    current_len = 0\n    return result, current_len\n\n\ndef parse_tokenized_to_result(\n    result: Dict[str, List[int]],\n    current_len: int,\n    res: Dict[str, List[int]],\n    labels: List[int],\n    pad_token_id: Union[int, None] = None,\n) -> Tuple[Dict[str, List[int]], int]:\n    \"\"\"\n    Parses the tokenized prompt and append the tokenized input_ids, attention_mask and labels to the result\n    \"\"\"\n\n    input_ids = res[\"input_ids\"]\n    input_len = len(input_ids)\n    result[\"input_ids\"][current_len : current_len + input_len] = input_ids\n    result[\"attention_mask\"][current_len : current_len + input_len] = [\n        1 if x != pad_token_id else 0 for x in input_ids\n    ]\n    result[\"labels\"][current_len : current_len + input_len] = labels\n    current_len += input_len\n\n    return result, current_len\n"
  },
  {
    "path": "src/axolotl/prompters.py",
    "content": "\"\"\"Module containing prompters\"\"\"\n\nfrom enum import Enum\nfrom typing import Generator, Optional, Union\n\nfrom colorama import Fore\n\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\nIGNORE_TOKEN_ID = -100\nREPR_TEMPLATE = \"\\n<start>\\n\" + Fore.CYAN + \"{full_prompt}\" + Fore.RESET + \"\\n<end>\\n\"\n\n\nclass PromptStyle(Enum):\n    \"\"\"\n    Enum for prompt styles\n    \"\"\"\n\n    INSTRUCT = \"instruct\"\n    CHAT = \"chat\"\n    CHATML = \"chatml\"\n    PHI = \"phi\"\n\n\nclass Prompter:\n    \"\"\"\n    Base prompter class for all prompters\n    \"\"\"\n\n\nclass AlpacaPrompter(Prompter):\n    \"\"\"\n    Base class for alpaca prompters\n    \"\"\"\n\n    system_prompt = \"Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\"\n    system_no_input_prompt = \"Below is an instruction that describes a task. Write a response that appropriately completes the request.\"\n    system_format: str = \"{system}\"\n    turn_format: str\n    turn_no_input_format: str\n    prompt_style: Optional[str] = None\n\n    def __init__(self, prompt_style: Optional[str] = PromptStyle.INSTRUCT.value):\n        self.prompt_style = prompt_style if prompt_style else PromptStyle.INSTRUCT.value\n        self.match_prompt_style()\n\n    def match_prompt_style(self):\n        if self.prompt_style == PromptStyle.INSTRUCT.value:\n            self.turn_format = \"### Instruction:\\n{instruction}\\n\\n### Input:\\n{input}\\n\\n### Response:\\n\"\n            self.turn_no_input_format = (\n                \"### Instruction:\\n{instruction}\\n\\n### Response:\\n\"\n            )\n            self.system_format = \"{system}\\n\\n\"\n        elif self.prompt_style == PromptStyle.CHAT.value:\n            self.turn_format = \"USER: {instruction}\\n{input}\\nASSISTANT:\"\n            self.turn_no_input_format = \"USER: {instruction}\\nASSISTANT:\"\n            self.system_format = \"SYSTEM: {system}\\n\"\n        elif self.prompt_style == PromptStyle.CHATML.value:\n            self.turn_format = \"<|im_start|>user\\n{instruction}\\n{input}<|im_end|>\\n<|im_start|>assistant\\n\"\n            self.turn_no_input_format = (\n                \"<|im_start|>user\\n{instruction}<|im_end|>\\n<|im_start|>assistant\\n\"\n            )\n            self.system_format = \"<|im_start|>system\\n{system}<|im_end|>\\n\"\n        elif self.prompt_style == PromptStyle.PHI.value:\n            self.turn_format = \"<|user|>\\n{instruction}<|end|>{input}<|assistant|>\"\n            self.turn_no_input_format = (\n                \"<|user|>\\n{instruction}<|end|>\\n<|assistant|>\\n\"\n            )\n            self.system_format = \"<|system|>\\n{system}<|end|>\\n\"\n\n    def _build_result(self, instruction, input_text, output):\n        # returns the full prompt from instruction and optional input\n        # if a label (=response, =output) is provided, it's also appended.\n        if input_text:\n            res = (\n                self.system_format.format(system=self.system_prompt)\n                if self.system_prompt\n                else \"\"\n            ) + self.turn_format.format(instruction=instruction, input=input_text)\n        else:\n            res = (\n                self.system_format.format(system=self.system_no_input_prompt)\n                if self.system_no_input_prompt\n                else \"\"\n            ) + self.turn_no_input_format.format(instruction=instruction)\n        if output:\n            res = f\"{res}{output}\"\n\n        return res\n\n    def build_prompt(\n        self,\n        instruction: str,\n        input: Union[None, str] = None,\n        output: Union[None, str] = None,\n    ) -> Generator[str, None, None]:\n        yield self._build_result(instruction, input, output)\n\n    def __repr__(self) -> str:\n        return REPR_TEMPLATE.format(\n            full_prompt=self._build_result(\"{instruction}\", \"{input}\", \"{output}\")\n        )\n\n\nclass UnpromptedPrompter(AlpacaPrompter):\n    \"\"\"\n    Prompter for alpaca no system prompt\n    \"\"\"\n\n    system_prompt = \"\"\n    system_no_input_prompt = \"\"\n\n\nclass JeopardyPrompter(AlpacaPrompter):\n    \"\"\"\n    Prompter for Jeopardy\n    \"\"\"\n\n    prompt_input = \"Below is a Jeopardy clue paired with input providing the category of the clue. Write a concise response that best answers tbe clue given the category.\\n\\n### Instruction:\\n{instruction}\\n\\n### Input:\\n{input}\\n\\n### Response:\\n\"\n\n\nclass MultipleChoiceExplainPrompter(AlpacaPrompter):\n    \"\"\"\n    Prompter for multiple choice explain\n    \"\"\"\n\n    system_prompt = (\n        \"Choose the answer that best answers the question. Explain your reasoning.\\n\"\n    )\n    system_no_input_prompt = (\n        \"Choose the answer that best answers the question. Explain your reasoning.\\n\"\n    )\n\n\nclass MultipleChoiceConcisePrompter(AlpacaPrompter):\n    \"\"\"\n    Prompter for multiple choice concise\n    \"\"\"\n\n    system_prompt = \"Choose the answer that best answers the question. Be concise in your response.\\n\\n\"\n    system_no_input_prompt = \"Choose the answer that best answers the question. Be concise in your response.\\n\\n\"\n\n    def match_prompt_style(self):\n        self.turn_format = \"USER: {instruction}\\n{input}\\nASSISTANT:\"\n        self.turn_no_input_format = \"USER: {instruction}\\nASSISTANT:\"\n\n\nclass SummarizeTLDRPrompter(AlpacaPrompter):\n    \"\"\"\n    Prompter for summarize TLDR\n    \"\"\"\n\n    system_prompt = \"\"\n    system_no_input_prompt = \"\"\n\n    def match_prompt_style(self):\n        self.turn_format = \"USER: Summarize the following article as a TL;DR.\\n{instruction}\\n{input}\\nASSISTANT:\"\n        self.turn_no_input_format = \"USER: Summarize the following article as a TL;DR.\\n{instruction}\\nASSISTANT:\"\n\n\nclass GPTeacherPrompter(AlpacaPrompter):\n    \"\"\"\n    Prompter for GPTeacher\n    \"\"\"\n\n\nclass NomicGPT4AllPrompter(AlpacaPrompter):\n    \"\"\"\n    Prompter for NomicGPT4All\n    \"\"\"\n\n\nclass ReflectAlpacaPrompter(Prompter):\n    \"\"\"\n    Prompter for ReflectAlpaca\n    \"\"\"\n\n    system_prompt = \"Below is an instruction that describes a task, paired with an input that provides further context. You, the Assistant, should generate a response as if it were an abstract for an academic or technical paper on the query along with a methodology. Then generate an Agent Reflection where you create a long form response as if from subject matter expert, be verbose, diligent, and creative in your application of knowledge, apply it through the lens of the response generated by the assistant. Look for flawed reasoning, faulty logic, or other mistakes in the method. Finally, generate a final response and method for the user with the Assistant abstract and Reflection analysis as augmentations to the generation\\n\\n\"\n    system_no_input_prompt = \"Below is an instruction that describes a task. You, the Assistant, should generate a response as if it were an abstract for an academic or technical paper on the query along with a methodology. Then generate an Agent Reflection where you create a long form response as if from subject matter expert, be verbose, diligent, and creative in your application of knowledge, apply it through the lens of the response generated by the assistant. Look for flawed reasoning, faulty logic, or other mistakes in the method. Finally, generate a final response and method for the user with the Assistant abstract and Reflection analysis as augmentations to the generation\\n\\n\"\n\n    prompt_input = (\n        \"### Instruction:\\n{instruction}\\n\\n### Input:\\n{input}\\n\\n### Response:\\n\"\n    )\n    prompt_no_input = \"### Instruction:\\n{instruction}\\n\\n### Response:\\n\"\n    agent_label = \"### Thought:\\n{output}\\n\\n### Agent Reflection:\\n{reflection}\\n\\n### Final Response:\\n{corrected}\"\n    response_split = \"### Response:\"\n\n    def __init__(self, prompt_style=\"instruct\"):\n        self.prompt_style = prompt_style\n        self.match_prompt_style()\n\n    def match_prompt_style(self):\n        if self.prompt_style == PromptStyle.INSTRUCT.value:\n            self.prompt_input = (\n                self.system_prompt\n                + \"### Instruction:\\n{instruction}\\n\\n### Input:\\n{input}\\n\\n### Response:\\n\"\n            )\n            self.prompt_no_input = (\n                self.system_no_input_prompt\n                + \"### Instruction:\\n{instruction}\\n\\n### Response:\\n\"\n            )\n            self.agent_label = \"### Thought:\\n{output}\\n\\n### Agent Reflection:\\n{reflection}\\n\\n### Final Response:\\n{corrected}\"\n            self.response_split = \"### Final Response:\"\n        if self.prompt_style == PromptStyle.CHAT.value:\n            self.prompt_input = (\n                self.system_prompt + \"USER: {instruction}\\n{input}\\nASSISTANT:\"\n            )\n            self.prompt_no_input = (\n                self.system_no_input_prompt + \"USER: {instruction}\\nASSISTANT:\"\n            )\n            self.agent_label = (\n                \"\\nTHOUGHT: {output}\\nASSISTANT REFLECTION: {reflection}\\nASSISTANT:\"\n            )\n            self.response_split = \"ASSISTANT:\"\n\n    def _build_result(\n        self,\n        instruction: str,\n        input: Union[None, str] = None,\n        output: Union[None, str] = None,\n        reflection: Union[None, str] = None,\n        corrected: Union[None, str] = None,\n    ):\n        # returns the full prompt from instruction and optional input\n        # if a label (=response, =output) is provided, it's also appended.\n        if input:\n            res = self.prompt_input.format(instruction=instruction, input=input)\n        else:\n            res = self.prompt_no_input.format(instruction=instruction)\n        if output and reflection and corrected:\n            label = self.agent_label.format(\n                output=output,\n                reflection=reflection,\n                corrected=corrected,\n            )\n            res = f\"{res}{label}\"\n\n        return res\n\n    def build_prompt(\n        self,\n        instruction: str,\n        input: Union[None, str] = None,\n        output: Union[None, str] = None,\n        reflection: Union[None, str] = None,\n        corrected: Union[None, str] = None,\n    ) -> Generator[str, None, None]:\n        yield self._build_result(\n            instruction,\n            input,\n            output,\n            reflection,\n            corrected,\n        )\n\n    def __repr__(self) -> str:\n        return REPR_TEMPLATE.format(\n            full_prompt=self._build_result(\"{instruction}\", \"{input}\", \"{output}\")\n        )\n\n\nALTERNATING_ASSERTION_FAILED_ROLE = (\n    \"Role did not alternate between turns (gpt and human). Please check your data.\"\n)\n\n\nclass UnsupportedPrompter(Prompter):\n    \"\"\"\n    A dummy class for custom prompters\n    \"\"\"\n\n    def __init__(self) -> None:\n        pass\n\n    def __repr__(self):\n        return \"Pre-tokenized or custom dataset types are unsupported for logging\"\n"
  },
  {
    "path": "src/axolotl/scripts/__init__.py",
    "content": ""
  },
  {
    "path": "src/axolotl/scripts/vllm_serve_lora.py",
    "content": "\"\"\"vLLM serve script with native LoRA adapter support.\n\nExtends TRL's vllm_serve to enable direct LoRA adapter loading in vLLM,\ninstead of merging adapter weights into the base model before syncing.\n\nUsage:\n    Set ``vllm.serve_module: axolotl.scripts.vllm_serve_lora`` in your config,\n    or ``trl.vllm_lora_sync: true`` to auto-select.\n\nBenefits over merge-sync:\n    - Syncs only LoRA adapter weights via filesystem instead of full merged model via NCCL\n    - vLLM handles LoRA application natively (Punica kernels)\n    - No NCCL communicator needed for weight sync\n\"\"\"\n\nimport logging\nimport os\nfrom contextlib import asynccontextmanager\nfrom dataclasses import dataclass, field\nfrom itertools import chain\nfrom multiprocessing import Pipe, Process\nfrom multiprocessing.connection import Connection\nfrom typing import Any\n\nfrom trl.scripts.vllm_serve import (\n    ScriptArguments,\n    chunk_list,\n    extract_logprobs,\n    get_open_port,\n)\nfrom vllm import LLM, SamplingParams\nfrom vllm.lora.request import LoRARequest\n\nlogger = logging.getLogger(__name__)\n\n\n@dataclass\nclass LoRAScriptArguments(ScriptArguments):\n    \"\"\"Extended script arguments with LoRA support.\"\"\"\n\n    enable_lora: bool = field(\n        default=True,\n        metadata={\"help\": \"Enable LoRA adapter support in vLLM.\"},\n    )\n    max_lora_rank: int = field(\n        default=64,\n        metadata={\"help\": \"Maximum LoRA rank supported.\"},\n    )\n    max_loras: int = field(\n        default=2,\n        metadata={\"help\": \"Maximum number of LoRA adapters loaded simultaneously.\"},\n    )\n    lora_dtype: str = field(\n        default=\"bfloat16\",\n        metadata={\"help\": \"Data type for LoRA weights.\"},\n    )\n\n\ndef llm_worker(\n    script_args: LoRAScriptArguments,\n    data_parallel_rank: int,\n    master_port: int,\n    connection: Connection,\n) -> None:\n    \"\"\"Worker process that creates a vLLM LLM with LoRA enabled.\"\"\"\n    os.environ[\"VLLM_DP_RANK\"] = str(data_parallel_rank)\n    os.environ[\"VLLM_DP_RANK_LOCAL\"] = str(data_parallel_rank)\n    os.environ[\"VLLM_DP_SIZE\"] = str(script_args.data_parallel_size)\n    os.environ[\"VLLM_DP_MASTER_PORT\"] = str(master_port)\n\n    llm = LLM(\n        model=script_args.model,\n        revision=script_args.revision,\n        tensor_parallel_size=script_args.tensor_parallel_size,\n        gpu_memory_utilization=script_args.gpu_memory_utilization,\n        enforce_eager=script_args.enforce_eager,\n        dtype=script_args.dtype,\n        enable_prefix_caching=script_args.enable_prefix_caching,\n        kv_cache_dtype=script_args.kv_cache_dtype,\n        max_model_len=script_args.max_model_len,\n        # Use batch-capable worker extension (adds batch_update_named_params + auto-close)\n        worker_extension_cls=\"axolotl.scripts.vllm_worker_ext.BatchWeightSyncWorkerExtension\",\n        trust_remote_code=script_args.trust_remote_code,\n        model_impl=script_args.vllm_model_impl,\n        logprobs_mode=\"processed_logprobs\",\n        # LoRA\n        enable_lora=script_args.enable_lora,\n        max_lora_rank=script_args.max_lora_rank,\n        max_loras=script_args.max_loras,\n        lora_dtype=script_args.lora_dtype,\n    )\n\n    connection.send({\"status\": \"ready\"})\n\n    while True:\n        try:\n            command = connection.recv()\n        except KeyboardInterrupt:\n            llm.collective_rpc(method=\"close_communicator\")\n            break\n\n        if command[\"type\"] in [\"call\", \"fire_and_forget\"]:\n            method_name = command[\"method\"]\n            args = command.get(\"args\", ())\n            kwargs = command.get(\"kwargs\", {})\n\n            # Reconstruct LoRARequest from serialized dict (can't pickle across pipe)\n            if \"lora_request\" in kwargs and kwargs[\"lora_request\"] is not None:\n                lr = kwargs[\"lora_request\"]\n                kwargs[\"lora_request\"] = LoRARequest(\n                    lora_name=lr[\"lora_name\"],\n                    lora_int_id=lr[\"lora_int_id\"],\n                    lora_path=lr[\"lora_path\"],\n                    load_inplace=lr.get(\"load_inplace\", False),\n                )\n\n            method = getattr(llm, method_name)\n            result = method(*args, **kwargs)\n            if command[\"type\"] == \"call\":\n                connection.send(result)\n        elif command[\"type\"] == \"shutdown\":\n            break\n\n\ndef main(script_args: ScriptArguments):\n    \"\"\"Start vLLM workers with LoRA support and the HTTP server.\"\"\"\n    import asyncio\n\n    import uvicorn\n    from fastapi import FastAPI\n    from pydantic import BaseModel, Field as PydanticField\n\n    # Request/Response models (defined locally like TRL's vllm_serve.main)\n    class GenerateRequest(BaseModel):\n        prompts: list[str]\n        images: list[str] | None = None\n        n: int = 1\n        repetition_penalty: float = 1.0\n        temperature: float = 1.0\n        top_p: float = 1.0\n        top_k: int = -1\n        min_p: float = 0.0\n        max_tokens: int = 16\n        logprobs: int | None = 0\n        truncate_prompt_tokens: int | None = None\n        structured_outputs_regex: str | None = None\n        generation_kwargs: dict = PydanticField(default_factory=dict)\n\n    class GenerateResponse(BaseModel):\n        prompt_ids: list[list[int]]\n        completion_ids: list[list[int]]\n        logprobs: list[list[list[float]]]\n        logprob_token_ids: list[list[list[int]]]\n\n    class ChatRequest(BaseModel):\n        messages: list[list[dict]]\n        n: int = 1\n        repetition_penalty: float = 1.0\n        temperature: float = 1.0\n        top_p: float = 1.0\n        top_k: int = -1\n        min_p: float = 0.0\n        max_tokens: int = 16\n        logprobs: int | None = 0\n        truncate_prompt_tokens: int | None = None\n        structured_outputs_regex: str | None = None\n        generation_kwargs: dict = PydanticField(default_factory=dict)\n        chat_template_kwargs: dict = PydanticField(default_factory=dict)\n\n    class ChatResponse(BaseModel):\n        prompt_ids: list[list[int]]\n        completion_ids: list[list[int]]\n        logprobs: list[list[list[float]]]\n        logprob_token_ids: list[list[list[int]]]\n\n    class InitCommunicatorRequest(BaseModel):\n        host: str\n        port: int\n        world_size: int\n        client_device_uuid: str\n\n    # Wrap plain ScriptArguments with LoRA defaults\n    if not isinstance(script_args, LoRAScriptArguments):\n        lora_args = LoRAScriptArguments.__new__(LoRAScriptArguments)\n        for f in ScriptArguments.__dataclass_fields__:\n            setattr(lora_args, f, getattr(script_args, f))\n        # Apply LoRA defaults\n        for f in LoRAScriptArguments.__dataclass_fields__:\n            if f not in ScriptArguments.__dataclass_fields__:\n                setattr(\n                    lora_args, f, LoRAScriptArguments.__dataclass_fields__[f].default\n                )\n        script_args = lora_args\n\n    # Spawn workers\n    master_port = get_open_port()\n    connections: list[Connection] = []\n    processes: list[Process] = []\n    for dp_rank in range(script_args.data_parallel_size):\n        parent_conn, child_conn = Pipe()\n        process = Process(\n            target=llm_worker,\n            args=(script_args, dp_rank, master_port, child_conn),\n        )\n        process.start()\n        connections.append(parent_conn)\n        processes.append(process)\n\n    @asynccontextmanager\n    async def lifespan(app: FastAPI):\n        import time\n\n        startup_timeout = 300  # 5 minutes\n        start_time = time.monotonic()\n        ready: set[int] = set()\n        while len(ready) < script_args.data_parallel_size:\n            elapsed = time.monotonic() - start_time\n            if elapsed > startup_timeout:\n                raise RuntimeError(\n                    f\"vLLM workers failed to start within {startup_timeout}s \"\n                    f\"({len(ready)}/{script_args.data_parallel_size} ready)\"\n                )\n            for i, (conn, proc) in enumerate(zip(connections, processes, strict=True)):\n                if id(conn) in ready:\n                    continue\n                if not proc.is_alive():\n                    raise RuntimeError(\n                        f\"vLLM worker {i} exited unexpectedly during startup\"\n                    )\n                if conn.poll():\n                    msg = conn.recv()\n                    if isinstance(msg, dict) and msg.get(\"status\") == \"ready\":\n                        ready.add(id(conn))\n            await asyncio.sleep(0.1)\n        yield\n        for p in processes:\n            p.join(timeout=10)\n            if p.is_alive():\n                p.terminate()\n                p.join()\n\n    app = FastAPI(lifespan=lifespan)\n\n    # --- Active LoRA state (shared across endpoints via closure) ---\n    active_lora: dict = {\"request\": None}\n\n    # ------------------------------------------------------------------\n    # LoRA-specific endpoints\n    # ------------------------------------------------------------------\n\n    class SetLoRARequest(BaseModel):\n        lora_name: str\n        lora_int_id: int\n        lora_path: str\n        load_inplace: bool = False\n\n    @app.post(\"/set_lora_adapter/\")\n    async def set_lora_adapter(request: SetLoRARequest):\n        \"\"\"Register a LoRA adapter for all subsequent generate/chat calls.\"\"\"\n        active_lora[\"request\"] = {\n            \"lora_name\": request.lora_name,\n            \"lora_int_id\": request.lora_int_id,\n            \"lora_path\": request.lora_path,\n            \"load_inplace\": request.load_inplace,\n        }\n        logger.info(\n            \"Set active LoRA: %s (id=%d, path=%s)\",\n            request.lora_name,\n            request.lora_int_id,\n            request.lora_path,\n        )\n        return {\"status\": \"ok\"}\n\n    @app.post(\"/clear_lora_adapter/\")\n    async def clear_lora_adapter():\n        \"\"\"Clear active LoRA adapter (revert to base model).\"\"\"\n        active_lora[\"request\"] = None\n        return {\"status\": \"ok\"}\n\n    # ------------------------------------------------------------------\n    # Standard endpoints (mirrors TRL's vllm_serve)\n    # ------------------------------------------------------------------\n\n    @app.get(\"/health/\")\n    async def health():\n        return {\"status\": \"ok\"}\n\n    @app.get(\"/get_world_size/\")\n    async def get_world_size():\n        return {\n            \"world_size\": script_args.tensor_parallel_size\n            * script_args.data_parallel_size\n        }\n\n    @app.post(\"/generate/\", response_model=GenerateResponse)\n    async def generate(request: GenerateRequest):\n        \"\"\"Generate completions with optional LoRA adapter.\"\"\"\n        import base64\n        from io import BytesIO\n\n        import vllm\n        from packaging.version import Version\n        from vllm.sampling_params import GuidedDecodingParams\n\n        images: list[str | None] = request.images or [None] * len(request.prompts)  # type: ignore[assignment,list-item]\n        prompts: list[dict[str, Any]] = []\n        for prompt, image in zip(request.prompts, images, strict=True):\n            row: dict[str, Any] = {\"prompt\": prompt}\n            if image is not None:\n                from PIL import Image\n\n                row[\"multi_modal_data\"] = {\n                    \"image\": Image.open(BytesIO(base64.b64decode(image)))\n                }\n            prompts.append(row)\n\n        generation_kwargs = {\n            \"n\": request.n,\n            \"repetition_penalty\": request.repetition_penalty,\n            \"temperature\": request.temperature,\n            \"top_p\": request.top_p,\n            \"top_k\": request.top_k,\n            \"min_p\": request.min_p,\n            \"max_tokens\": request.max_tokens,\n            \"logprobs\": request.logprobs,\n        }\n        generation_kwargs.update(request.generation_kwargs)\n\n        if Version(vllm.__version__) <= Version(\"0.10.2\"):\n            key = \"guided_decoding\"\n            if request.structured_outputs_regex is not None:\n                generation_kwargs[key] = GuidedDecodingParams(\n                    regex=request.structured_outputs_regex\n                )\n            else:\n                generation_kwargs.setdefault(key, None)\n        else:\n            from vllm.sampling_params import StructuredOutputsParams\n\n            key = \"structured_outputs\"\n            if request.structured_outputs_regex is not None:\n                generation_kwargs[key] = StructuredOutputsParams(\n                    regex=request.structured_outputs_regex\n                )\n            elif isinstance(generation_kwargs.get(key), dict):\n                generation_kwargs[key] = StructuredOutputsParams(\n                    **generation_kwargs[key]\n                )\n            else:\n                generation_kwargs.setdefault(key, None)\n\n        sampling_params = SamplingParams(**generation_kwargs)\n        chunked_prompts = chunk_list(prompts, script_args.data_parallel_size)\n\n        for conn, chunk in zip(connections, chunked_prompts, strict=True):\n            if not chunk:\n                chunk = [{\"prompt\": \"<placeholder>\"}]\n            kwargs = {\n                \"prompts\": chunk,\n                \"sampling_params\": sampling_params,\n                \"lora_request\": active_lora[\"request\"],\n            }\n            conn.send({\"type\": \"call\", \"method\": \"generate\", \"kwargs\": kwargs})\n\n        all_outputs = [conn.recv() for conn in connections]\n        all_outputs = [\n            o for o, c in zip(all_outputs, chunked_prompts, strict=True) if c\n        ]\n        all_outputs = list(chain.from_iterable(all_outputs))\n\n        return {\n            \"prompt_ids\": [o.prompt_token_ids for o in all_outputs],\n            \"completion_ids\": [\n                list(out.token_ids) for o in all_outputs for out in o.outputs\n            ],\n            \"logprobs\": extract_logprobs(all_outputs)[0],\n            \"logprob_token_ids\": extract_logprobs(all_outputs)[1],\n        }\n\n    @app.post(\"/chat/\", response_model=ChatResponse)\n    async def chat(request: ChatRequest):\n        \"\"\"Chat endpoint with optional LoRA adapter.\"\"\"\n        generation_kwargs = {\n            \"n\": request.n,\n            \"repetition_penalty\": request.repetition_penalty,\n            \"temperature\": request.temperature,\n            \"top_p\": request.top_p,\n            \"top_k\": request.top_k,\n            \"min_p\": request.min_p,\n            \"max_tokens\": request.max_tokens,\n            \"logprobs\": request.logprobs,\n        }\n        generation_kwargs.update(request.generation_kwargs)\n        sampling_params = SamplingParams(**generation_kwargs)\n        chunked = chunk_list(request.messages, script_args.data_parallel_size)\n        for conn, chunk in zip(connections, chunked, strict=True):\n            if not chunk:\n                chunk = [[{\"role\": \"user\", \"content\": \"<placeholder>\"}]]\n            kwargs = {\n                \"messages\": chunk,\n                \"sampling_params\": sampling_params,\n                \"use_tqdm\": False,\n                \"lora_request\": active_lora[\"request\"],\n            }\n            conn.send({\"type\": \"call\", \"method\": \"chat\", \"kwargs\": kwargs})\n\n        all_outputs = [conn.recv() for conn in connections]\n        all_outputs = [o for o, c in zip(all_outputs, chunked, strict=True) if c]\n        all_outputs = list(chain.from_iterable(all_outputs))\n\n        return {\n            \"prompt_ids\": [o.prompt_token_ids for o in all_outputs],\n            \"completion_ids\": [\n                list(out.token_ids) for o in all_outputs for out in o.outputs\n            ],\n            \"logprobs\": extract_logprobs(all_outputs)[0],\n            \"logprob_token_ids\": extract_logprobs(all_outputs)[1],\n        }\n\n    # --- Weight sync endpoints (legacy fallback, same as TRL) ---\n\n    @app.post(\"/init_communicator/\")\n    async def init_communicator(request: InitCommunicatorRequest):\n        world_size = (\n            script_args.tensor_parallel_size * script_args.data_parallel_size + 1\n        )\n        kwargs = {\n            \"method\": \"init_communicator\",\n            \"args\": (\n                request.host,\n                request.port,\n                world_size,\n                request.client_device_uuid,\n            ),\n        }\n        msg = {\"type\": \"fire_and_forget\", \"method\": \"collective_rpc\", \"kwargs\": kwargs}\n        loop = asyncio.get_running_loop()\n        await asyncio.gather(\n            *(loop.run_in_executor(None, c.send, msg) for c in connections)\n        )\n        return {\"message\": \"Initializing communicator\"}\n\n    class UpdateWeightsRequest(BaseModel):\n        name: str\n        dtype: str\n        shape: list[int]\n\n    @app.post(\"/update_named_param/\")\n    async def update_named_param(request: UpdateWeightsRequest):\n        kwargs = {\n            \"method\": \"update_named_param\",\n            \"args\": (request.name, request.dtype, tuple(request.shape)),\n        }\n        msg = {\"type\": \"fire_and_forget\", \"method\": \"collective_rpc\", \"kwargs\": kwargs}\n        loop = asyncio.get_running_loop()\n        await asyncio.gather(\n            *(loop.run_in_executor(None, c.send, msg) for c in connections)\n        )\n        return {\"message\": \"Updating parameter\"}\n\n    class BatchUpdateWeightsRequest(BaseModel):\n        params: list[dict]\n\n    @app.post(\"/batch_update_named_params/\")\n    async def batch_update_named_params(request: BatchUpdateWeightsRequest):\n        params_list = [\n            (p[\"name\"], p[\"dtype\"], tuple(p[\"shape\"])) for p in request.params\n        ]\n        kwargs = {\"method\": \"batch_update_named_params\", \"args\": (params_list,)}\n        msg = {\"type\": \"fire_and_forget\", \"method\": \"collective_rpc\", \"kwargs\": kwargs}\n        loop = asyncio.get_running_loop()\n        await asyncio.gather(\n            *(loop.run_in_executor(None, c.send, msg) for c in connections)\n        )\n        return {\"message\": f\"Batch update for {len(params_list)} params\"}\n\n    @app.post(\"/reset_prefix_cache/\")\n    async def reset_prefix_cache():\n        for conn in connections:\n            conn.send({\"type\": \"call\", \"method\": \"reset_prefix_cache\"})\n        results = [conn.recv() for conn in connections]\n        return {\"message\": f\"Reset prefix cache: {all(results)}\"}\n\n    @app.post(\"/close_communicator/\")\n    async def close_communicator():\n        kwargs = {\"method\": \"close_communicator\"}\n        for conn in connections:\n            conn.send(\n                {\n                    \"type\": \"fire_and_forget\",\n                    \"method\": \"collective_rpc\",\n                    \"kwargs\": kwargs,\n                }\n            )\n        return {\"message\": \"Closing communicator\"}\n\n    uvicorn.run(\n        app,\n        host=script_args.host,\n        port=script_args.port,\n        log_level=script_args.log_level,\n        access_log=True,\n    )\n"
  },
  {
    "path": "src/axolotl/scripts/vllm_worker_ext.py",
    "content": "\"\"\"Extended vLLM worker extension with batch weight sync support.\n\nSubclasses TRL's WeightSyncWorkerExtension to add:\n- batch_update_named_params: receives multiple params in one call\n- Auto-close stale communicator on re-init\n- _direct_set_weight: proper handling for stacked (qkv_proj, gate_up_proj) params,\n  including LoRA-wrapped models where vLLM inserts base_layer into the hierarchy\n\"\"\"\n\nimport logging\n\nimport torch\n\ntry:\n    from transformers import is_torch_xpu_available\nexcept ImportError:\n    is_torch_xpu_available = lambda: False  # noqa: E731\n\nfrom trl.scripts.vllm_serve import WeightSyncWorkerExtension\n\nlogger = logging.getLogger(__name__)\n\n# Stacked param name mapping: shard_name -> (packed_name, shard_order)\n_STACKED_PARAMS = {\n    \"q_proj\": (\"qkv_proj\", 0),\n    \"k_proj\": (\"qkv_proj\", 1),\n    \"v_proj\": (\"qkv_proj\", 2),\n    \"gate_proj\": (\"gate_up_proj\", 0),\n    \"up_proj\": (\"gate_up_proj\", 1),\n}\n\n\nclass BatchWeightSyncWorkerExtension(WeightSyncWorkerExtension):\n    \"\"\"Worker extension that adds batch weight update and direct weight setting.\"\"\"\n\n    def init_communicator(self, host, port, world_size, client_device_uuid):\n        \"\"\"Auto-close stale communicator before re-initializing.\"\"\"\n        if self.communicator is not None:\n            self.close_communicator()\n        super().init_communicator(host, port, world_size, client_device_uuid)\n\n    def _direct_set_weight(self, name: str, weight: torch.Tensor) -> None:\n        \"\"\"Directly copy weight data into the model, handling stacked params.\n\n        Bypasses model.load_weights() which may fail on vLLM 0.17's new\n        module-tree weight loader for stacked params (qkv_proj, gate_up_proj).\n\n        Handles LoRA-wrapped params where vLLM inserts ``base_layer`` into the\n        parameter hierarchy (e.g. ``qkv_proj.base_layer.weight``).\n        \"\"\"\n        model = self.model_runner.model\n        params_dict = dict(model.named_parameters())\n\n        # Check if this is a simple direct param (exists as-is)\n        if name in params_dict:\n            params_dict[name].data.copy_(weight.to(params_dict[name].dtype))\n            return\n\n        # Also check with base_layer inserted: x.y.weight -> x.y.base_layer.weight\n        parts_bl = name.rsplit(\".\", 1)\n        if len(parts_bl) == 2:\n            base_layer_name = f\"{parts_bl[0]}.base_layer.{parts_bl[1]}\"\n            if base_layer_name in params_dict:\n                params_dict[base_layer_name].data.copy_(\n                    weight.to(params_dict[base_layer_name].dtype)\n                )\n                return\n\n        # Handle stacked params: e.g. \"model.layers.0.self_attn.q_proj.weight\"\n        # -> \"model.layers.0.self_attn.qkv_proj.weight\" with shard offset\n        parts = name.rsplit(\".\", 2)  # [prefix, layer_name, suffix]\n        if len(parts) == 3:\n            prefix, layer_name, suffix = parts\n            if layer_name in _STACKED_PARAMS:\n                packed_name, shard_idx = _STACKED_PARAMS[layer_name]\n                for packed_full in [\n                    f\"{prefix}.{packed_name}.{suffix}\",\n                    f\"{prefix}.{packed_name}.base_layer.{suffix}\",\n                ]:\n                    if packed_full not in params_dict:\n                        continue\n                    param = params_dict[packed_full]\n                    # Navigate to the packed module to find shard sizes\n                    module_path = packed_full.rsplit(\".\", 1)[0]  # strip .weight/.bias\n                    if \".base_layer\" in module_path:\n                        module_path = module_path.replace(\".base_layer\", \"\")\n                    module = model\n                    for attr in module_path.split(\".\"):\n                        module = getattr(module, attr, None)\n                        if module is None:\n                            break\n                    # LoRA wrappers don't have output_sizes directly;\n                    # check base_layer for the underlying parallel linear\n                    if module is not None and not hasattr(module, \"output_sizes\"):\n                        base = getattr(module, \"base_layer\", None)\n                        if base is not None and hasattr(base, \"output_sizes\"):\n                            module = base\n                    if module is not None and hasattr(module, \"output_sizes\"):\n                        tp_size = getattr(module, \"tp_size\", 1)\n                        sizes = [s // tp_size for s in module.output_sizes]\n                        offset = sum(sizes[:shard_idx])\n                        shard_size = sizes[shard_idx]\n                        param.data[offset : offset + shard_size].copy_(\n                            weight.to(param.dtype)\n                        )\n                        return\n\n        # Fallback: try load_weights (may work for non-stacked params)\n        logger.warning(\"Falling back to load_weights for param: %s\", name)\n        model.load_weights(weights=[(name, weight)])\n\n    def update_named_param(self, name, dtype, shape):\n        \"\"\"Override to use _direct_set_weight instead of load_weights.\"\"\"\n        if self.communicator is None:\n            raise RuntimeError(\"Communicator not initialized.\")\n\n        dtype = getattr(torch, dtype.split(\".\")[-1])\n        weight = torch.empty(shape, dtype=dtype, device=self.device)\n\n        if is_torch_xpu_available():\n            self.communicator.broadcast(weight, root=self.client_rank)\n            self.communicator.barrier()\n        else:\n            self.communicator.broadcast(weight, src=self.client_rank)\n            self.communicator.group.barrier()\n\n        self._direct_set_weight(name, weight)\n\n    def batch_update_named_params(self, params_list: list[tuple[str, str, tuple]]):\n        \"\"\"Receive and apply multiple weight tensors in sequence.\n\n        Args:\n            params_list: List of (name, dtype_str, shape) tuples.\n        \"\"\"\n        if self.communicator is None:\n            raise RuntimeError(\"Communicator not initialized.\")\n\n        weights_to_load = []\n        for name, dtype_str, shape in params_list:\n            dtype = getattr(torch, dtype_str.split(\".\")[-1])\n            weight = torch.empty(shape, dtype=dtype, device=self.device)\n\n            if is_torch_xpu_available():\n                self.communicator.broadcast(weight, root=self.client_rank)\n            else:\n                self.communicator.broadcast(weight, src=self.client_rank)\n\n            weights_to_load.append((name, weight))\n\n        # Single barrier after all broadcasts\n        if is_torch_xpu_available():\n            self.communicator.barrier()\n        else:\n            self.communicator.group.barrier()\n\n        # Load weights using direct set (handles stacked params)\n        for name, weight in weights_to_load:\n            self._direct_set_weight(name, weight)\n"
  },
  {
    "path": "src/axolotl/telemetry/__init__.py",
    "content": ""
  },
  {
    "path": "src/axolotl/telemetry/callbacks.py",
    "content": "\"\"\"Trainer callbacks for reporting runtime metrics at regular intervals.\"\"\"\n\nimport logging\nimport time\n\nfrom transformers import (\n    TrainerCallback,\n    TrainerControl,\n    TrainerState,\n    TrainingArguments,\n)\n\nfrom axolotl.telemetry.manager import TelemetryManager\nfrom axolotl.telemetry.runtime_metrics import RuntimeMetricsTracker\n\nLOG = logging.getLogger(__name__)\n\nTIME_SINCE_LAST = 60\n\n\nclass TelemetryCallback(TrainerCallback):\n    \"\"\"\n    Trainer callback for tracking and reporting runtime metrics.\n\n    This callback tracks training progress, runtime, and memory usage,\n    sending telemetry at configurable intervals.\n    \"\"\"\n\n    report_interval_steps: int = 100\n\n    def __init__(self):\n        \"\"\"Initialize the metrics callback.\"\"\"\n        self.tracker = RuntimeMetricsTracker()\n        self.telemetry_manager = TelemetryManager.get_instance()\n        self.current_epoch = -1\n        self.start_time = time.time()\n        self.last_report_time = None\n        self.last_report_step = 0\n\n    # pylint: disable=unused-argument\n    def on_train_begin(\n        self,\n        args: TrainingArguments,\n        state: TrainerState,\n        control: TrainerControl,\n        **kwargs,\n    ):\n        \"\"\"Handle training start.\"\"\"\n        self.telemetry_manager.send_event(event_type=\"train-start\")\n\n    # pylint: disable=unused-argument\n    def on_train_end(\n        self,\n        args: TrainingArguments,\n        state: TrainerState,\n        control: TrainerControl,\n        **kwargs,\n    ):\n        \"\"\"Handle training end.\"\"\"\n        # Send training completion event\n        self.telemetry_manager.send_event(\n            event_type=\"train-end\",\n            properties=self._extract_last_metrics(state)\n            | self.tracker.metrics.to_dict(),\n        )\n\n    # pylint: disable=unused-argument\n    def on_epoch_begin(\n        self,\n        args: TrainingArguments,\n        state: TrainerState,\n        control: TrainerControl,\n        **kwargs,\n    ):\n        \"\"\"Handle epoch start.\"\"\"\n        self.current_epoch += 1\n        self.tracker.start_epoch(self.current_epoch)\n\n    # pylint: disable=unused-argument\n    def on_epoch_end(\n        self,\n        args: TrainingArguments,\n        state: TrainerState,\n        control: TrainerControl,\n        **kwargs,\n    ):\n        \"\"\"Handle epoch end.\"\"\"\n        self.tracker.end_epoch(self.current_epoch)\n\n    # pylint: disable=unused-argument\n    def on_step_end(\n        self,\n        args: TrainingArguments,\n        state: TrainerState,\n        control: TrainerControl,\n        **kwargs,\n    ):\n        \"\"\"Handle step end.\"\"\"\n        step = state.global_step\n        self.tracker.update_step(step)\n\n        # Check if we should report metrics\n        should_report = (\n            step % self.report_interval_steps == 0\n            or step == 1  # Always report first step\n            or step - self.last_report_step >= self.report_interval_steps\n        )\n\n        if should_report:\n            current_time = time.time()\n            if self.last_report_time is not None:\n                time_since_last_report = current_time - self.last_report_time\n            else:\n                time_since_last_report = current_time - self.start_time\n            steps_since_last_report = step - self.last_report_step\n\n            # Only report if enough time has passed\n            if (\n                step == 1\n                or time_since_last_report >= TIME_SINCE_LAST\n                or steps_since_last_report >= self.report_interval_steps\n            ):\n                # Calculate steps per second for this interval\n                if time_since_last_report > 0 and steps_since_last_report > 0:\n                    steps_per_second = steps_since_last_report / time_since_last_report\n                else:\n                    steps_per_second = 0\n\n                # Update memory metrics\n                self.tracker.update_memory_metrics()\n\n                # Prepare metrics to report\n                metrics = self._extract_last_metrics(state) | {\n                    \"step\": step,\n                    \"epoch\": self.current_epoch,\n                    \"progress\": state.epoch,  # Fractional epoch progress\n                    \"steps_per_second\": steps_per_second,\n                    \"elapsed_time\": current_time - self.start_time,\n                    \"time_since_last_report\": time_since_last_report,\n                }\n\n                # Add memory metrics\n                memory_metrics = self.tracker.get_memory_metrics()\n                metrics.update({\"memory\": memory_metrics})\n\n                # Send telemetry\n                self.telemetry_manager.send_event(\n                    event_type=\"train-progress\", properties=metrics\n                )\n\n                # Update last report time and step\n                self.last_report_time = current_time\n                self.last_report_step = step\n\n    def _extract_last_metrics(self, state: TrainerState) -> dict:\n        \"\"\"Extract last loss, learning_rate, grad_norm, and token metrics from log history.\"\"\"\n        if not state.log_history:\n            return {\n                \"loss\": 0,\n                \"ppl\": 0,\n                \"learning_rate\": 0,\n                \"grad_norm\": 0,\n                \"tokens/total\": 0,\n                \"tokens/trainable\": 0,\n                \"tokens/train_per_sec_per_gpu\": 0,\n            }\n\n        last_log = state.log_history[-1]\n        return {\n            \"loss\": last_log.get(\"loss\", 0),\n            \"ppl\": last_log.get(\"ppl\", 0),\n            \"learning_rate\": last_log.get(\"learning_rate\", 0),\n            \"grad_norm\": last_log.get(\"grad_norm\", 0),\n            \"tokens/total\": last_log.get(\"tokens/total\", 0),\n            \"tokens/trainable\": last_log.get(\"tokens/trainable\", 0),\n            \"tokens/train_per_sec_per_gpu\": last_log.get(\n                \"tokens/train_per_sec_per_gpu\", 0\n            ),\n        }\n"
  },
  {
    "path": "src/axolotl/telemetry/errors.py",
    "content": "\"\"\"Telemetry utilities for exception and traceback information.\"\"\"\n\nimport logging\nimport os\nimport re\nimport traceback\nfrom functools import wraps\nfrom inspect import getmodule\nfrom typing import Any, Callable\n\nfrom axolotl.telemetry.manager import TelemetryManager\n\nLOG = logging.getLogger(__name__)\n\nERROR_HANDLED = False\n\n\ndef sanitize_stack_trace(stack_trace: str) -> str:\n    \"\"\"\n    Remove personal information from stack trace messages while keeping Python package codepaths.\n\n    This function identifies Python packages by looking for common patterns in virtual environment\n    and site-packages directories, preserving the package path while removing user-specific paths.\n\n    Args:\n        stack_trace: The original stack trace string.\n\n    Returns:\n        A sanitized version of the stack trace with Python package paths preserved.\n    \"\"\"\n    # Split the stack trace into lines to process each file path separately\n    lines = stack_trace.split(\"\\n\")\n    sanitized_lines = []\n\n    # Regular expression to find file paths in the stack trace\n    path_pattern = re.compile(r'(?:File \")(.*?)(?:\")')\n\n    # Regular expression to identify paths in site-packages or dist-packages\n    # This matches path segments like \"site-packages/package_name\" or \"dist-packages/package_name\"\n    site_packages_pattern = re.compile(\n        r\"(?:site-packages|dist-packages)[/\\\\]([\\w\\-\\.]+)\"\n    )\n\n    # Additional common virtual environment patterns\n    venv_lib_pattern = re.compile(\n        r\"(?:lib|Lib)[/\\\\](?:python\\d+(?:\\.\\d+)?[/\\\\])?(?:site-packages|dist-packages)[/\\\\]([\\w\\-\\.]+)\"\n    )\n\n    for line in lines:\n        # Check if this line contains a file path\n        path_match = path_pattern.search(line)\n\n        if path_match:\n            full_path = path_match.group(1)\n            sanitized_path = \"\"\n\n            # Try to match site-packages pattern\n            site_packages_match = site_packages_pattern.search(full_path)\n            venv_lib_match = venv_lib_pattern.search(full_path)\n\n            if site_packages_match:\n                # Find the index where the matched pattern starts\n                idx = full_path.find(\"site-packages\")\n                if idx == -1:\n                    idx = full_path.find(\"dist-packages\")\n\n                # Keep from 'site-packages' onward\n                if idx >= 0:\n                    sanitized_path = full_path[idx:]\n            elif venv_lib_match:\n                # For other virtual environment patterns, find the package directory\n                match_idx = venv_lib_match.start(1)\n                if match_idx > 0:\n                    # Keep from the package name onward\n                    package_name = venv_lib_match.group(1)\n                    idx = full_path.rfind(\n                        package_name, 0, match_idx + len(package_name)\n                    )\n                    if idx >= 0:\n                        sanitized_path = full_path[idx:]\n\n            # If we couldn't identify a package pattern but path contains 'axolotl'\n            elif \"axolotl\" in full_path:\n                idx = full_path.rfind(\"axolotl\")\n                if idx >= 0:\n                    sanitized_path = full_path[idx:]\n\n            # Apply the sanitization to the line\n            if sanitized_path:\n                line = line.replace(full_path, sanitized_path)\n            else:\n                # If we couldn't identify a package pattern, just keep the filename\n                filename = os.path.basename(full_path)\n                if filename:\n                    line = line.replace(full_path, filename)\n                else:\n                    line = line.replace(full_path, \"\")\n\n        sanitized_lines.append(line)\n\n    return \"\\n\".join(sanitized_lines)\n\n\ndef send_errors(func: Callable) -> Callable:\n    \"\"\"\n    Decorator to send exception info in a function. If an exception is raised, we send\n    telemetry containing the stack trace and error message.\n\n    If an error occurs in a decorated function that is called by another decorated\n    function, we'll only send telemetry corresponding to the lower-level function.\n\n    Args:\n        func: Function to decorate.\n\n    Returns:\n        Decorated function.\n    \"\"\"\n\n    @wraps(func)\n    def wrapper(*args, **kwargs) -> Any:\n        telemetry_manager = TelemetryManager.get_instance()\n\n        if not telemetry_manager.enabled:\n            return func(*args, **kwargs)\n\n        try:\n            return func(*args, **kwargs)\n        except Exception as exception:\n            # Only track if we're not already handling an error. This prevents us from\n            # capturing an error more than once in nested decorated function calls.\n            global ERROR_HANDLED  # pylint: disable=global-statement\n            if not ERROR_HANDLED:\n                ERROR_HANDLED = True\n\n                # Get function module path\n                module = getmodule(func)\n                module_path = (\n                    f\"{module.__name__}.{func.__name__}\" if module else func.__name__\n                )\n\n                # Get stack trace\n                stack_trace = \"\".join(\n                    traceback.format_exception(\n                        type(exception), exception, exception.__traceback__\n                    )\n                )\n                stack_trace = sanitize_stack_trace(stack_trace)\n\n                # Send error telemetry\n                telemetry_manager.send_event(\n                    event_type=f\"{module_path}-error\",\n                    properties={\n                        \"exception\": str(exception),\n                        \"stack_trace\": stack_trace,\n                    },\n                )\n\n                LOG.error(\n                    f\"Error captured in telemetry. Run ID: {telemetry_manager.run_id}\"\n                )\n\n            raise\n\n    return wrapper\n"
  },
  {
    "path": "src/axolotl/telemetry/manager.py",
    "content": "\"\"\"Telemetry manager and associated utilities.\"\"\"\n\nimport atexit\nimport importlib\nimport logging\nimport os\nimport platform\nimport uuid\nfrom pathlib import Path\nfrom typing import Any\n\nimport posthog\nimport psutil\nimport torch\nimport yaml\n\nLOG = logging.getLogger(__name__)\n\nPOSTHOG_HOST = \"https://app.posthog.com\"\nPOSTHOG_WRITE_KEY = \"phc_1kUR0o04oJKKTTeSsIz2Mfm5mpiVsQEf2WOlzljMD7y\"\n\nWHITELIST_PATH = str(Path(__file__).parent / \"whitelist.yaml\")\n\n# NOTE: Need to keep these up to date with any config schema changes\nFIELDS_TO_REDACT = {\n    \"base_model\",\n    \"tokenizer_config\",\n    \"base_model_config\",\n    \"pretraining_dataset\",  # NOTE: this field may be a string or a dictionary\n    \"resume_from_checkpoint\",\n    \"hub_model_id\",\n}\nPREFIXES_TO_REDACT = {\"wandb_\", \"comet_\", \"mlflow_\", \"gradio_\", \"trackio_\", \"swanlab_\"}\nPATH_INDICATORS = {\"path\", \"dir\", \"data_files\"}\n\n# pylint: disable=duplicate-code\nRELEVANT_PACKAGES = {\n    \"torch\",\n    \"transformers\",\n    \"trl\",\n    \"datasets\",\n    \"peft\",\n    \"bitsandbytes\",\n    \"accelerate\",\n    \"optimum\",\n    \"deepspeed\",\n    \"ray\",\n    \"axolotl\",\n    \"triton\",\n    \"mamba-ssm\",\n    \"flash-attn\",\n    \"xformers\",\n    \"autoawq\",\n    \"tokenizers\",\n    \"sentencepiece\",\n    \"torchao\",\n    \"lm_eval\",\n}\n\n\ndef is_main_process() -> bool:\n    \"\"\"\n    Check whether we're running in the main process.\n\n    Note:\n        We're using this function instead of `torch.utils.distributed.is_main_process`\n        causes issues with DeepSpeed world_size since. This function avoids that issue\n        by checking env vars that are set by various launchers.\n\n    Returns:\n        Whether we're running in the main process.\n    \"\"\"\n    # If PyTorch distributed is already initialized, use it\n    if torch.distributed.is_initialized():\n        return torch.distributed.get_rank() == 0\n\n    # Otherwise check environment variables for global rank\n    # NOTE: need to verify this in SLURM / OpenMPI environments\n    global_rank = int(\n        os.environ.get(\n            \"RANK\",\n            os.environ.get(\n                \"GLOBAL_RANK\",\n                os.environ.get(\n                    \"SLURM_PROCID\",\n                    os.environ.get(\n                        \"OMPI_COMM_WORLD_RANK\",\n                        \"0\",\n                    ),\n                ),\n            ),\n        )\n    )\n\n    return global_rank == 0\n\n\nclass TelemetryManager:\n    \"\"\"Manages telemetry collection and transmission\"\"\"\n\n    _instance = None\n    _initialized = False\n\n    def __new__(cls):\n        \"\"\"\n        Telemetry manager constructor. Creates the singleton instance of this class if\n        it doesn't already exist.\n        \"\"\"\n        if cls._instance is None:\n            cls._instance = super(TelemetryManager, cls).__new__(cls)\n            cls._instance._initialized = False\n\n        return cls._instance\n\n    def __init__(self):\n        \"\"\"Telemetry manager initializer\"\"\"\n        if self._initialized:\n            return\n\n        self.enabled = self._check_telemetry_enabled()\n\n        if self.enabled:\n            self.run_id = str(uuid.uuid4())\n            self.whitelist = self._load_whitelist()\n\n            try:\n                self.system_info = self._get_system_info()\n            except Exception as e:  # pylint: disable=broad-exception-caught\n                LOG.warning(f\"Error during system info collection: {e}\")\n                self.system_info = None\n\n            self._init_posthog()\n\n            # Register shutdown method to flush posthog telemetry\n            atexit.register(self.shutdown)\n\n        self._initialized = True\n\n    @classmethod\n    def get_instance(cls) -> \"TelemetryManager\":\n        if cls._instance is None:\n            cls._instance = TelemetryManager()\n\n        return cls._instance\n\n    def _check_telemetry_enabled(self) -> bool:\n        \"\"\"\n        Check if telemetry is enabled based on environment variables. We also check\n        whether this is the main process (for the distributed setting and to avoid\n        sending duplicate PostHog events per GPU).\n\n        Note: This is enabled by default on an opt-out basis. Set\n        `AXOLOTL_DO_NOT_TRACK=1` to disable telemetry. For more details, see\n        https://axolotl-ai-cloud.github.io/axolotl/docs/telemetry.html.\n\n        Returns:\n            Boolean denoting whether telemetry is enabled or not.\n        \"\"\"\n        # Only rank 0 will send telemetry\n        if not is_main_process():\n            return False\n\n        # Parse relevant env vars\n        axolotl_do_not_track = os.getenv(\"AXOLOTL_DO_NOT_TRACK\")\n        do_not_track = os.getenv(\"DO_NOT_TRACK\")\n\n        # Default to enabled (opt-out model)\n        if axolotl_do_not_track is None or axolotl_do_not_track.lower() not in (\n            \"0\",\n            \"1\",\n            \"false\",\n            \"true\",\n        ):\n            return True\n\n        if do_not_track is None:\n            do_not_track = \"0\"\n\n        # Respect AXOLOTL_DO_NOT_TRACK, DO_NOT_TRACK if enabled\n        enabled = axolotl_do_not_track.lower() not in (\n            \"1\",\n            \"true\",\n        ) and do_not_track.lower() not in (\"1\", \"true\")\n\n        return enabled\n\n    def _load_whitelist(self) -> dict:\n        \"\"\"Load HuggingFace Hub organization whitelist\"\"\"\n        with open(WHITELIST_PATH, encoding=\"utf-8\") as f:\n            whitelist = yaml.safe_load(f)\n\n            # Send org strings to lowercase since model names are case insensitive\n            whitelist[\"organizations\"] = {\n                org.lower() for org in whitelist[\"organizations\"]\n            }\n\n            return whitelist\n\n    def _is_whitelisted(self, value: str) -> bool:\n        \"\"\"\n        Check if model / dataset / etc. org is in whitelist.\n\n        Args:\n            value: Value for one of `axolotl.telemetry.manager.FIELDS_WITH_ORGS`\n                (\"base_model\", etc.).\n\n        Returns:\n            Boolean indicating whitelist membership.\n        \"\"\"\n        # NOTE: This membership-checking logic can be improved.\n        # What happens when a local model path matches a whitelisted org?\n        parts = value.split(\"/\")\n        if len(parts) < 2:\n            return False\n        org = parts[0]\n        whitelisted = org.lower() in self.whitelist[\"organizations\"]\n\n        return whitelisted\n\n    def _init_posthog(self):\n        \"\"\"Initialize PostHog client\"\"\"\n        posthog.api_key = POSTHOG_WRITE_KEY\n        posthog.project_api_key = POSTHOG_WRITE_KEY\n        posthog.host = POSTHOG_HOST\n\n    def _redact_paths(self, properties: dict[str, Any]) -> dict[str, Any]:\n        \"\"\"\n        Redact properties to remove any paths, so as to avoid inadvertently collecting\n        private or personally identifiable information (PII). We also remove\n        information related to Wandb, MLflow, etc. configuration.\n\n        Args:\n            properties: Dictionary of properties to redact.\n\n        Returns:\n            Properties dictionary with redaction applied.\n        \"\"\"\n        if not properties:\n            return {}\n\n        def redact_value(value: Any, key: str = \"\") -> Any:\n            \"\"\"Recursively sanitize values, redacting those with path-like keys\"\"\"\n            if isinstance(key, str) and isinstance(value, str):\n                # Other redaction special cases\n                if (\n                    key in FIELDS_TO_REDACT\n                    or any(prefix in key for prefix in PREFIXES_TO_REDACT)\n                    or any(indicator in key.lower() for indicator in PATH_INDICATORS)\n                ):\n                    # Fields with whitelisted orgs don't need to be redacted\n                    if not self._is_whitelisted(value):\n                        return \"[REDACTED]\"\n\n            # Handle nested values\n            if isinstance(value, dict):\n                return {k: redact_value(v, k) for k, v in value.items()}\n            if isinstance(value, list):\n                return [redact_value(item) for item in value]\n\n            return value\n\n        # Create new dict with redacted values\n        redacted = {k: redact_value(v, k) for k, v in properties.items()}\n\n        return redacted\n\n    def _get_system_info(self) -> dict[str, Any]:\n        \"\"\"Collect system information for various hardware accelerators\"\"\"\n        gpu_info = []\n        accelerator_type = \"none\"\n\n        # NVIDIA GPUs\n        if torch.cuda.is_available():\n            accelerator_type = \"cuda\"\n            for i in range(torch.cuda.device_count()):\n                gpu_info.append(\n                    {\n                        \"name\": torch.cuda.get_device_name(i),\n                        \"memory\": torch.cuda.get_device_properties(i).total_memory,\n                    }\n                )\n\n        # AMD GPUs\n        elif hasattr(torch, \"hip\") and torch.hip.is_available():\n            accelerator_type = \"hip\"\n            for i in range(torch.hip.device_count()):\n                gpu_info.append(\n                    {\n                        \"name\": torch.hip.get_device_name(i),\n                        \"memory\": (\n                            torch.hip.get_device_properties(i).total_memory\n                            if hasattr(torch.hip, \"get_device_properties\")\n                            else None\n                        ),\n                    }\n                )\n\n        # Apple Silicon\n        elif hasattr(torch.backends, \"mps\") and torch.backends.mps.is_available():\n            accelerator_type = \"mps\"\n            gpu_info.append(\n                {\n                    \"name\": \"Apple Silicon\",\n                    # NOTE: this is memory allocated to this process, not total memory\n                    \"memory\": torch.mps.driver_allocated_memory(),\n                }\n            )\n\n        # Intel GPUs\n        elif hasattr(torch, \"xpu\") and torch.xpu.is_available():\n            accelerator_type = \"xpu\"\n            for i in range(torch.xpu.device_count()):\n                memory = None\n                if hasattr(torch.xpu, \"get_device_properties\"):\n                    memory = torch.xpu.get_device_properties(i).total_memory\n\n                gpu_info.append(\n                    {\n                        \"name\": torch.xpu.get_device_name(i),\n                        \"memory\": memory,\n                    }\n                )\n\n        # NPUs\n        elif hasattr(torch, \"npu\") and torch.npu.is_available():\n            accelerator_type = \"npu\"\n            for i in range(torch.npu.device_count()):\n                memory = None\n                if hasattr(torch.npu, \"get_device_properties\"):\n                    memory = torch.npu.get_device_properties(i).total_memory\n\n                gpu_info.append(\n                    {\n                        \"name\": torch.npu.get_device_name(i),\n                        \"memory\": memory,\n                    }\n                )\n\n        # Get relevant package versions\n        installed_packages = {}\n        for package in RELEVANT_PACKAGES:\n            try:\n                version = importlib.metadata.version(package)\n                installed_packages[f\"{package}_version\"] = version\n            except importlib.metadata.PackageNotFoundError:\n                pass\n\n        return {\n            \"os\": platform.system(),\n            \"python_version\": platform.python_version(),\n            \"cpu_count\": psutil.cpu_count(),\n            \"memory_total\": psutil.virtual_memory().total,\n            \"accelerator_type\": accelerator_type,\n            \"accelerator_count\": len(gpu_info),\n            \"accelerator_info\": gpu_info,\n            **installed_packages,\n        }\n\n    def send_event(self, event_type: str, properties: dict[str, Any] | None = None):\n        \"\"\"Send a telemetry event\"\"\"\n        if not self.enabled:\n            return\n\n        if properties is None:\n            properties = {}\n\n        # Sanitize properties to remove PII\n        properties = self._redact_paths(properties)\n\n        # Wrap PostHog errors in try / except to not raise errors during Axolotl usage\n        try:\n            # Send event via PostHog\n            posthog.capture(\n                distinct_id=self.run_id,\n                event=event_type,\n                properties=properties,\n                disable_geoip=True,\n            )\n        except Exception as e:  # pylint: disable=broad-exception-caught\n            LOG.warning(f\"Failed to send telemetry event: {e}\")\n\n        # Additionally, send system info telemetry when loading config.\n        # NOTE: Is this the best place for this?\n        if event_type == \"config-loaded\":\n            self.send_system_info()\n\n    def send_system_info(self):\n        \"\"\"Helper method for sending system info\"\"\"\n        if self.system_info is not None:\n            self.send_event(event_type=\"system-info\", properties=self.system_info)\n\n    def shutdown(self):\n        \"\"\"Ensure all queued events are processed before shutdown\"\"\"\n        if self.enabled:\n            posthog.shutdown()\n"
  },
  {
    "path": "src/axolotl/telemetry/runtime_metrics.py",
    "content": "\"\"\"Telemetry utilities for runtime and memory metrics.\"\"\"\n\nimport logging\nimport time\nfrom dataclasses import dataclass, field\nfrom typing import Any\n\nimport psutil\nimport torch\n\nfrom axolotl.telemetry.manager import TelemetryManager\n\nLOG = logging.getLogger(__name__)\n\n\n@dataclass\nclass RuntimeMetrics:\n    \"\"\"Container for runtime metrics to be tracked throughout training.\"\"\"\n\n    # Timing metrics\n    start_time: float\n    epoch_start_times: dict[int, float] = field(init=False)\n    epoch_end_times: dict[int, float] = field(init=False)\n\n    # Memory metrics\n    peak_cpu_memory: int = 0\n    peak_gpu_memory: dict[int, int] = field(init=False)\n\n    # Progress metrics\n    total_steps: int = 0\n    current_epoch: int = 0\n    current_step: int = 0\n\n    def __post_init__(self):\n        \"\"\"Initialize empty metric mappings.\"\"\"\n        self.epoch_start_times = {}\n        self.epoch_end_times = {}\n        self.peak_gpu_memory = {}\n\n    @property\n    def elapsed_time(self) -> float:\n        \"\"\"Calculate total elapsed time in seconds.\"\"\"\n        return time.time() - self.start_time\n\n    def epoch_time(self, epoch: int) -> float | None:\n        \"\"\"Calculate time taken for a specific epoch in seconds.\"\"\"\n        if epoch in self.epoch_start_times and epoch in self.epoch_end_times:\n            return self.epoch_end_times[epoch] - self.epoch_start_times[epoch]\n\n        return None\n\n    def average_epoch_time(self) -> float | None:\n        \"\"\"Calculate average time per epoch in seconds.\"\"\"\n        completed_epochs = [\n            epoch for epoch in self.epoch_start_times if epoch in self.epoch_end_times\n        ]\n        if not completed_epochs:\n            return None\n\n        total_time = 0.0\n        for epoch in completed_epochs:\n            epoch_time = self.epoch_time(epoch)\n            if epoch_time is not None:  # Check to avoid mypy warning\n                total_time += epoch_time\n\n        return total_time / len(completed_epochs)\n\n    def steps_per_second(self) -> float | None:\n        \"\"\"Calculate average steps per second across all training.\"\"\"\n        if self.total_steps == 0 or self.elapsed_time == 0:\n            return None\n\n        return self.total_steps / self.elapsed_time\n\n    def to_dict(self) -> dict[str, Any]:\n        \"\"\"Convert metrics to a dictionary for telemetry reporting.\"\"\"\n        metrics = {\n            \"total_time_seconds\": self.elapsed_time,\n            \"total_steps\": self.total_steps,\n            \"steps_per_second\": self.steps_per_second(),\n            \"epochs_completed\": len(\n                [\n                    epoch\n                    for epoch in self.epoch_start_times\n                    if epoch in self.epoch_end_times\n                ]\n            ),\n            \"peak_cpu_memory_bytes\": self.peak_cpu_memory,\n        }\n\n        # Add per-epoch timing if available\n        epoch_times: dict[str, float] = {}\n        for epoch in sorted(self.epoch_end_times.keys()):\n            time_taken = self.epoch_time(epoch)\n            if time_taken is not None:\n                epoch_times[f\"epoch_{epoch}_seconds\"] = time_taken\n\n        if epoch_times:\n            metrics[\"epoch_times\"] = epoch_times  # type: ignore\n            metrics[\"average_epoch_time_seconds\"] = self.average_epoch_time()\n\n        # Add GPU memory metrics if available\n        if self.peak_gpu_memory:\n            gpu_metrics: dict[str, int] = {}\n            for gpu_id, memory in self.peak_gpu_memory.items():\n                gpu_metrics[f\"gpu_{gpu_id}_peak_memory_bytes\"] = memory\n            metrics[\"gpu_memory\"] = gpu_metrics  # type: ignore\n\n        return metrics\n\n\nclass RuntimeMetricsTracker:\n    \"\"\"Tracker for runtime metrics during training.\"\"\"\n\n    update_interval = 100\n\n    def __init__(self):\n        \"\"\"Initialize the runtime metrics tracker.\"\"\"\n        self.metrics = RuntimeMetrics(start_time=time.time())\n        self.telemetry_manager = TelemetryManager.get_instance()\n        self._process = psutil.Process()\n\n    def start_epoch(self, epoch: int):\n        \"\"\"Record the start of a new epoch.\"\"\"\n        self.metrics.current_epoch = epoch\n        self.metrics.epoch_start_times[epoch] = time.time()\n        self.update_memory_metrics()\n\n    def end_epoch(self, epoch: int):\n        \"\"\"Record the end of an epoch.\"\"\"\n        self.metrics.epoch_end_times[epoch] = time.time()\n\n    def update_step(self, step: int):\n        \"\"\"Update the current step count.\"\"\"\n        self.metrics.current_step = step\n        self.metrics.total_steps += 1\n\n        # Periodically update memory metrics\n        if step % self.update_interval == 0:\n            self.update_memory_metrics()\n\n    def _get_allocated_memory(self) -> dict[int, int]:\n        \"\"\"\n        Helper function for getting accelerator-agnostic allocated memory.\n\n        Returns:\n            A dictionary mapping device IDs to allocated memory in bytes\n        \"\"\"\n        memory_used: dict[int, int] = {}\n\n        # NVIDIA GPUs\n        if torch.cuda.is_available():\n            for i in range(torch.cuda.device_count()):\n                memory_used[i] = torch.cuda.memory_allocated(i)\n\n        # AMD GPUs\n        elif hasattr(torch, \"hip\") and torch.hip.is_available():\n            for i in range(torch.hip.device_count()):\n                if hasattr(torch.hip, \"memory_allocated\"):\n                    memory_used[i] = torch.hip.memory_allocated(i)\n\n        # Apple Silicon\n        elif hasattr(torch.backends, \"mps\") and torch.backends.mps.is_available():\n            # MPS doesn't have per-device memory stats since there's only one device\n            if hasattr(torch.mps, \"current_allocated_memory\"):\n                memory_used[0] = torch.mps.current_allocated_memory()\n\n        # Intel GPUs\n        elif hasattr(torch, \"xpu\") and torch.xpu.is_available():\n            for i in range(torch.xpu.device_count()):\n                if hasattr(torch.xpu, \"memory_allocated\"):\n                    memory_used[i] = torch.xpu.memory_allocated(i)\n\n        # NPUs\n        elif hasattr(torch, \"npu\") and torch.npu.is_available():\n            for i in range(torch.npu.device_count()):\n                if hasattr(torch.npu, \"memory_allocated\"):\n                    memory_used[i] = torch.npu.memory_allocated(i)\n\n        return memory_used\n\n    def update_memory_metrics(self):\n        \"\"\"Update peak memory usage metrics.\"\"\"\n        # CPU memory\n        cpu_memory = self._process.memory_info().rss\n        self.metrics.peak_cpu_memory = max(self.metrics.peak_cpu_memory, cpu_memory)\n\n        # GPU memory (if available)\n        memory_used = self._get_allocated_memory()\n        for i, memory in memory_used.items():\n            self.metrics.peak_gpu_memory[i] = max(\n                self.metrics.peak_gpu_memory.get(i, 0), memory\n            )\n\n    def get_memory_metrics(self) -> dict[str, Any]:\n        \"\"\"Get the current memory metrics as a dictionary.\"\"\"\n        memory_metrics = {\n            \"cpu_memory_bytes\": self._process.memory_info().rss,\n            \"peak_cpu_memory_bytes\": self.metrics.peak_cpu_memory,\n        }\n\n        # GPU memory (if available)\n        memory_used = self._get_allocated_memory()\n        for i, memory in memory_used.items():\n            memory_metrics[f\"gpu_{i}_memory_bytes\"] = memory\n            memory_metrics[f\"gpu_{i}_peak_memory_bytes\"] = (\n                self.metrics.peak_gpu_memory.get(i, 0)\n            )\n\n        return memory_metrics\n"
  },
  {
    "path": "src/axolotl/telemetry/whitelist.yaml",
    "content": "organizations:\n  - \"axolotl-ai-co\"\n  - \"meta-llama\"\n  - \"huggingface\"\n  - \"nvidia\"\n  - \"facebook\"\n  - \"google\"\n  - \"microsoft\"\n  - \"deepseek-ai\"\n  - \"HuggingFaceTB\"\n  - \"mistralai\"\n  - \"Qwen\"\n  - \"unsloth\"\n  - \"NousResearch\"\n  - \"allenai\"\n  - \"amd\"\n  - \"tiiuae\"\n  - \"tencent\"\n  - \"zai-org\"\n  - \"openai\"\n  - \"ibm-granite\"\n  - \"arcee-ai\"\n  - \"swiss-ai\"\n  - \"CohereForAI\"\n  - \"deepcogito\"\n  - \"THUDM\"\n  - \"ai21labs\"\n  - \"LiquidAI\"\n  - \"canopylabs\"\n  - \"state-spaces\"\n  - \"mistral-community\"\n  - \"llava-hf\"\n  - \"ByteDance-Seed\"\n  - \"ACE-Step\"\n  - \"openbmb\"\n  - \"MiniMaxAI\"\n  - \"stepfun-ai\"\n  - \"internlm\"\n  - \"katanemo\"\n  - \"XiaomiMiMo\"\n"
  },
  {
    "path": "src/axolotl/train.py",
    "content": "\"\"\"Prepare and train a model on a dataset. Can also infer from a model or merge lora\"\"\"\n\nfrom __future__ import annotations\n\nimport importlib\nimport inspect\nimport json\nimport os\nimport shutil\nimport signal\nimport sys\nimport typing\nimport weakref\nfrom collections import OrderedDict\nfrom contextlib import ExitStack\nfrom pathlib import Path\nfrom typing import Any, Dict\n\nimport torch\nimport transformers.modelcard\nfrom datasets import Dataset\nfrom huggingface_hub.errors import OfflineModeIsEnabled\nfrom peft import PeftConfig, PeftModel\nfrom transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin\nfrom transformers.integrations.deepspeed import is_deepspeed_zero3_enabled\nfrom transformers.trainer import Trainer\n\nfrom axolotl.common.datasets import TrainDatasetMeta\nfrom axolotl.contribs.lgpl import (  # pylint: disable = no-name-in-module\n    fix_untrained_tokens,\n)\nfrom axolotl.integrations.base import PluginManager\nfrom axolotl.loaders import ModelLoader, load_processor, load_tokenizer\nfrom axolotl.telemetry.errors import send_errors\nfrom axolotl.telemetry.manager import TelemetryManager\nfrom axolotl.utils.ctx_managers.sequence_parallel import SequenceParallelContextManager\nfrom axolotl.utils.dict import DictDefault\nfrom axolotl.utils.distributed import cleanup_distributed\nfrom axolotl.utils.freeze import freeze_layers_except\nfrom axolotl.utils.logging import get_logger\nfrom axolotl.utils.schemas.enums import RLType\nfrom axolotl.utils.train import determine_last_checkpoint\nfrom axolotl.utils.trainer import setup_trainer\n\nif typing.TYPE_CHECKING:\n    from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder\n\nLOG = get_logger(__name__)\n\nTELEMETRY_MANAGER = TelemetryManager.get_instance()\nPLUGIN_MANAGER = PluginManager.get_instance()\n\n\ndef setup_model_and_tokenizer(\n    cfg: DictDefault,\n) -> tuple[\n    PreTrainedModel, PreTrainedTokenizer, PeftConfig | None, ProcessorMixin | None\n]:\n    \"\"\"Load the tokenizer, processor (for multimodal models), and model based on\n    configuration.\n\n    Args:\n        cfg: Dictionary mapping `axolotl` config keys to values.\n\n    Returns:\n        Tuple containing model, tokenizer, `peft_config` (if LoRA / QLoRA, else\n            `None`), and processor (if multimodal, else `None`).\n    \"\"\"\n    # Load tokenizer\n    LOG.debug(\n        f\"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}\",\n    )\n    tokenizer = load_tokenizer(cfg)\n\n    # Load processor for multimodal models if needed\n    processor = None\n    if cfg.is_multimodal:\n        processor = load_processor(cfg, tokenizer)\n\n    # Load the model\n    LOG.debug(\"Loading model\")\n\n    model_loader = ModelLoader(cfg, tokenizer, processor=processor)\n    model, peft_config = model_loader.load()\n    if model.generation_config is not None:\n        model.generation_config.do_sample = True\n\n    model_properties = model.config.to_dict()\n    try:\n        model_properties[\"num_parameters\"] = model.num_parameters()\n    except Exception:  # pylint: disable=broad-exception-caught\n        model_properties[\"num_parameters\"] = sum(p.numel() for p in model.parameters())\n    # if the num_parameters is less than 2B, let's round to nearest 100M, else round to nearest 1B\n    if model_properties[\"num_parameters\"] < 2e9:\n        model_properties[\"num_parameters_est\"] = (\n            f\"{round(model_properties['num_parameters'] / 1e8) * 100}M\"\n        )\n    else:\n        model_properties[\"num_parameters_est\"] = (\n            f\"{round(model_properties['num_parameters'] / 1e9)}B\"\n        )\n    TELEMETRY_MANAGER.send_event(event_type=\"model-load\", properties=model_properties)\n    if peft_config:\n        TELEMETRY_MANAGER.send_event(\n            event_type=\"peft-config-load\", properties=peft_config.to_dict()\n        )\n\n    # Apply freezing if specified\n    if cfg.unfrozen_parameters:\n        freeze_layers_except(model, cfg.unfrozen_parameters)\n        if any(\n            any(embed in param for embed in [\"lm_head\", \"embed_tokens\"])\n            for param in cfg.unfrozen_parameters\n        ):\n            model.enable_input_require_grads()\n\n    return model, tokenizer, peft_config, processor\n\n\ndef setup_reference_model(\n    cfg: DictDefault, tokenizer: PreTrainedTokenizer\n) -> PreTrainedModel | None:\n    \"\"\"\n    Set up the reference model for RL training if needed.\n\n    Args:\n        cfg: Dictionary mapping `axolotl` config keys to values.\n        tokenizer: The tokenizer to use for the reference model.\n\n    Returns:\n        Reference model if needed for RL training, `None` otherwise.\n    \"\"\"\n    model_ref = None\n    if cfg.rl and cfg.rl != RLType.ORPO:\n        if cfg.adapter and not cfg.rl_adapter_ref_model:\n            # use built-in trl autounwrap\n            LOG.debug(\"Passing model_ref: None to RL trainer\")\n            model_ref = None  # explicit setting to None\n        else:\n            reference_model: bool = True\n            if cfg.rl == RLType.GRPO and cfg.trl.beta == 0:\n                reference_model = False\n            # load the model again for model_ref/baseline\n            model_loader = ModelLoader(cfg, tokenizer, reference_model=reference_model)\n            model_ref, _ = model_loader.load()\n    return model_ref\n\n\ndef setup_signal_handler(cfg: DictDefault, model: PreTrainedModel):\n    \"\"\"\n    Set up signal handler for graceful termination.\n\n    Args:\n        cfg: Dictionary mapping `axolotl` config keys to values.\n        model: The model to save on termination\n    \"\"\"\n    # ray workers don't have access to this signal\n    if cfg.local_rank == 0 and not cfg.use_ray:\n\n        def terminate_handler(_, __, model_weakref):\n            if model_weakref() is not None:\n                _model = model_weakref()\n                _model.save_pretrained(cfg.output_dir)\n\n            cleanup_distributed()\n            sys.exit(0)\n\n        _model_weakref = weakref.ref(model)\n        signal.signal(\n            signal.SIGINT,\n            lambda signum, frame: terminate_handler(signum, frame, _model_weakref),\n        )\n\n\ndef execute_training(\n    cfg: DictDefault, trainer: Any, resume_from_checkpoint: str | None\n):\n    \"\"\"\n    Execute the training process with appropriate SDP kernel configurations.\n\n    Args:\n        cfg: Dictionary mapping `axolotl` config keys to values.\n        trainer: The configured trainer object.\n        resume_from_checkpoint: Path to checkpoint to resume from, if applicable.\n    \"\"\"\n    with ExitStack() as stack:\n        # Define the context managers to use\n        if cfg.flash_optimum:\n            stack.enter_context(\n                torch.backends.cuda.sdp_kernel(\n                    enable_flash=True,\n                    enable_math=True,\n                    enable_mem_efficient=True,\n                )\n            )\n\n        if cfg.context_parallel_size > 1:\n            models = [trainer.model]\n            if hasattr(trainer, \"ref_model\") and trainer.ref_model:\n                models.append(trainer.ref_model)\n\n            stack.enter_context(\n                SequenceParallelContextManager(\n                    models=models,\n                    context_parallel_size=cfg.context_parallel_size,\n                    gradient_accumulation_steps=cfg.gradient_accumulation_steps,\n                    ring_attn_func=cfg.ring_attn_func,\n                    heads_k_stride=cfg.heads_k_stride,\n                    gather_outputs=cfg.rl is RLType.GRPO,\n                    device_mesh=trainer.accelerator.torch_device_mesh,\n                )\n            )\n\n        # TODO: disabling for now as not compatible with FSDP2 + torchao low bit optimizers\n        # if cfg.bf16:\n        #     torch.set_default_dtype(torch.bfloat16)\n\n        LOG.info(\"Starting trainer...\")\n        trainer.train(resume_from_checkpoint=resume_from_checkpoint)\n\n        PLUGIN_MANAGER.post_train(cfg, trainer.model)\n\n\ndef save_trained_model(\n    cfg: DictDefault,\n    trainer: Any,\n    model: PreTrainedModel,\n):\n    \"\"\"\n    Save the trained model according to configuration and training setup.\n\n    Args:\n        cfg: Dictionary mapping `axolotl` config keys to values.\n        trainer: The trainer object.\n        model: The trained model to save.\n    \"\"\"\n    LOG.info(f\"Training completed! Saving trained model to {cfg.output_dir}.\")\n\n    # Post training module hooks\n    for name, module in model.named_modules():\n        if hasattr(module, \"_post_training\"):\n            module._post_training(model, name)\n\n    # handle QAT\n    if cfg.qat:\n        from axolotl.utils.quantization import convert_qat_model\n\n        convert_qat_model(\n            model,\n            quantize_embedding=cfg.qat.quantize_embedding,\n        )\n        LOG.info(\n            \"QAT usage note: please ensure you quantize your model fine-tuned using QAT by running `axolotl quantize`\"\n            \" with the same config which you used for training.\"\n        )\n    # Handle ReLoRA early return case\n    if cfg.relora:\n        if cfg.adapter == \"lora\" and not (cfg.load_in_4bit or cfg.load_in_8bit):\n            model = model.merge_and_unload()\n        else:\n            # final model weights have already been saved by `ReLoRACallback.on_train_end`\n            return\n\n    if trainer.is_fsdp_enabled or cfg.fsdp_config:\n        if cfg.fsdp_config or cfg.fsdp:\n            if cfg.fsdp_config.final_state_dict_type:\n                state_dict_type = cfg.fsdp_config.final_state_dict_type\n            else:\n                state_dict_type = cfg.fsdp_config.state_dict_type\n            trainer.accelerator.state.fsdp_plugin.set_state_dict_type(state_dict_type)\n        trainer.save_model(cfg.output_dir)  # only handles FULL_STATE_DICT\n        if state_dict_type == \"SHARDED_STATE_DICT\":\n            LOG.info(\n                \"The final model was saved with a sharded state dict. Please ensure you merge \"\n                \"the sharded weights with `merge-sharded-fsdp-weights`.\"\n            )\n            checkpoint_dir = determine_last_checkpoint(cfg, update=False)\n            if (\n                not (Path(cfg.output_dir) / \"model.safetensors.index.json\").exists()\n                and checkpoint_dir\n            ):\n                # import here to prevent circular import\n                from axolotl.cli.merge_sharded_fsdp_weights import merge_fsdp_weights\n\n                fsdp_dir = Path(checkpoint_dir) / \"pytorch_model_fsdp_0\"\n                merged_path = str(Path(cfg.output_dir) / \"merged\")\n                merge_fsdp_weights(\n                    checkpoint_dir=str(fsdp_dir),\n                    output_path=merged_path,\n                )\n                trainer.accelerator.wait_for_everyone()\n                if trainer.accelerator.is_main_process:\n                    # move all files in merged_path to cfg.output_dir\n                    for merged_file in Path(merged_path).iterdir():\n                        if (Path(cfg.output_dir) / merged_file.name).exists():\n                            (Path(cfg.output_dir) / merged_file.name).unlink()\n                        shutil.move(str(merged_file), cfg.output_dir)\n                    shutil.rmtree(merged_path)  # remove what should be an empty dir\n        # TODO(wing):see https://github.com/huggingface/transformers/pull/40207\n        # cleanup the FSDP prefix in the model config.json\n        if trainer.accelerator.is_main_process:\n            with open(\n                Path(cfg.output_dir) / \"config.json\", \"r\", encoding=\"utf-8\"\n            ) as config_file_io:\n                # read the model config as an OrderedDict\n                config = json.load(config_file_io, object_pairs_hook=OrderedDict)\n                config[\"architectures\"] = [\n                    name.lstrip(\"FSDP\") for name in config[\"architectures\"]\n                ]\n            # write the updated model config back\n            with open(\n                os.path.join(cfg.output_dir, \"config.json\"), \"w\", encoding=\"utf-8\"\n            ) as config_file_io:\n                json.dump(config, config_file_io, indent=2)\n    elif cfg.deepspeed and is_deepspeed_zero3_enabled():\n        # Copied over from: https://github.com/huggingface/accelerate/blob/5ae611118057232f441055f7ef9ba0b0f2b8d533/docs/source/usage_guides/deepspeed.md#saving-and-loading\n        trainer.accelerator.wait_for_everyone()\n        trainer.save_model(cfg.output_dir)\n\n        # the trainer saved a model.safetensors file in the output directory,\n        # but it is most likely a proxy model and if so, should be deleted\n        maybe_proxy = os.path.exists(os.path.join(cfg.output_dir, \"model.safetensors\"))\n        maybe_sharded = os.path.exists(\n            os.path.join(cfg.output_dir, \"model.safetensors.index.json\")\n        )\n\n        if maybe_proxy and maybe_sharded:\n            LOG.info(f\"Deleting {os.path.join(cfg.output_dir, 'model.safetensors')}\")\n            LOG.info(\"This is a proxy model and should be deleted\")\n            try:\n                os.remove(os.path.join(cfg.output_dir, \"model.safetensors\"))\n            except FileNotFoundError:\n                pass\n    elif cfg.local_rank == 0:\n        if cfg.rl and cfg.adapter and not cfg.rl_adapter_ref_model:\n            trainer.model.save_pretrained(cfg.output_dir)\n\n        model.save_pretrained(cfg.output_dir)\n\n    if hasattr(cfg, \"llmcompressor\") and cfg.llmcompressor:\n        # TODO: add integration support so this can be implemented completely within the plugin\n        from axolotl.integrations.llm_compressor.utils import save_compressed_model\n\n        save_compressed_model(\n            model=model,\n            output_dir=cfg.output_dir,\n            trainer=trainer,\n            save_compressed=cfg.llmcompressor.save_compressed,\n        )\n\n    LOG.info(f\"Model successfully saved to {cfg.output_dir}\")\n\n\ndef create_model_card(cfg: DictDefault, trainer: Trainer):\n    \"\"\"\n    Create a model card for the trained model if needed.\n\n    Args:\n        cfg: Dictionary mapping `axolotl` config keys to values.\n        trainer: The trainer object with model card creation capabilities.\n    \"\"\"\n    if not cfg.hub_model_id:\n        # Guard since create_model_card may fail if dataset_tags is empty list\n        try:\n            model_card_kwarg = {\n                \"model_name\": cfg.output_dir.lstrip(\"./\")\n                .encode(\"utf-8\")\n                .decode(\"utf-8\")\n            }\n\n            # We check if we're using a TRL trainer; if so, `dataset_tags` is not consumed.\n            rl = cfg.rl is not None or cfg.reward_model or cfg.process_reward_model\n            if cfg.datasets is not None and not rl:\n                dataset_tags = [\n                    d[\"path\"] for d in cfg.datasets if not Path(d[\"path\"]).is_dir()\n                ]\n                dataset_tags = [d for d in dataset_tags if not d.startswith(\"https://\")]\n\n                if dataset_tags:\n                    model_card_kwarg[\"dataset_tags\"] = dataset_tags\n\n            trainer.create_model_card(**model_card_kwarg)\n        except (AttributeError, UnicodeDecodeError, OfflineModeIsEnabled):\n            pass\n    elif cfg.hub_model_id:\n        # Defensively push to the hub to ensure the model card is updated\n        trainer.push_to_hub()\n\n\ndef save_initial_configs(\n    cfg: DictDefault,\n    tokenizer: PreTrainedTokenizer,\n    model: PreTrainedModel,\n    peft_config: PeftConfig | None,\n    processor: ProcessorMixin | None,\n):\n    \"\"\"\n    Save initial configurations before training.\n\n    Args:\n        cfg: Dictionary mapping `axolotl` config keys to values.\n        tokenizer: The tokenizer to save.\n        model: The model to save configuration for.\n        peft_config: The PEFT configuration to save if applicable.\n    \"\"\"\n    # Create output_dir if it doesn't already exist\n    output_dir = Path(cfg.output_dir)\n    if not output_dir.is_dir():\n        os.makedirs(cfg.output_dir, exist_ok=True)\n\n    # Pre-save adapter config so it's available to inspect\n    if peft_config:\n        LOG.info(f\"Pre-saving adapter config to {cfg.output_dir}...\")\n        peft_config.save_pretrained(cfg.output_dir)\n\n    # Pre-save the tokenizer and model configs\n    LOG.info(f\"Pre-saving tokenizer to {cfg.output_dir}...\")\n    tokenizer.save_pretrained(\n        str(Path(cfg.output_dir)), save_jinja_files=cfg.tokenizer_save_jinja_files\n    )\n    if hasattr(model, \"config\"):\n        LOG.info(f\"Pre-saving model config to {cfg.output_dir}...\")\n        model.config.save_pretrained(str(output_dir))\n\n    if processor:\n        LOG.info(f\"Pre-saving processor to {cfg.output_dir}...\")\n        processor.save_pretrained(str(output_dir))\n\n\ndef setup_model_card(cfg: DictDefault):\n    \"\"\"\n    Set up the Axolotl badge and add the Axolotl config to the model card if available.\n\n    Args:\n        cfg: Dictionary mapping `axolotl` config keys to values.\n    \"\"\"\n    badge_markdown = \"\"\"[<img src=\"https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/image/axolotl-badge-web.png\" alt=\"Built with Axolotl\" width=\"200\" height=\"32\"/>](https://github.com/axolotl-ai-cloud/axolotl)\"\"\"\n    transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f\"\\n{badge_markdown}\"\n\n    if cfg.axolotl_config_path:\n        raw_axolotl_cfg = Path(cfg.axolotl_config_path)\n        version = importlib.metadata.version(\"axolotl\")\n        if raw_axolotl_cfg.is_file():\n            transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f\"\\n<details><summary>See axolotl config</summary>\\n\\naxolotl version: `{version}`\\n```yaml\\n{raw_axolotl_cfg.read_text(encoding='utf-8')}\\n```\\n\\n</details><br>\\n\"\n\n\ndef handle_untrained_tokens_fix(\n    cfg: DictDefault,\n    model: PreTrainedModel,\n    tokenizer: PreTrainedTokenizer,\n    train_dataset: Dataset,\n):\n    \"\"\"\n    Apply fixes for untrained tokens if configured.\n\n    Args:\n        cfg: Dictionary mapping `axolotl` config keys to values.\n        model: The model to apply fixes to.\n        tokenizer: The tokenizer for token identification.\n        train_dataset: The training dataset to use.\n    \"\"\"\n    if not cfg.fix_untrained_tokens:\n        return\n\n    is_ds_zero3: bool = False\n    if os.environ.get(\"ACCELERATE_DEEPSPEED_ZERO_STAGE\") == \"3\":\n        is_ds_zero3 = True\n\n    # Check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args\n    sig = inspect.signature(fix_untrained_tokens)\n\n    fix_kwargs: Dict[str, Any] = {}\n    # If the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list\n    if \"token_ids_to_fix\" in sig.parameters and isinstance(\n        cfg.fix_untrained_tokens, list\n    ):\n        fix_kwargs[\"token_ids_to_fix\"] = cfg.fix_untrained_tokens\n    if \"is_ds_zero3\" in sig.parameters:\n        fix_kwargs[\"is_ds_zero3\"] = is_ds_zero3\n\n    fix_untrained_tokens(model, tokenizer, train_dataset, **fix_kwargs)\n\n    if cfg.local_rank == 0:\n        model.save_pretrained(str(Path(cfg.output_dir)))\n\n\ndef setup_model_and_trainer(\n    cfg: DictDefault, dataset_meta: TrainDatasetMeta\n) -> tuple[\n    \"HFRLTrainerBuilder\" | \"HFCausalTrainerBuilder\",\n    PeftModel | PreTrainedModel,\n    PreTrainedTokenizer,\n    PeftConfig | None,\n    ProcessorMixin | None,\n]:\n    \"\"\"\n    Load model, tokenizer, trainer, etc. Helper function to encapsulate the full\n    trainer setup.\n\n    Args:\n        cfg: The configuration dictionary with training parameters.\n        dataset_meta: Object with training, validation datasets and metadata.\n\n    Returns:\n        Tuple of:\n            - Trainer (Causal or RLHF)\n            - Model\n            - Tokenizer\n            - PEFT config\n            - Processor\n    \"\"\"\n    # Load tokenizer, processor and model\n    model, tokenizer, peft_config, processor = setup_model_and_tokenizer(cfg)\n\n    # Set up reference model for RL if needed\n    model_ref = setup_reference_model(cfg, tokenizer)\n\n    # Get datasets from metadata\n    train_dataset = dataset_meta.train_dataset\n    eval_dataset = dataset_meta.eval_dataset\n    total_num_steps = dataset_meta.total_num_steps\n\n    # Set up trainer\n    trainer = setup_trainer(\n        cfg=cfg,\n        train_dataset=train_dataset,\n        eval_dataset=eval_dataset,\n        model=model,\n        tokenizer=tokenizer,\n        processor=processor,\n        total_num_steps=total_num_steps,\n        model_ref=model_ref,\n        peft_config=peft_config,\n    )\n    PLUGIN_MANAGER.post_trainer_create(cfg, trainer)\n\n    if cfg.use_ray:\n        try:\n            import ray.train.huggingface.transformers\n\n            trainer = ray.train.huggingface.transformers.prepare_trainer(trainer)\n        except ImportError:\n            LOG.warning(\n                \"The Ray integration with Hugging Face Transformers is not available. \"\n                \"To use Ray, install the 'ray[train]' package.\"\n            )\n\n    return (\n        trainer,\n        model,\n        tokenizer,\n        peft_config,\n        processor,\n    )\n\n\n@send_errors\ndef train(\n    cfg: DictDefault, dataset_meta: TrainDatasetMeta\n) -> tuple[PeftModel | PreTrainedModel, PreTrainedTokenizer, Trainer]:\n    \"\"\"\n    Train a model on the given dataset.\n\n    Args:\n        cfg: The configuration dictionary with training parameters\n        dataset_meta: Object with training, validation datasets and metadata\n\n    Returns:\n        Tuple of (model, tokenizer) after training\n    \"\"\"\n    # Setup model, tokenizer, (causal or RLHF) trainer, etc.\n    (\n        trainer,\n        model,\n        tokenizer,\n        peft_config,\n        processor,\n    ) = setup_model_and_trainer(cfg, dataset_meta)\n\n    # Handle untrained tokens if configured\n    train_dataset = dataset_meta.train_dataset\n    handle_untrained_tokens_fix(cfg, model, tokenizer, train_dataset)\n\n    # Additional setup\n    save_initial_configs(cfg, tokenizer, model, peft_config, processor)\n    setup_signal_handler(cfg, model)\n    setup_model_card(cfg)\n\n    # Execute the training\n    resume_from_checkpoint = determine_last_checkpoint(cfg)\n    execute_training(cfg, trainer, resume_from_checkpoint)\n\n    # clear cache\n    if torch.cuda.is_available():\n        torch.cuda.empty_cache()\n\n    # Save the trained model and cleanup\n    save_trained_model(cfg, trainer, model)\n    tokenizer.save_pretrained(\n        str(Path(cfg.output_dir)), save_jinja_files=cfg.tokenizer_save_jinja_files\n    )\n    create_model_card(cfg, trainer)\n    if not cfg.use_ray:\n        cleanup_distributed()\n    PLUGIN_MANAGER.post_train(cfg, model)\n\n    return model, tokenizer, trainer\n"
  },
  {
    "path": "src/axolotl/utils/__init__.py",
    "content": "\"\"\"\nBasic utils for Axolotl\n\"\"\"\n\nimport importlib.util\nimport os\nimport re\n\nimport torch\n\n\ndef is_mlflow_available():\n    return importlib.util.find_spec(\"mlflow\") is not None\n\n\ndef is_comet_available():\n    return importlib.util.find_spec(\"comet_ml\") is not None\n\n\ndef is_opentelemetry_available():\n    return (\n        importlib.util.find_spec(\"opentelemetry\") is not None\n        and importlib.util.find_spec(\"prometheus_client\") is not None\n    )\n\n\ndef is_trackio_available():\n    return importlib.util.find_spec(\"trackio\") is not None\n\n\ndef get_pytorch_version() -> tuple[int, int, int]:\n    \"\"\"\n    Get Pytorch version as a tuple of (major, minor, patch).\n    \"\"\"\n    torch_version = torch.__version__\n    version_match = re.match(r\"^(\\d+)\\.(\\d+)(?:\\.(\\d+))?\", torch_version)\n\n    if not version_match:\n        raise ValueError(\"Invalid version format\")\n\n    major, minor, patch = version_match.groups()\n    major, minor = int(major), int(minor)\n    patch = int(patch) if patch is not None else 0  # Default patch to 0 if not present\n    return major, minor, patch\n\n\ndef set_pytorch_cuda_alloc_conf():\n    \"\"\"Set up CUDA allocation config\"\"\"\n    torch_version = torch.__version__.split(\".\")\n    torch_major, torch_minor = int(torch_version[0]), int(torch_version[1])\n    config_value = \"expandable_segments:True,roundup_power2_divisions:16\"\n    if (\n        torch_major == 2\n        and torch_minor >= 9\n        and os.getenv(\"PYTORCH_ALLOC_CONF\") is None\n    ):\n        os.environ[\"PYTORCH_ALLOC_CONF\"] = config_value\n    elif (\n        torch_major == 2\n        and torch_minor >= 2\n        and os.getenv(\"PYTORCH_CUDA_ALLOC_CONF\") is None\n    ):\n        os.environ[\"PYTORCH_CUDA_ALLOC_CONF\"] = config_value\n\n\ndef set_misc_env():\n    if os.getenv(\"XFORMERS_IGNORE_FLASH_VERSION_CHECK\") is None:\n        os.environ[\"XFORMERS_IGNORE_FLASH_VERSION_CHECK\"] = \"1\"\n\n\ndef get_not_null(value, default=None):\n    \"\"\"\n    return the value if it's not None, otherwise return the default value\n    \"\"\"\n    return value if value is not None else default\n"
  },
  {
    "path": "src/axolotl/utils/bench.py",
    "content": "\"\"\"Benchmarking and measurement utilities\"\"\"\n\nimport functools\nimport logging\n\nimport torch\nfrom transformers.utils.import_utils import is_torch_npu_available\n\nfrom axolotl.utils.distributed import get_device_type\n\ntry:\n    from pynvml import (\n        NVMLError,\n        nvmlDeviceGetHandleByIndex,\n        nvmlDeviceGetMemoryInfo,\n        nvmlInit,\n    )\nexcept ImportError:\n    NVMLError = None\n    nvmlDeviceGetHandleByIndex = None\n    nvmlDeviceGetMemoryInfo = None\n    nvmlInit = None\n\n\ndef check_cuda_device(default_value):\n    \"\"\"\n    wraps a function and returns the default value instead of running the\n    wrapped function if cuda isn't available or the device is auto\n    :param default_value:\n    :return:\n    \"\"\"\n\n    def deco(func):\n        @functools.wraps(func)\n        def wrapper(*args, **kwargs):\n            device = kwargs.get(\"device\", args[0] if args else None)\n\n            if (\n                device is None\n                or not torch.cuda.is_available()\n                or device == \"auto\"\n                or torch.device(device).type == \"cpu\"\n                or torch.device(device).type == \"meta\"\n            ):\n                return default_value\n            return func(*args, **kwargs)\n\n        return wrapper\n\n    return deco\n\n\n@check_cuda_device(0.0)\ndef gpu_memory_usage(device=0):\n    return torch.cuda.memory_allocated(device) / 1024.0**3\n\n\n@check_cuda_device((0.0, 0.0, 0.0))\ndef gpu_memory_usage_all(device=0):\n    active = torch.cuda.memory_stats().get(\"active_bytes.all.peak\", 0) / 1024.0**3\n    allocated = torch.cuda.max_memory_allocated(device) / 1024.0**3\n    reserved = torch.cuda.max_memory_reserved(device) / 1024.0**3\n    torch.cuda.reset_peak_memory_stats(device)\n    return active, allocated, reserved\n\n\ndef mps_memory_usage_all():\n    active = torch.mps.current_allocated_memory() / 1024.0**3\n    allocated = torch.mps.driver_allocated_memory() / 1024.0**3\n    return active, allocated, 0\n\n\ndef npu_memory_usage_all(device=0):\n    usage = torch.npu.memory_allocated(device) / 1024.0**3\n    reserved = torch.npu.memory_reserved(device) / 1024.0**3\n    return usage, reserved - usage, 0\n\n\n@check_cuda_device(0.0)\ndef gpu_memory_usage_smi(device=0):\n    if isinstance(device, torch.device):\n        device = device.index\n    if isinstance(device, str) and device.startswith(\"cuda:\"):\n        device = int(device[5:])\n    if not nvmlInit:\n        return 0.0\n    try:\n        nvmlInit()\n        handle = nvmlDeviceGetHandleByIndex(device)\n        info = nvmlDeviceGetMemoryInfo(handle)\n        return info.used / 1024.0**3\n    except NVMLError:\n        return 0.0\n\n\ndef get_gpu_memory_usage(device: int | torch.device = 0):\n    cur_device_type = str(get_device_type())\n    if torch.backends.mps.is_available():\n        usage, cache, misc = mps_memory_usage_all()\n    elif \"npu\" in cur_device_type and is_torch_npu_available():\n        usage, cache, misc = npu_memory_usage_all(device)\n    elif \"cuda\" in cur_device_type and torch.cuda.is_available():\n        usage, cache, misc = gpu_memory_usage_all(device)\n    else:\n        return 0.0, 0.0, 0.0\n\n    return usage, cache, misc\n\n\ndef log_gpu_memory_usage(\n    log: logging.Logger | logging.LoggerAdapter,\n    msg: str = \"\",\n    device: int | torch.device = 0,\n):\n    try:\n        active, allocated, reserved = get_gpu_memory_usage(device)\n    except ValueError:\n        # likely CPU, ignore\n        return\n    cur_device_type = str(get_device_type())\n    extras = []\n    if allocated > 0:\n        extras.append(f\"+{allocated:.03f}GB allocated\")\n    if reserved > 0:\n        extras.append(f\"+{reserved:.03f}GB reserved\")\n    msg = f\"{cur_device_type} memory active:\" if not msg else msg\n    log.debug(\n        f\"{msg} {active:.03f}GB ({', '.join(extras)})\",\n        stacklevel=2,\n    )\n"
  },
  {
    "path": "src/axolotl/utils/callbacks/__init__.py",
    "content": "\"\"\"Callbacks for Trainer class\"\"\"\n\nfrom __future__ import annotations\n\nimport gc\nimport json\nimport os\nimport traceback\nfrom shutil import copyfile\nfrom tempfile import NamedTemporaryFile\nfrom typing import TYPE_CHECKING, Any, Dict, List\n\nimport evaluate\nimport numpy as np\nimport pandas as pd\nimport torch\nimport torch.distributed as dist\nimport wandb\nimport yaml\nfrom datasets import load_dataset\nfrom tqdm import tqdm\nfrom transformers import (\n    GenerationConfig,\n    Trainer,\n    TrainerCallback,\n    TrainerControl,\n    TrainerState,\n    TrainingArguments,\n)\nfrom transformers.trainer_utils import (\n    SaveStrategy,\n)\nfrom trl.models import unwrap_model_for_generation\n\nfrom axolotl.utils import is_comet_available, is_mlflow_available\nfrom axolotl.utils.callbacks.perplexity import Perplexity\nfrom axolotl.utils.distributed import (\n    barrier,\n    broadcast_dict,\n    gather_scalar_from_all_ranks,\n    get_world_size,\n    is_distributed,\n    is_main_process,\n    zero_first,\n)\nfrom axolotl.utils.logging import get_logger\nfrom axolotl.utils.schemas.config import AxolotlInputConfig\n\nif TYPE_CHECKING:\n    from axolotl.core.training_args import AxolotlTrainingArguments\n\n\nIGNORE_INDEX = -100\nLOG = get_logger(__name__)\n\n\nclass LossWatchDogCallback(TrainerCallback):\n    \"\"\"Callback to track loss and stop training if loss is too high\"\"\"\n\n    def __init__(self, cfg):\n        self.cfg = cfg\n        self.violations = 0\n        self.threshold = cfg.loss_watchdog_threshold\n        self.patience = cfg.loss_watchdog_patience or 3\n\n    def on_step_end(\n        self,\n        args: TrainingArguments,\n        state: TrainerState,\n        control: TrainerControl,\n        **_kwargs,\n    ) -> TrainerControl:\n        if len(state.log_history) > 0 and \"loss\" in state.log_history[-1]:\n            if state.log_history[-1][\"loss\"] > self.threshold:\n                self.violations += 1\n                if self.violations >= self.patience:\n                    LOG.warning(\n                        \"Loss is too high, stopping training (loss_watchdog_threshold)\"\n                    )\n                    control.should_training_stop = True\n            else:\n                self.violations = 0\n        return control\n\n\nclass SaveModelOnFirstStepCallback(TrainerCallback):\n    \"\"\"Callback to save the model on the first step of training if enabled\"\"\"\n\n    def on_step_end(\n        self,\n        args: TrainingArguments,\n        state: TrainerState,\n        control: TrainerControl,\n        **_kwargs,\n    ) -> TrainerControl:\n        if state.global_step == 1:\n            control.should_save = True\n        return control\n\n\ndef bench_eval_callback_factory(trainer, tokenizer):\n    accuracy = evaluate.load(\"accuracy\")\n    abcd_idx = [\n        tokenizer(\"A\", add_special_tokens=False).input_ids[0],\n        tokenizer(\"B\", add_special_tokens=False).input_ids[0],\n        tokenizer(\"C\", add_special_tokens=False).input_ids[0],\n        tokenizer(\"D\", add_special_tokens=False).input_ids[0],\n        tokenizer(\"E\", add_special_tokens=False).input_ids[0],\n        tokenizer(\"F\", add_special_tokens=False).input_ids[0],\n        tokenizer(\"G\", add_special_tokens=False).input_ids[0],\n    ]\n    bench_split = \"eval\"\n\n    def transform_bench_subject(example):\n        # Split on ':' and trim whitespace\n        parts = example[\"subject\"].split(\":\")\n        first_part = (\n            parts[0].strip().lower().replace(\"-\", \"_\")\n        )  # Lowercase the first part\n        second_part = (\n            parts[1].strip().replace(\"-\", \"_\") if len(parts) > 1 else \"all\"\n        )  # Replace hyphens with underscores\n\n        # Return the transformed values\n        return {\"name\": first_part, \"subject\": second_part}\n\n    if trainer.args.bench_dataset == \"mmlu-zs\":\n        bench_dataset = load_dataset(\n            \"openaccess-ai-collective/mmlu-evals\",\n            data_files={\n                \"eval\": \"zero_shot_mmlu_val.json\",\n                \"test\": \"zero_shot_mmlu_test.json\",\n            },\n        )\n        # bench_dataset = bench_dataset.remove_columns(\"subject\")\n    # MMLU Five-shot (Eval/Test only)\n    elif trainer.args.bench_dataset in [\"mmlu\", \"mmlu-fs\"]:\n        bench_dataset = load_dataset(\n            \"openaccess-ai-collective/mmlu-evals\",\n            data_files={\n                \"eval\": \"five_shot_mmlu_val.json\",\n                \"test\": \"five_shot_mmlu_test.json\",\n            },\n        )\n        # bench_dataset = bench_dataset.remove_columns('subject')\n    elif \"/\" in trainer.args.bench_dataset:\n        bench_ds = trainer.args.bench_dataset\n        bench_ds_name = \"/\".join(bench_ds.split(\"/\", 2)[:2])\n        bench_ds_data_file = \"/\".join(bench_ds.split(\"/\", 2)[2:])\n        bench_dataset = load_dataset(\n            bench_ds_name,\n            data_files={\n                \"eval\": bench_ds_data_file,\n            },\n        )\n        bench_dataset[\"eval\"] = bench_dataset[\"eval\"].map(transform_bench_subject)\n    else:\n        raise ValueError(\n            f\"unhandled value `{trainer.args.bench_dataset}` for bench_dataset training args\"\n        )\n    bench_dataset = bench_dataset[trainer.args.bench_split]\n    if trainer.args.max_bench_samples is not None:\n        bench_dataset = bench_dataset.select(range(trainer.args.max_bench_samples))\n\n    def tokenize_evals(example):\n        source = f\"{tokenizer.bos_token}{example['input']}\"\n        target = f\"{example['output']}{tokenizer.eos_token}\"\n\n        tokenized_source = tokenizer(\n            source,\n            max_length=2048,\n            truncation=True,\n            add_special_tokens=False,\n        )\n        tokenized_target = tokenizer(\n            target,\n            max_length=2048,\n            truncation=True,\n            add_special_tokens=False,\n        )\n        input_ids = tokenized_source[\"input_ids\"] + tokenized_target[\"input_ids\"]\n        labels = [IGNORE_INDEX] * len(tokenized_source[\"input_ids\"]) + tokenized_target[\n            \"input_ids\"\n        ]\n\n        return {\n            \"input_ids\": input_ids,\n            \"labels\": labels,\n            \"subject\": example[\"subject\"],\n        }\n\n    with zero_first(is_main_process()):\n        bench_dataset = bench_dataset.map(tokenize_evals)\n        bench_dataset = bench_dataset.filter(lambda x: x[\"labels\"][-2] in abcd_idx)\n\n    class BenchEvalCallback(TrainerCallback):\n        \"\"\"\n        TrainerCallback that runs the MMLU evals\n        \"\"\"\n\n        def on_evaluate(\n            self,\n            args: AxolotlTrainingArguments,\n            state: TrainerState,\n            control: TrainerControl,\n            metrics: Dict[str, float],\n            **kwargs,\n        ):\n            data_loader = trainer.get_bench_dataloader(\n                bench_dataset.remove_columns([\"input\", \"subject\", \"output\", \"name\"])\n            )\n            trainer.model.eval()\n            preds, refs = [], []\n            loss_bench = 0\n            for batch in tqdm(data_loader, total=len(data_loader)):\n                (loss, logits, labels) = trainer.prediction_step(\n                    trainer.model,\n                    batch,\n                    prediction_loss_only=False,\n                )\n                # There are two tokens, the output, and eos token.\n                for i, logit in enumerate(logits):\n                    label_non_zero_id = (batch[\"labels\"][i] != IGNORE_INDEX).nonzero()[\n                        0\n                    ][0]\n                    logit_abcd = logit[label_non_zero_id - 1][abcd_idx]\n                    preds.append(torch.argmax(logit_abcd).item())\n                labels = labels[labels != IGNORE_INDEX].view(-1, 2)[:, 0]\n                refs += [\n                    abcd_idx.index(label) if label in abcd_idx else -1\n                    for label in labels.tolist()\n                ]\n                loss_bench += loss.item()\n            # Extract results by subject.\n            bench_name = bench_dataset[\"name\"]\n            bench_names: dict = {s: {\"refs\": [], \"preds\": []} for s in set(bench_name)}\n            for s, p, r in zip(bench_name, preds, refs, strict=False):\n                bench_names[s][\"preds\"].append(p)\n                bench_names[s][\"refs\"].append(r)\n            barrier()\n            local_bench_names = bench_names\n            gathered_bench_names: List[Dict] = [{} for _ in range(get_world_size())]\n            # Gather results from all GPUs to GPU 0\n\n            loss_bench_ranks = gather_scalar_from_all_ranks(\n                lambda: loss_bench, get_world_size()\n            )\n            len_data_loader_ranks = gather_scalar_from_all_ranks(\n                lambda: len(data_loader), get_world_size()\n            )\n\n            results = {}\n            if is_distributed() and not is_main_process():\n                dist.gather_object(local_bench_names, dst=0)\n            else:\n                if is_distributed():\n                    dist.gather_object(local_bench_names, gathered_bench_names, dst=0)\n                else:\n                    gathered_bench_names = [local_bench_names]\n                bench_loss = sum(loss_bench_ranks) / sum(len_data_loader_ranks)\n                results = {f\"{bench_split}_bench_loss\": bench_loss}\n\n                # Combine results from all GPUs\n                combined_bench_names: Dict[str, Dict[str, List]] = {}\n                for bench_name in gathered_bench_names:\n                    for name, data in bench_name.items():\n                        if name not in combined_bench_names:\n                            combined_bench_names[name] = {\"refs\": [], \"preds\": []}\n                        combined_bench_names[name][\"refs\"].extend(data[\"refs\"])\n                        combined_bench_names[name][\"preds\"].extend(data[\"preds\"])\n\n                bench_scores = []\n                bench_refs = []\n                bench_preds = []\n                for bench_name in combined_bench_names:\n                    bench_score = accuracy.compute(\n                        references=combined_bench_names[bench_name][\"refs\"],\n                        predictions=combined_bench_names[bench_name][\"preds\"],\n                    )[\"accuracy\"]\n                    bench_refs.extend(combined_bench_names[bench_name][\"refs\"])\n                    bench_preds.extend(combined_bench_names[bench_name][\"preds\"])\n                    if not pd.isna(bench_score):\n                        results[f\"{bench_split}_bench_accuracy_{bench_name}\"] = (\n                            bench_score\n                        )\n                        bench_scores.append(bench_score)\n                    else:\n                        results[f\"{bench_split}_bench_accuracy_{bench_name}\"] = 0.0\n                        bench_scores.append(0.0)\n                results[f\"{bench_split}_bench_average_accuracy\"] = np.mean(bench_scores)\n                results[f\"{bench_split}_bench_total_accuracy\"] = accuracy.compute(\n                    references=bench_refs, predictions=bench_preds\n                )[\"accuracy\"]\n                trainer.log(results)\n\n            results = broadcast_dict(results)\n            for key, val in results.items():\n                metrics[key] = val\n\n    return BenchEvalCallback\n\n\ndef causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):\n    class CausalLMBenchEvalCallback(TrainerCallback):\n        \"\"\"Callback to log prediction values during each evaluation\"\"\"\n\n        def __init__(self, cfg):\n            self.cfg = cfg\n            self.logged = False\n            self.metrics = self.__maybe_load_metrics()\n\n        def __maybe_load_metrics(self):\n            metrics = {}\n            for metric in self.cfg.eval_causal_lm_metrics:\n                if metric == \"perplexity\":\n                    max_seq_len = self.cfg.eval_max_new_tokens\n                    metrics[metric] = Perplexity(\n                        tokenizer=tokenizer,\n                        max_seq_len=max_seq_len,\n                    )\n                else:\n                    try:\n                        metrics[metric] = evaluate.load(metric)\n                    except Exception as exc:\n                        LOG.warning(f\"{metric}: {exc.args}\")\n            return metrics\n\n        def on_evaluate(\n            self,\n            args: AxolotlTrainingArguments,\n            state: TrainerState,\n            control: TrainerControl,\n            train_dataloader,\n            eval_dataloader,\n            **kwargs,\n        ):\n            trainer.model_wrapped.eval()\n\n            device = torch.device(\n                self.cfg.device\n            )  # Use this instead of trainer.model_wrapped.device as it may return cpu if fsdp offloaded\n\n            generation_config = GenerationConfig(\n                max_new_tokens=self.cfg.eval_max_new_tokens,\n                bos_token_id=tokenizer.bos_token_id,\n                eos_token_id=tokenizer.eos_token_id,\n                pad_token_id=tokenizer.pad_token_id,\n                do_sample=False,\n                use_cache=True,\n                return_dict_in_generate=True,\n                output_attentions=False,\n                output_hidden_states=False,\n                output_scores=False,\n            )\n\n            def find_ranges(lst):\n                ranges = []\n                start = 0\n                for i in range(1, len(lst)):\n                    if lst[i] == 0:\n                        ranges.append((start, i - 1))\n                        start = i\n                end = len(lst) - 1\n                ranges.append((start, end))\n                return ranges\n\n            def compute(metric: evaluate.Metric, **kwargs):\n                # safely compute a metric and return the score if the format is correct\n                metric_score = None\n                try:\n                    # Only pass the kwargs that are in the metric's feature list\n                    metric_kwargs = {\n                        k: kwargs[k] for k in metric._feature_names() if k in kwargs\n                    }\n\n                    if isinstance(metric, Perplexity):\n                        metric_kwargs[\"model\"] = trainer.model_wrapped\n\n                    metric_score = metric.compute(**metric_kwargs)\n                    return (\n                        metric_score[\"score\"]\n                        if \"score\" in metric_score\n                        else metric_score[\"mean_score\"]\n                    )\n                except Exception:\n                    traceback.print_exc()\n                    LOG.debug(\n                        f\"Failed to compute metric {metric.name} with kwargs {kwargs.keys()}\"\n                    )\n                return metric_score\n\n            def evaluate_preds(sources, predictions, references):\n                scores = {}\n\n                for metric_name, metric in self.metrics.items():\n                    score = compute(\n                        metric,\n                        references=references,\n                        predictions=predictions,\n                        sources=sources,\n                    )\n                    if score is None:\n                        score = compute(\n                            metric,\n                            references=[[r] for r in references],\n                            predictions=predictions,\n                        )\n                    scores[\"eval_\" + metric_name] = score\n                return scores\n\n            def predict_with_generate():\n                eval_src, eval_pred, eval_ref = [], [], []\n\n                with unwrap_model_for_generation(\n                    trainer.model_wrapped, trainer.accelerator\n                ) as unwrapped_model:\n                    for batch in tqdm(eval_dataloader, disable=not is_main_process()):\n                        batch_labels = batch[\"labels\"].to(device)\n                        batch_input_ids = batch[\"input_ids\"].to(device)\n\n                        if \"position_ids\" in batch:\n                            batch_pos_ids = batch[\"position_ids\"].tolist()\n                        else:\n                            batch_pos_ids = [None] * len(batch[\"input_ids\"])\n\n                        prompt_token_ids_list = []\n                        completion_token_ids_list = []\n\n                        for input_ids_all, labels_all, pos_ids in zip(\n                            batch_input_ids,\n                            batch_labels,\n                            batch_pos_ids,\n                            strict=False,\n                        ):\n                            if pos_ids is None:\n                                pos_ranges = [(0, len(input_ids_all) - 1)]\n                            else:\n                                pos_ranges = find_ranges(pos_ids)\n\n                            for pos_range in pos_ranges:\n                                start, end = pos_range\n                                if start == end:\n                                    continue\n\n                                input_ids = input_ids_all[start : end + 1]\n                                labels = labels_all[start : end + 1]\n\n                                tokens_without_loss = labels == IGNORE_INDEX\n                                tokens_with_loss = labels != IGNORE_INDEX\n                                tokens_exclude_padding = (\n                                    input_ids != tokenizer.pad_token_id\n                                )\n                                prompt_token_includes = (\n                                    tokens_without_loss & tokens_exclude_padding\n                                )\n\n                                prompt_token_ids = input_ids[prompt_token_includes]\n                                prompt_token_ids_list.append(prompt_token_ids)\n\n                                completion_token_ids = input_ids[tokens_with_loss]\n                                completion_token_ids_list.append(completion_token_ids)\n\n                        prompt_texts = tokenizer.batch_decode(\n                            prompt_token_ids_list, skip_special_tokens=True\n                        )\n                        completion_texts = tokenizer.batch_decode(\n                            completion_token_ids_list, skip_special_tokens=True\n                        )\n\n                        with torch.no_grad():\n                            prompt_encoding = tokenizer(\n                                prompt_texts, padding=True, return_tensors=\"pt\"\n                            ).to(device)\n\n                            predictions = unwrapped_model.generate(\n                                **prompt_encoding, generation_config=generation_config\n                            )\n\n                            del prompt_encoding\n\n                        prediction_all_tokens = predictions[\"sequences\"].cpu().tolist()\n                        prediction_without_prompt_tokens_list = []\n                        for prompt_token_ids, prediction_tokens in zip(\n                            prompt_token_ids_list, prediction_all_tokens, strict=False\n                        ):\n                            prediction_without_prompt_tokens = prediction_tokens[\n                                len(prompt_token_ids) :\n                            ]\n                            prediction_without_prompt_tokens_list.append(\n                                prediction_without_prompt_tokens\n                            )\n\n                        predicted_texts = tokenizer.batch_decode(\n                            prediction_without_prompt_tokens_list,\n                            skip_special_tokens=True,\n                        )\n\n                        eval_src.extend(prompt_texts)\n                        eval_pred.extend(predicted_texts)\n                        eval_ref.extend(completion_texts)\n\n                return eval_src, eval_pred, eval_ref\n\n            eval_preds = predict_with_generate()\n            trainer.log(evaluate_preds(*eval_preds))\n\n            return control\n\n    return CausalLMBenchEvalCallback\n\n\ndef log_prediction_callback_factory(trainer: Trainer, tokenizer, logger: str):\n    class LogPredictionCallback(TrainerCallback):\n        \"\"\"Callback to log prediction values during each evaluation\"\"\"\n\n        def __init__(self, cfg):\n            self.cfg = cfg\n            self.logged = False\n\n        def on_evaluate(\n            self,\n            args: AxolotlTrainingArguments,\n            state: TrainerState,\n            control: TrainerControl,\n            train_dataloader,\n            eval_dataloader,\n            **kwargs,\n        ):\n            eval_table_size = self.cfg.eval_table_size\n\n            if eval_table_size <= 0:\n                return control\n\n            trainer.model.eval()\n            device = torch.device(self.cfg.device)\n\n            generation_config = GenerationConfig(\n                max_new_tokens=self.cfg.eval_max_new_tokens,\n                bos_token_id=tokenizer.bos_token_id,\n                eos_token_id=tokenizer.eos_token_id,\n                pad_token_id=tokenizer.pad_token_id,\n                do_sample=False,\n                use_cache=True,\n                return_dict_in_generate=True,\n                output_attentions=False,\n                output_hidden_states=False,\n                output_scores=False,\n            )\n\n            def logits_to_tokens(logits) -> torch.Tensor:\n                probabilities = torch.softmax(logits, dim=-1)\n                # Get the predicted token ids (the ones with the highest probability)\n                predicted_token_ids = torch.argmax(probabilities, dim=-1)\n                return predicted_token_ids\n\n            def find_ranges(lst):\n                ranges = []\n                start = 0\n                for i in range(1, len(lst)):\n                    if lst[i] == 0:\n                        ranges.append((start, i - 1))\n                        start = i\n                end = len(lst) - 1\n                ranges.append((start, end))\n                return ranges\n\n            def log_table_from_dataloader(name: str, table_dataloader):\n                table_data: Dict[str, List[Any]] = {\n                    \"id\": [],\n                    \"Prompt\": [],\n                    \"Correct Completion\": [],\n                    \"Predicted Completion (model.generate)\": [],\n                    \"Predicted Completion (trainer.prediction_step)\": [],\n                }\n                row_index = 0\n\n                for batch in tqdm(table_dataloader):\n                    if row_index > eval_table_size:\n                        break\n\n                    batch_labels = batch[\"labels\"].to(device)\n                    batch_input_ids = batch[\"input_ids\"].to(device)\n\n                    if \"position_ids\" in batch:\n                        batch_pos_ids = batch[\"position_ids\"].tolist()\n                    else:\n                        batch_pos_ids = [None] * len(batch[\"input_ids\"])\n\n                    (_, batch_logits, _) = trainer.prediction_step(\n                        trainer.model,\n                        batch,\n                        prediction_loss_only=False,\n                    )\n\n                    prompt_token_ids_list = []\n                    pred_step_token_ids_list = []\n                    completion_token_ids_list = []\n\n                    for input_ids_all, labels_all, pos_ids, logits in zip(\n                        batch_input_ids,\n                        batch_labels,\n                        batch_pos_ids,\n                        batch_logits,\n                        strict=False,\n                    ):\n                        if pos_ids is None:\n                            pos_ranges = [(0, len(input_ids_all) - 1)]\n                        else:\n                            pos_ranges = find_ranges(pos_ids)\n\n                        for pos_range in pos_ranges:\n                            start, end = pos_range\n                            if start == end:\n                                continue\n\n                            input_ids = input_ids_all[start : end + 1]\n                            labels = labels_all[start : end + 1]\n\n                            tokens_without_loss = labels == IGNORE_INDEX\n                            tokens_with_loss = labels != IGNORE_INDEX\n                            tokens_exclude_padding = input_ids != tokenizer.pad_token_id\n                            prompt_token_includes = (\n                                tokens_without_loss & tokens_exclude_padding\n                            )\n\n                            prompt_token_ids = input_ids[prompt_token_includes]\n                            prompt_token_ids_list.append(prompt_token_ids)\n\n                            completion_token_ids = input_ids[tokens_with_loss]\n                            completion_token_ids_list.append(completion_token_ids)\n\n                            pred_step_token_ids = logits_to_tokens(\n                                logits[start : end + 1]\n                            )[tokens_with_loss]\n                            pred_step_token_ids_list.append(pred_step_token_ids)\n\n                    prompt_texts = tokenizer.batch_decode(\n                        prompt_token_ids_list, skip_special_tokens=True\n                    )\n                    completion_texts = tokenizer.batch_decode(\n                        completion_token_ids_list, skip_special_tokens=True\n                    )\n                    pred_step_texts = tokenizer.batch_decode(\n                        pred_step_token_ids_list, skip_special_tokens=True\n                    )\n\n                    with torch.no_grad():\n                        prompt_encoding = tokenizer(\n                            prompt_texts, padding=True, return_tensors=\"pt\"\n                        ).to(self.cfg.device)\n                        predictions = trainer.model.generate(\n                            **prompt_encoding, generation_config=generation_config\n                        )\n\n                    prediction_all_tokens = predictions[\"sequences\"].cpu().tolist()\n                    prediction_without_prompt_tokens_list = []\n                    for prompt_token_ids, prediction_tokens in zip(\n                        prompt_token_ids_list, prediction_all_tokens, strict=False\n                    ):\n                        prediction_without_prompt_tokens = prediction_tokens[\n                            len(prompt_token_ids) :\n                        ]\n                        prediction_without_prompt_tokens_list.append(\n                            prediction_without_prompt_tokens\n                        )\n\n                    predicted_texts = tokenizer.batch_decode(\n                        prediction_without_prompt_tokens_list, skip_special_tokens=True\n                    )\n\n                    for (\n                        prompt_text,\n                        completion_text,\n                        prediction_text,\n                        pred_step_text,\n                    ) in zip(\n                        prompt_texts,\n                        completion_texts,\n                        predicted_texts,\n                        pred_step_texts,\n                        strict=False,\n                    ):\n                        table_data[\"id\"].append(row_index)\n                        table_data[\"Prompt\"].append(prompt_text)\n                        table_data[\"Correct Completion\"].append(completion_text)\n                        table_data[\"Predicted Completion (model.generate)\"].append(\n                            prediction_text\n                        )\n                        table_data[\n                            \"Predicted Completion (trainer.prediction_step)\"\n                        ].append(pred_step_text)\n                        row_index += 1\n                if logger == \"wandb\":\n                    # type: ignore[attr-defined]\n                    wandb.run.log(\n                        {\n                            f\"{name} - Predictions vs Ground Truth\": pd.DataFrame(\n                                table_data\n                            )\n                        }\n                    )\n                elif logger == \"mlflow\" and is_mlflow_available():\n                    import mlflow\n\n                    tracking_uri = AxolotlInputConfig(\n                        **self.cfg.to_dict()\n                    ).mlflow_tracking_uri\n                    mlflow.log_table(\n                        data=table_data,\n                        artifact_file=\"PredictionsVsGroundTruth.json\",\n                        tracking_uri=tracking_uri,\n                    )\n                elif logger == \"comet_ml\" and is_comet_available():\n                    import comet_ml\n\n                    experiment = comet_ml.get_running_experiment()\n                    if experiment:\n                        experiment.log_table(\n                            f\"{name} - Predictions vs Ground Truth.csv\",\n                            pd.DataFrame(table_data),\n                        )\n\n            if is_main_process():\n                log_table_from_dataloader(\"Eval\", eval_dataloader)\n\n            return control\n\n    return LogPredictionCallback\n\n\nclass SaveAxolotlConfigtoWandBCallback(TrainerCallback):\n    \"\"\"Callback to save axolotl config to wandb\"\"\"\n\n    def __init__(self, axolotl_config_path):\n        self.axolotl_config_path = axolotl_config_path\n\n    def on_train_begin(\n        self,\n        args: AxolotlTrainingArguments,\n        state: TrainerState,\n        control: TrainerControl,\n        **kwargs,\n    ):\n        if state.is_world_process_zero:\n            try:\n                # sync config to top level in run, cannot delete file right away because wandb schedules it to be synced even w/policy = 'now', so let OS delete it later.\n                with NamedTemporaryFile(\n                    mode=\"w\", delete=False, suffix=\".yml\", prefix=\"axolotl_config_\"\n                ) as temp_file:\n                    copyfile(self.axolotl_config_path, temp_file.name)\n                    artifact = wandb.Artifact(\n                        f\"config-{wandb.run.id}\", type=\"axolotl-config\"\n                    )\n                    artifact.add_file(temp_file.name)\n                    wandb.log_artifact(artifact)\n                    wandb.save(temp_file.name)\n                    LOG.info(\n                        \"The Axolotl config has been saved to the WandB run under files.\"\n                    )\n            except (FileNotFoundError, ConnectionError) as err:\n                LOG.warning(f\"Error while saving Axolotl config to WandB: {err}\")\n\n            try:\n                with open(self.axolotl_config_path, \"r\", encoding=\"utf-8\") as f:\n                    cfg = yaml.safe_load(f) or {}\n\n                chat_tpl = cfg.get(\"chat_template_jinja\")\n                if chat_tpl:\n                    with NamedTemporaryFile(\n                        mode=\"w\", delete=True, suffix=\".jinja\", prefix=\"chat_template_\"\n                    ) as temp_ct_file:\n                        if (\n                            isinstance(chat_tpl, str)\n                            and os.path.exists(chat_tpl)\n                            and os.path.isfile(chat_tpl)\n                        ):\n                            copyfile(chat_tpl, temp_ct_file.name)\n                        else:\n                            temp_ct_file.write(str(chat_tpl))\n                            temp_ct_file.flush()\n\n                        artifact = wandb.Artifact(\n                            f\"chat-template-{wandb.run.id}\", type=\"jinja-template\"\n                        )\n                        artifact.add_file(temp_ct_file.name)\n                        wandb.log_artifact(artifact)\n                        wandb.save(temp_ct_file.name)\n                        LOG.info(\n                            \"The chat_template_jinja has been saved to the WandB run under files.\"\n                        )\n            except (FileNotFoundError, ConnectionError, yaml.YAMLError) as err:\n                LOG.warning(f\"Error while saving chat_template_jinja to WandB: {err}\")\n\n            if args.deepspeed:\n                try:\n                    # sync config to top level in run, cannot delete file right away because wandb schedules it to be synced even w/policy = 'now', so let OS delete it later.\n                    with NamedTemporaryFile(\n                        mode=\"w\",\n                        delete=False,\n                        suffix=\".json\",\n                        prefix=\"deepspeed_config_\",\n                    ) as temp_file:\n                        skip_upload = False\n                        if isinstance(args.deepspeed, dict):\n                            json.dump(args.deepspeed, temp_file, indent=4)\n                        elif isinstance(args.deepspeed, str) and os.path.exists(\n                            args.deepspeed\n                        ):\n                            copyfile(args.deepspeed, temp_file.name)\n                        else:\n                            skip_upload = True\n                        if not skip_upload:\n                            artifact = wandb.Artifact(\n                                f\"deepspeed-config-{wandb.run.id}\",\n                                type=\"deepspeed-config\",\n                            )\n                            artifact.add_file(temp_file.name)\n                            wandb.log_artifact(artifact)\n                            wandb.save(temp_file.name)\n                            LOG.info(\n                                \"The DeepSpeed config has been saved to the WandB run under files.\"\n                            )\n                except (FileNotFoundError, ConnectionError) as err:\n                    LOG.warning(f\"Error while saving DeepSpeed config to WandB: {err}\")\n\n        return control\n\n\nclass GCCallback(TrainerCallback):\n    \"\"\"Callback to garbage collect torch cache\"\"\"\n\n    def __init__(self, gc_steps: int | None = -1):\n        self.gc_steps: int = gc_steps or -1\n        self.next_gc_on_begin_step: int = -1\n\n    def _gc(self):\n        torch.cuda.empty_cache()\n        gc.collect()\n\n    def on_train_begin(\n        self,\n        args,\n        state,\n        control,\n        **kwargs,\n    ):\n        self._gc()\n\n    def on_step_begin(\n        self,\n        args,\n        state,\n        control,\n        **kwargs,\n    ):\n        if self.next_gc_on_begin_step == state.global_step or state.global_step == 0:\n            self._gc()\n\n    def on_step_end(\n        self,\n        args,\n        state,\n        control,\n        **kwargs,\n    ):\n        if control.should_evaluate:\n            # automatically GC before evals so the eval memory spike from the CEL doesn't OOM the trainer\n            self._gc()\n            # also GC on the start of the next step after the eval\n            self.next_gc_on_begin_step = state.global_step + 1\n        elif self.gc_steps > 0 and state.global_step % self.gc_steps == 0:\n            self._gc()\n        elif (\n            args.save_strategy == SaveStrategy.STEPS\n            and state.save_steps > 0\n            and state.global_step % state.save_steps == 0\n        ):\n            # gc on save steps in case anything is loaded to CPU RAM like offloaded tensors\n            self._gc()\n        elif state.global_step >= state.max_steps:\n            if args.save_strategy == SaveStrategy.STEPS:\n                # gc on save steps in case anything is loaded to CPU RAM like offloaded tensors\n                self._gc()\n\n    def on_epoch_end(\n        self,\n        args,\n        state,\n        control,\n        **kwargs,\n    ):\n        self._gc()\n\n\ndef colab_inference_post_train_callback(trainer: Trainer):\n    class ColabCallback(TrainerCallback):\n        \"\"\"Callback to prep model for inference on Google Colab\"\"\"\n\n        def __init__(self, cfg):\n            self.gpu_name = torch.cuda.get_device_name(0)\n            self.cfg = cfg\n\n        def on_train_end(self, args, state, control, **kwargs):\n            \"\"\"\n            handle T4 gpu, we need to convert attention to eager for inference\n            \"\"\"\n            if \"Tesla T4\" in self.gpu_name and self.cfg.xformers_attention:\n                trainer.model.config._attn_implementation = \"eager\"\n            trainer.model.gradient_checkpointing_disable()\n            trainer.model.config.use_cache = True\n            trainer.model.eval()\n\n    return ColabCallback\n"
  },
  {
    "path": "src/axolotl/utils/callbacks/comet_.py",
    "content": "\"\"\"Comet module for trainer callbacks\"\"\"\n\nfrom typing import TYPE_CHECKING\n\nimport comet_ml\nfrom transformers import TrainerCallback, TrainerControl, TrainerState\n\nfrom axolotl.utils.distributed import is_main_process\nfrom axolotl.utils.logging import get_logger\n\nif TYPE_CHECKING:\n    from axolotl.core.training_args import AxolotlTrainingArguments\n\nLOG = get_logger(__name__)\n\n\nclass SaveAxolotlConfigtoCometCallback(TrainerCallback):\n    \"\"\"Callback to save axolotl config to comet\"\"\"\n\n    def __init__(self, axolotl_config_path):\n        self.axolotl_config_path = axolotl_config_path\n\n    def on_train_begin(\n        self,\n        args: \"AxolotlTrainingArguments\",\n        state: TrainerState,\n        control: TrainerControl,\n        **kwargs,\n    ):\n        if is_main_process():\n            try:\n                comet_experiment = comet_ml.start(source=\"axolotl\")\n                comet_experiment.log_other(\"Created from\", \"axolotl\")\n                comet_experiment.log_asset(\n                    self.axolotl_config_path,\n                    file_name=\"axolotl-config\",\n                )\n                LOG.info(\n                    \"The Axolotl config has been saved to the Comet Experiment under assets.\"\n                )\n            except (FileNotFoundError, ConnectionError) as err:\n                LOG.warning(f\"Error while saving Axolotl config to Comet: {err}\")\n        return control\n"
  },
  {
    "path": "src/axolotl/utils/callbacks/dynamic_checkpoint.py",
    "content": "from pathlib import Path\n\nfrom transformers import (\n    TrainerCallback,\n    TrainerControl,\n    TrainerState,\n    TrainingArguments,\n)\n\nfrom axolotl.utils.distributed import (\n    barrier,\n    is_distributed,\n    is_main_process,\n)\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\nDEFAULT_TRIGGER_FILENAME = \"axolotl_checkpoint.save\"\n\n\nclass DynamicCheckpointCallback(TrainerCallback):\n    \"\"\"\n    Callback to save checkpoints on-demand during training via:\n    1. File-based trigger (works everywhere, rank 0 checks file)\n\n    Thread-safe for multi-GPU distributed training.\n\n    Usage:\n        # File-based:\n        touch /path/to/output_dir/axolotl_checkpoint.save\n    \"\"\"\n\n    def _get_config_value(self, config, key, default=None):\n        \"\"\"Helper to get config value from dict or object.\"\"\"\n        if isinstance(config, dict):\n            return config.get(key, default)\n        return getattr(config, key, default)\n\n    def __init__(self, cfg):\n        self.cfg = cfg\n        if not cfg.dynamic_checkpoint or not cfg.dynamic_checkpoint.enabled:\n            self.enabled = False\n            return\n\n        self.enabled = True\n        dc_config = cfg.dynamic_checkpoint\n\n        trigger_file_path = self._get_config_value(dc_config, \"trigger_file_path\")\n        self.trigger_filename = (\n            trigger_file_path if trigger_file_path else DEFAULT_TRIGGER_FILENAME\n        )\n\n        check_interval = self._get_config_value(dc_config, \"check_interval\")\n        self.check_interval = check_interval if check_interval is not None else 100\n        self.should_save_checkpoint = False\n\n        LOG.info(\n            f\"Dynamic checkpoint enabled. To trigger checkpoint save:\\n\"\n            f\"  • File: touch {cfg.output_dir}/{self.trigger_filename}\\n\"\n            f\"  • Check interval: every {self.check_interval} steps\",\n        )\n\n    def on_step_end(\n        self,\n        args: TrainingArguments,\n        state: TrainerState,\n        control: TrainerControl,\n        **_kwargs,\n    ) -> TrainerControl:\n        \"\"\"\n        Check for checkpoint triggers at the end of each step.\n        ONLY rank 0 checks the file, then all ranks synchronize.\n        \"\"\"\n        if not self.enabled:\n            return control\n\n        trigger_detected = False\n\n        if state.global_step % self.check_interval == 0:\n            if is_main_process():\n                trigger_path = Path(args.output_dir) / self.trigger_filename\n\n                if trigger_path.exists():\n                    trigger_detected = True\n                    try:\n                        trigger_path.unlink()  # Delete the trigger file\n                        LOG.info(\n                            f\"Dynamic checkpoint triggered via file '{self.trigger_filename}' \"\n                            f\"at step {state.global_step}\",\n                        )\n                    except OSError as exc:\n                        LOG.warning(\n                            f\"Failed to delete trigger file: {exc}\",\n                        )\n\n                if self.should_save_checkpoint:\n                    trigger_detected = True\n                    self.should_save_checkpoint = False  # Reset flag\n\n            if is_distributed():\n                import torch\n                import torch.distributed as dist\n\n                device = getattr(\n                    args,\n                    \"device\",\n                    torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\"),\n                )\n\n                trigger_tensor = torch.tensor(\n                    1 if trigger_detected else 0,\n                    dtype=torch.long,\n                    device=device,\n                )\n\n                dist.broadcast(trigger_tensor, src=0)\n\n                trigger_detected = bool(trigger_tensor.item())\n\n                barrier()\n\n        if trigger_detected:\n            control.should_save = True\n            LOG.info(\n                f\"Saving dynamic checkpoint at step {state.global_step}\",\n            )\n        return control\n"
  },
  {
    "path": "src/axolotl/utils/callbacks/generation.py",
    "content": "\"\"\"Callback for generating samples during SFT/Pretrain training.\"\"\"\n\nfrom transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState\nfrom transformers.training_args import TrainingArguments\n\nfrom axolotl.utils.generation.sft import generate_samples\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\n\nclass SFTGenerationCallback(TrainerCallback):\n    \"\"\"Callback for generating samples during SFT/Pretrain training.\"\"\"\n\n    def __init__(self, trainer):\n        self.trainer = trainer\n\n    def on_evaluate(\n        self,\n        args: TrainingArguments,\n        state: TrainerState,\n        control: TrainerControl,\n        **kwargs,\n    ):\n        \"\"\"Generate samples at specified intervals.\"\"\"\n        cfg = self.trainer.axolotl_cfg\n\n        if not getattr(cfg, \"generate_samples\", False):\n            return\n\n            dataloader = None\n            try:\n                if getattr(self.trainer, \"eval_dataset\", None) is not None:\n                    dataloader = self.trainer.get_eval_dataloader()\n                    LOG.info(\n                        f\"Using eval dataloader for generation at step {state.global_step}\"\n                    )\n            except Exception as e:\n                LOG.warning(f\"Could not get eval dataloader: {e}\")\n                dataloader = None\n\n            if dataloader is None:\n                dataloader = self.trainer.get_train_dataloader()\n                LOG.info(\n                    f\"Using train dataloader for generation at step {state.global_step}\"\n                )\n\n            samples = generate_samples(\n                model=self.trainer.model,\n                tokenizer=self.trainer.processing_class,\n                dataloader=dataloader,\n                num_generation_samples=getattr(cfg, \"num_generation_samples\", 3),\n                max_new_tokens=getattr(cfg, \"generation_max_new_tokens\", 50),\n                temperature=getattr(cfg, \"generation_temperature\", 0.7),\n                top_p=getattr(cfg, \"generation_top_p\", None),\n                top_k=getattr(cfg, \"generation_top_k\", None),\n                do_sample=getattr(cfg, \"generation_do_sample\", True),\n                prompt_ratio=getattr(cfg, \"generation_prompt_ratio\", 0.5),\n            )\n            self._log_samples(samples, state.global_step)\n\n    def _log_samples(self, samples: list, step: int):\n        \"\"\"Log generated samples to console and W&B.\"\"\"\n        from axolotl.utils.generation.sft import format_generation_for_logging\n\n        for i, sample in enumerate(samples):\n            console_text, wandb_text = format_generation_for_logging(sample, i, step)\n\n            LOG.info(console_text)\n\n            try:\n                import wandb\n\n                if wandb.run is not None:\n                    wandb.log(\n                        {\n                            f\"samples/sample_{i + 1}\": wandb.Html(\n                                f\"<pre>{wandb_text}</pre>\"\n                            )\n                        },\n                        step=step,\n                    )\n            except (ImportError, Exception):\n                pass\n"
  },
  {
    "path": "src/axolotl/utils/callbacks/lisa.py",
    "content": "\"\"\"\nmodule for LISA\n\nAdapted from https://github.com/OptimalScale/LMFlow/pull/701 for HF transformers & Axolotl\nArxiv: https://arxiv.org/abs/2403.17919\nLicense: Apache 2.0\n\"\"\"\n\nfrom functools import reduce\nfrom typing import TYPE_CHECKING\n\nimport numpy as np\nfrom transformers import TrainerCallback\n\nfrom axolotl.utils.logging import get_logger\n\nif TYPE_CHECKING:\n    from axolotl.core.trainers import AxolotlTrainer\n\nLOG = get_logger(__name__)\n\n\ndef lisa_callback_factory(trainer: \"AxolotlTrainer\"):\n    class LISACallback(TrainerCallback):\n        \"\"\"trainer callback for lisa layer switching\"\"\"\n\n        def __init__(\n            self, n_layers, step_interval, trainer, layers_attribute=\"model.layers\"\n        ):\n            super().__init__()\n            self.n_layers = n_layers\n            self.step_interval = step_interval\n            self.layers_attribute = layers_attribute\n            self.trainer = trainer\n\n            reduce(getattr, self.layers_attribute.split(\".\"), self.trainer.model)\n\n            self.total_layers = len(\n                reduce(getattr, self.layers_attribute.split(\".\"), self.trainer.model)\n            )\n            self.active_layers_indices = []\n\n            layers = reduce(\n                getattr, self.layers_attribute.split(\".\"), self.trainer.model\n            )\n            LOG.info(\n                f\"LISA will activate {self.n_layers}/{len(layers)} layers ({self.n_layers * 100 / len(layers)}%) every {self.step_interval} steps\"\n            )\n\n        def freeze_all_layers(self):\n            layers = reduce(\n                getattr, self.layers_attribute.split(\".\"), self.trainer.model\n            )\n            for layer in layers:\n                for param in layer.parameters():\n                    param.requires_grad = False\n\n        def on_step_begin(self, args, state, control, **kwargs):\n            # Check if it's time to switch active layers, including at step 0\n            if state.global_step % self.step_interval == 0 or state.global_step == 1:\n                self.switch_active_layers()\n\n        def switch_active_layers(self):\n            # First, disable gradients for all layers\n            self.freeze_all_layers()\n\n            # Randomly select n_layers to activate\n            layers = reduce(\n                getattr, self.layers_attribute.split(\".\"), self.trainer.model\n            )\n            self.active_layers_indices = np.random.choice(\n                range(self.total_layers), self.n_layers, replace=False\n            )\n            LOG.info(\n                f\"Activating layers at indices: {self.active_layers_indices} for the next steps.\"\n            )\n\n            # Enable gradients only for the selected layers\n            for idx in self.active_layers_indices:\n                for param in layers[idx].parameters():\n                    param.requires_grad = True\n\n    lisa_callback = LISACallback(\n        n_layers=trainer.args.lisa_n_layers,\n        step_interval=trainer.args.lisa_step_interval,\n        trainer=trainer,\n        layers_attribute=trainer.args.lisa_layers_attribute,\n    )\n\n    return lisa_callback\n"
  },
  {
    "path": "src/axolotl/utils/callbacks/mlflow_.py",
    "content": "\"\"\"MLFlow module for trainer callbacks\"\"\"\n\nimport os\nfrom shutil import copyfile\nfrom tempfile import NamedTemporaryFile\nfrom typing import TYPE_CHECKING\n\nimport mlflow\nfrom transformers import TrainerCallback, TrainerControl, TrainerState\n\nfrom axolotl.utils.distributed import is_main_process\nfrom axolotl.utils.logging import get_logger\n\nif TYPE_CHECKING:\n    from axolotl.core.training_args import AxolotlTrainingArguments\n\nLOG = get_logger(__name__)\n\n\ndef should_log_artifacts() -> bool:\n    truths = [\"TRUE\", \"1\", \"YES\"]\n    return os.getenv(\"HF_MLFLOW_LOG_ARTIFACTS\", \"FALSE\").upper() in truths\n\n\nclass SaveAxolotlConfigtoMlflowCallback(TrainerCallback):\n    \"\"\"Callback to save axolotl config to mlflow\"\"\"\n\n    def __init__(self, axolotl_config_path):\n        self.axolotl_config_path = axolotl_config_path\n\n    def on_train_begin(\n        self,\n        args: \"AxolotlTrainingArguments\",\n        state: TrainerState,\n        control: TrainerControl,\n        **kwargs,\n    ):\n        if is_main_process():\n            try:\n                if should_log_artifacts():\n                    with NamedTemporaryFile(\n                        mode=\"w\", delete=False, suffix=\".yml\", prefix=\"axolotl_config_\"\n                    ) as temp_file:\n                        copyfile(self.axolotl_config_path, temp_file.name)\n                        mlflow.log_artifact(temp_file.name, artifact_path=\"\")\n                        LOG.info(\n                            \"The Axolotl config has been saved to the MLflow artifacts.\"\n                        )\n                else:\n                    LOG.info(\n                        \"Skipping logging artifacts to MLflow (hf_mlflow_log_artifacts is false)\"\n                    )\n            except (FileNotFoundError, ConnectionError) as err:\n                LOG.warning(f\"Error while saving Axolotl config to MLflow: {err}\")\n        return control\n"
  },
  {
    "path": "src/axolotl/utils/callbacks/models.py",
    "content": "\"\"\"Helper functions for model classes\"\"\"\n\nfrom typing import Tuple\n\nfrom transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES\n\n\ndef get_causal_lm_model_cls_prefix(model_type: str) -> Tuple[str, str]:\n    if model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:\n        causal_lm_cls = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type]\n        causal_lm_cls_prefix = causal_lm_cls\n        for suffix in [\n            \"ForCausalLM\",\n            \"ForConditionalGeneration\",\n            \"LMHeadModel\",\n            \"GenerationDecoder\",\n        ]:\n            causal_lm_cls_prefix = causal_lm_cls_prefix.replace(suffix, \"\")\n        return causal_lm_cls_prefix, causal_lm_cls\n    causal_lm_cls_prefix = \"\".join(\n        [part.capitalize() for part in model_type.split(\"_\")]\n    )\n    return causal_lm_cls_prefix, f\"{causal_lm_cls_prefix}ForCausalLM\"\n"
  },
  {
    "path": "src/axolotl/utils/callbacks/opentelemetry.py",
    "content": "\"\"\"OpenTelemetry metrics callback for Axolotl training\"\"\"\n\nimport threading\nfrom typing import Dict, Optional\n\nfrom transformers import (\n    TrainerCallback,\n    TrainerControl,\n    TrainerState,\n    TrainingArguments,\n)\n\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\ntry:\n    from opentelemetry import metrics\n    from opentelemetry.exporter.prometheus import PrometheusMetricReader\n    from opentelemetry.metrics import set_meter_provider\n    from opentelemetry.sdk.metrics import MeterProvider as SDKMeterProvider\n    from prometheus_client import start_http_server\n\n    OPENTELEMETRY_AVAILABLE = True\nexcept ImportError:\n    LOG.warning(\"OpenTelemetry not available. pip install [opentelemetry]\")\n    OPENTELEMETRY_AVAILABLE = False\n\n\nclass OpenTelemetryMetricsCallback(TrainerCallback):\n    \"\"\"\n    TrainerCallback that exports training metrics to OpenTelemetry/Prometheus.\n\n    This callback automatically tracks key training metrics including:\n    - Training loss\n    - Evaluation loss\n    - Learning rate\n    - Epoch progress\n    - Global step count\n    - Gradient norm\n\n    Metrics are exposed via HTTP endpoint for Prometheus scraping.\n    \"\"\"\n\n    def __init__(self, cfg):\n        if not OPENTELEMETRY_AVAILABLE:\n            LOG.warning(\"OpenTelemetry not available, metrics will not be collected\")\n            self.metrics_enabled = False\n            return\n\n        self.cfg = cfg\n        self.metrics_host = getattr(cfg, \"otel_metrics_host\", \"localhost\")\n        self.metrics_port = getattr(cfg, \"otel_metrics_port\", 8000)\n        self.metrics_enabled = True\n        self.server_started = False\n        self.metrics_lock = threading.Lock()\n\n        try:\n            # Create Prometheus metrics reader\n            prometheus_reader = PrometheusMetricReader()\n\n            # Create meter provider with Prometheus exporter\n            provider = SDKMeterProvider(metric_readers=[prometheus_reader])\n            set_meter_provider(provider)\n\n            # Get meter for creating metrics\n            self.meter = metrics.get_meter(\"axolotl.training\")\n\n            # Create metrics\n            self._create_metrics()\n\n        except Exception as e:\n            LOG.warning(f\"Failed to initialize OpenTelemetry metrics: {e}\")\n            self.metrics_enabled = False\n\n    def _create_metrics(self):\n        \"\"\"Create all metrics that will be tracked\"\"\"\n        self.train_loss_gauge = self.meter.create_gauge(\n            name=\"axolotl_train_loss\",\n            description=\"Current training loss\",\n            unit=\"1\",\n        )\n\n        self.eval_loss_gauge = self.meter.create_gauge(\n            name=\"axolotl_eval_loss\",\n            description=\"Current evaluation loss\",\n            unit=\"1\",\n        )\n\n        self.learning_rate_gauge = self.meter.create_gauge(\n            name=\"axolotl_learning_rate\",\n            description=\"Current learning rate\",\n            unit=\"1\",\n        )\n\n        self.epoch_gauge = self.meter.create_gauge(\n            name=\"axolotl_epoch\",\n            description=\"Current training epoch\",\n            unit=\"1\",\n        )\n\n        self.global_step_counter = self.meter.create_counter(\n            name=\"axolotl_global_steps\",\n            description=\"Total training steps completed\",\n            unit=\"1\",\n        )\n\n        self.grad_norm_gauge = self.meter.create_gauge(\n            name=\"axolotl_gradient_norm\",\n            description=\"Gradient norm\",\n            unit=\"1\",\n        )\n\n        self.memory_usage_gauge = self.meter.create_gauge(\n            name=\"axolotl_memory_usage\",\n            description=\"Current memory usage in MB\",\n            unit=\"MB\",\n        )\n\n    def _start_metrics_server(self):\n        \"\"\"Start the HTTP server for metrics exposure\"\"\"\n        if self.server_started:\n            return\n\n        try:\n            start_http_server(self.metrics_port, addr=self.metrics_host)\n            self.server_started = True\n            LOG.info(\n                f\"OpenTelemetry metrics server started on http://{self.metrics_host}:{self.metrics_port}/metrics\"\n            )\n\n        except Exception as e:\n            LOG.error(f\"Failed to start OpenTelemetry metrics server: {e}\")\n\n    def on_train_begin(\n        self,\n        args: TrainingArguments,\n        state: TrainerState,\n        control: TrainerControl,\n        **kwargs,\n    ):\n        \"\"\"Called at the beginning of training\"\"\"\n        if not self.metrics_enabled:\n            return\n\n        self._start_metrics_server()\n        LOG.info(\"OpenTelemetry metrics collection started\")\n\n    def on_log(\n        self,\n        args: TrainingArguments,\n        state: TrainerState,\n        control: TrainerControl,\n        logs: Optional[Dict[str, float]] = None,\n        **kwargs,\n    ):\n        \"\"\"Called when logging occurs\"\"\"\n        if not self.metrics_enabled or not logs:\n            return\n\n        if \"loss\" in logs:\n            self.train_loss_gauge.set(logs[\"loss\"])\n\n        if \"eval_loss\" in logs:\n            self.eval_loss_gauge.set(logs[\"eval_loss\"])\n\n        if \"learning_rate\" in logs:\n            self.learning_rate_gauge.set(logs[\"learning_rate\"])\n\n        if \"epoch\" in logs:\n            self.epoch_gauge.set(logs[\"epoch\"])\n\n        if \"grad_norm\" in logs:\n            self.grad_norm_gauge.set(logs[\"grad_norm\"])\n        if \"memory_usage\" in logs:\n            self.memory_usage_gauge.set(logs[\"memory_usage\"])\n\n    def on_step_end(\n        self,\n        args: TrainingArguments,\n        state: TrainerState,\n        control: TrainerControl,\n        **kwargs,\n    ):\n        \"\"\"Called at the end of each training step\"\"\"\n        if not self.metrics_enabled:\n            return\n\n        # Update step counter and epoch\n        self.global_step_counter.add(1)\n        if state.epoch is not None:\n            self.epoch_gauge.set(state.epoch)\n\n    def on_evaluate(\n        self,\n        args: TrainingArguments,\n        state: TrainerState,\n        control: TrainerControl,\n        metrics: Optional[Dict[str, float]] = None,\n        **kwargs,\n    ):\n        \"\"\"Called after evaluation\"\"\"\n        if not self.metrics_enabled or not metrics:\n            return\n\n        if \"eval_loss\" in metrics:\n            self.eval_loss_gauge.set(metrics[\"eval_loss\"])\n\n        # Record any other eval metrics as gauges\n        for key, value in metrics.items():\n            if key.startswith(\"eval_\") and isinstance(value, (int, float)):\n                # Create gauge for this metric if it doesn't exist\n                gauge_name = f\"axolotl_{key}\"\n                try:\n                    gauge = self.meter.create_gauge(\n                        name=gauge_name,\n                        description=f\"Evaluation metric: {key}\",\n                        unit=\"1\",\n                    )\n                    gauge.set(value)\n                except Exception as e:\n                    LOG.warning(f\"Failed to create/update metric {gauge_name}: {e}\")\n\n    def on_train_end(\n        self,\n        args: TrainingArguments,\n        state: TrainerState,\n        control: TrainerControl,\n        **kwargs,\n    ):\n        \"\"\"Called at the end of training\"\"\"\n        if not self.metrics_enabled:\n            return\n\n        LOG.info(\"Training completed. OpenTelemetry metrics collection finished.\")\n        LOG.info(\n            f\"Metrics are still available at http://{self.metrics_host}:{self.metrics_port}/metrics\"\n        )\n"
  },
  {
    "path": "src/axolotl/utils/callbacks/perplexity.py",
    "content": "\"\"\"callback to calculate perplexity as an evaluation metric.\"\"\"\n\nfrom typing import Dict, List, Optional\n\nimport torch\nfrom torch import Tensor\nfrom tqdm import tqdm\nfrom transformers.modeling_outputs import CausalLMOutput\nfrom transformers.modeling_utils import PreTrainedModel\n\ntry:\n    from transformers.tokenization_python import PreTrainedTokenizer\nexcept ImportError:\n    from transformers.tokenization_utils import PreTrainedTokenizer\n\nfrom axolotl.utils.distributed import is_main_process\n\n\nclass Perplexity:\n    \"\"\"\n    Calculate perplexity as defined in https://huggingface.co/docs/transformers/en/perplexity.\n    This is a custom variant that doesn't re-tokenize the input or re-load the model.\n    \"\"\"\n\n    def __init__(\n        self,\n        tokenizer: PreTrainedTokenizer,\n        max_seq_len: int,\n        stride: int = 512,\n    ) -> None:\n        self.max_seq_len = max_seq_len\n        self.stride = stride\n        self.tokenizer = tokenizer\n        self.name = \"perplexity\"\n\n    def _feature_names(self) -> List[str]:\n        return [\"references\"]\n\n    def compute(\n        self,\n        model: PreTrainedModel,\n        references: Optional[List[str]] = None,\n    ) -> Dict[str, float]:\n        \"\"\"\n        Compute perplexity in a fixed length sliding window across the sequence.\n        \"\"\"\n        assert references is not None, \"Missing parameter: references\"\n\n        model.eval()\n\n        references_tokenized = self.tokenizer(\n            references, return_tensors=\"pt\", padding=True, truncation=True\n        )\n        input_ids: Tensor = references_tokenized[\"input_ids\"]  # type: ignore\n        input_ids = input_ids.to(model.device)\n\n        sequence_length = input_ids.size(1)\n\n        losses = []\n        prev_end_loc = 0\n        for begin_loc in tqdm(\n            range(0, sequence_length, self.stride), disable=not is_main_process()\n        ):\n            end_loc = min(begin_loc + self.max_seq_len, sequence_length)\n            trg_len = end_loc - prev_end_loc\n            input_ids_slice = input_ids[:, begin_loc:end_loc]\n            labels_slice = input_ids_slice.clone()\n            labels_slice[:, :-trg_len] = -100\n\n            with torch.no_grad():\n                outputs: CausalLMOutput = model(\n                    input_ids=input_ids_slice, labels=labels_slice\n                )\n\n            losses.append(outputs.loss)\n\n            prev_end_loc = end_loc\n            if end_loc == sequence_length:\n                break\n\n        perplexity = torch.exp(torch.stack(losses).mean()).item()\n\n        return {\n            \"score\": perplexity,\n        }\n"
  },
  {
    "path": "src/axolotl/utils/callbacks/profiler.py",
    "content": "\"\"\"\nHF Trainer callback for creating pytorch profiling snapshots\n\"\"\"\n\nfrom pathlib import Path\nfrom pickle import dump  # nosec B403\n\nimport torch\nfrom transformers import (\n    TrainerCallback,\n    TrainerControl,\n    TrainerState,\n    TrainingArguments,\n)\n\n\nclass PytorchProfilerCallback(TrainerCallback):\n    \"\"\"\n    PyTorch Profiler callback to create snapshots of GPU memory usage at specified steps.\n\n    Also runs torch.profiler to produce a Chrome trace for timing analysis.\n    \"\"\"\n\n    def __init__(self, steps_to_profile: int = 5, profiler_steps_start: int = 0):\n        # steps are 0 indexed, so to start at 0-th step, we start at beginning of first step,\n        # and finish at end of last step, so 5 steps_to_profile is steps [0, 1, 2, 3, 4]\n        self.profiler_steps_end = profiler_steps_start + steps_to_profile - 1\n        if profiler_steps_start == 0:\n            # start recording memory allocations before everything is allocated, because if we start\n            # at the beginning of step 0, we won't have any memory allocations in the traces\n            torch.cuda.memory._record_memory_history(enabled=\"all\", stacks=\"all\")\n            profiler_steps_start = -1\n        self.profiler_steps_start = profiler_steps_start\n        self._profiler = None\n\n    def on_step_begin(\n        self,\n        args: TrainingArguments,\n        state: TrainerState,\n        control: TrainerControl,\n        **kwargs,\n    ):\n        if state.global_step == self.profiler_steps_start:\n            torch.cuda.memory._record_memory_history(enabled=\"all\", stacks=\"all\")\n\n        # Start torch.profiler on the first profiled step\n        if state.global_step == max(self.profiler_steps_start, 0):\n            profiler = torch.profiler.profile(\n                activities=[\n                    torch.profiler.ProfilerActivity.CPU,\n                    torch.profiler.ProfilerActivity.CUDA,\n                ],\n                record_shapes=True,\n                profile_memory=True,\n                with_stack=True,\n            )\n            profiler.__enter__()\n            self._profiler = profiler\n\n    def on_step_end(\n        self,\n        args: TrainingArguments,\n        state: TrainerState,\n        control: TrainerControl,\n        **kwargs,\n    ):\n        if state.global_step == self.profiler_steps_end:\n            snapshot = torch.cuda.memory._snapshot()\n            with open(Path(args.output_dir) / \"snapshot.pickle\", \"wb\") as fout:\n                dump(snapshot, fout)\n\n            # tell CUDA to stop recording memory allocations now\n            torch.cuda.memory._record_memory_history(enabled=None)\n\n            # Stop and export torch.profiler trace\n            if self._profiler is not None:\n                self._profiler.__exit__(None, None, None)\n                trace_path = Path(args.output_dir) / \"profiler_trace.json\"\n                self._profiler.export_chrome_trace(str(trace_path))\n                self._profiler = None\n\n    def on_train_end(\n        self,\n        args: TrainingArguments,\n        state: TrainerState,\n        control: TrainerControl,\n        **kwargs,\n    ):\n        # make sure to record if we happen to have more steps than steps to profile\n        if (\n            state.global_step >= self.profiler_steps_start\n            and state.global_step < self.profiler_steps_end\n        ):\n            snapshot = torch.cuda.memory._snapshot()\n            with open(Path(args.output_dir) / \"snapshot.pickle\", \"wb\") as fout:\n                dump(snapshot, fout)\n\n            # tell CUDA to stop recording memory allocations now\n            torch.cuda.memory._record_memory_history(enabled=None)\n\n        if self._profiler is not None:\n            self._profiler.__exit__(None, None, None)\n            trace_path = Path(args.output_dir) / \"profiler_trace.json\"\n            self._profiler.export_chrome_trace(str(trace_path))\n            self._profiler = None\n"
  },
  {
    "path": "src/axolotl/utils/callbacks/qat.py",
    "content": "\"\"\"QAT Callback for HF Causal Trainer\"\"\"\n\nfrom functools import partial\n\nfrom torch import nn\nfrom torchao.quantization.qat.embedding import FakeQuantizedEmbedding\nfrom torchao.quantization.qat.linear import FakeQuantizedLinear\nfrom transformers import TrainerCallback\n\nfrom axolotl.utils.logging import get_logger\nfrom axolotl.utils.schemas.quantization import QATConfig\n\nLOG = get_logger(__name__)\n\n\ndef toggle_fake_quant(mod: nn.Module, enable: bool):\n    \"\"\"\n    Toggle fake quantization for any fake quantized linear or embedding layers in the model.\n\n    Args:\n        mod: The module to toggle fake quantization for.\n        enable: Whether to enable or disable fake quantization.\n    \"\"\"\n    if isinstance(mod, (FakeQuantizedLinear, FakeQuantizedEmbedding)):\n        if (\n            isinstance(mod, FakeQuantizedLinear)\n            and mod.activation_fake_quantizer is not None\n        ):\n            mod.activation_fake_quantizer.enabled = enable\n        mod.weight_fake_quantizer.enabled = enable\n\n\nclass QATCallback(TrainerCallback):\n    \"\"\"\n    Callback to toggle fake quantization for the model.\n    \"\"\"\n\n    def __init__(self, cfg: QATConfig):\n        self.cfg = cfg\n\n    def on_step_begin(self, args, state, control, model, **kwargs):\n        if self.cfg.fake_quant_after_n_steps is not None:\n            if state.global_step == 0:\n                LOG.info(f\"Disabling fake quantization at step {state.global_step}\")\n                model.apply(partial(toggle_fake_quant, enable=False))\n            elif state.global_step == self.cfg.fake_quant_after_n_steps:\n                LOG.info(f\"Enabling fake quantization at step {state.global_step}\")\n                model.apply(partial(toggle_fake_quant, enable=True))\n"
  },
  {
    "path": "src/axolotl/utils/callbacks/swanlab.py",
    "content": "\"\"\"Callbacks for SwanLab integration\"\"\"\n\nfrom __future__ import annotations\n\nimport json\nimport os\nfrom shutil import copyfile\nfrom tempfile import NamedTemporaryFile\nfrom typing import TYPE_CHECKING\n\nfrom transformers import (\n    TrainerCallback,\n    TrainerControl,\n    TrainerState,\n    TrainingArguments,\n)\n\nfrom axolotl.utils.logging import get_logger\n\nif TYPE_CHECKING:\n    from axolotl.core.training_args import AxolotlTrainingArguments\n\nLOG = get_logger(__name__)\n\n\nclass CustomSwanLabCallback(TrainerCallback):\n    \"\"\"\n    Lightweight SwanLab callback that directly logs metrics without using\n    SwanLab's transformers integration (which requires omegaconf).\n\n    This avoids the antlr4 version conflict between omegaconf and axolotl.\n    \"\"\"\n\n    def __init__(self):\n        self._initialized = False\n        self.swanlab = None\n\n    def setup(self):\n        \"\"\"Lazy initialization of SwanLab\"\"\"\n        if self._initialized:\n            return\n\n        try:\n            import swanlab\n\n            self.swanlab = swanlab\n\n            # Check if SwanLab run is initialized\n            if swanlab.get_run() is None:\n                LOG.warning(\"SwanLab run is not initialized\")\n                return\n\n            self._initialized = True\n            LOG.info(\"CustomSwanLabCallback initialized successfully\")\n        except ImportError:\n            LOG.error(\"SwanLab is not installed\")\n\n    def on_train_begin(\n        self,\n        args: TrainingArguments,\n        state: TrainerState,\n        control: TrainerControl,\n        **kwargs,\n    ):\n        \"\"\"Called at the beginning of training\"\"\"\n        if not state.is_world_process_zero:\n            return control\n\n        self.setup()\n\n        if not self._initialized:\n            return control\n\n        # Log training configuration\n        try:\n            self.swanlab.config.update(\n                {\n                    \"train_batch_size\": args.per_device_train_batch_size,\n                    \"eval_batch_size\": args.per_device_eval_batch_size,\n                    \"learning_rate\": args.learning_rate,\n                    \"num_train_epochs\": args.num_train_epochs,\n                    \"max_steps\": args.max_steps,\n                    \"warmup_steps\": args.warmup_steps,\n                    \"logging_steps\": args.logging_steps,\n                    \"save_steps\": args.save_steps,\n                    \"gradient_accumulation_steps\": args.gradient_accumulation_steps,\n                }\n            )\n            LOG.debug(\"Training configuration logged to SwanLab\")\n        except Exception as err:\n            LOG.warning(f\"Failed to log training config: {err}\")\n\n        return control\n\n    def on_log(\n        self,\n        args: TrainingArguments,\n        state: TrainerState,\n        control: TrainerControl,\n        logs=None,\n        **kwargs,\n    ):\n        \"\"\"Called when logging metrics\"\"\"\n        if not state.is_world_process_zero:\n            return control\n\n        if not self._initialized:\n            self.setup()\n\n        if not self._initialized or logs is None:\n            return control\n\n        # Log metrics to SwanLab\n        try:\n            # Filter out non-numeric values and prepare for logging\n            metrics = {}\n            for key, value in logs.items():\n                if isinstance(value, (int, float)):\n                    # Use step from state\n                    metrics[key] = value\n\n            if metrics and state.global_step is not None:\n                self.swanlab.log(metrics, step=state.global_step)\n        except Exception as err:\n            LOG.warning(f\"Failed to log metrics to SwanLab: {err}\")\n\n        return control\n\n    def on_train_end(\n        self,\n        args: TrainingArguments,\n        state: TrainerState,\n        control: TrainerControl,\n        **kwargs,\n    ):\n        \"\"\"Called at the end of training\"\"\"\n        if not state.is_world_process_zero:\n            return control\n\n        if self._initialized:\n            LOG.info(\"Training completed. SwanLab logs are available.\")\n\n        return control\n\n\nclass SaveAxolotlConfigtoSwanLabCallback(TrainerCallback):\n    \"\"\"Callback to save axolotl config to SwanLab\"\"\"\n\n    def __init__(self, axolotl_config_path):\n        self.axolotl_config_path = axolotl_config_path\n\n    def on_train_begin(\n        self,\n        args: AxolotlTrainingArguments,\n        state: TrainerState,\n        control: TrainerControl,\n        **kwargs,\n    ):\n        if state.is_world_process_zero:\n            try:\n                import swanlab\n\n                # Check if SwanLab is initialized\n                if swanlab.get_run() is None:\n                    LOG.warning(\n                        \"SwanLab run is not initialized. Please initialize SwanLab before training.\"\n                    )\n                    return control\n\n                # Log Axolotl config as artifact\n                with NamedTemporaryFile(\n                    mode=\"w\", delete=False, suffix=\".yml\", prefix=\"axolotl_config_\"\n                ) as temp_file:\n                    copyfile(self.axolotl_config_path, temp_file.name)\n\n                    # Log config file to SwanLab\n                    with open(temp_file.name, \"r\", encoding=\"utf-8\") as config_file:\n                        swanlab.log(\n                            {\n                                \"axolotl_config\": swanlab.Text(\n                                    config_file.read(), caption=\"Axolotl Config\"\n                                )\n                            }\n                        )\n\n                    LOG.info(\n                        \"The Axolotl config has been saved to the SwanLab run under logs.\"\n                    )\n\n                    # Clean up temp file\n                    os.unlink(temp_file.name)\n\n            except ImportError:\n                LOG.warning(\n                    \"SwanLab is not installed. Install it with: pip install swanlab\"\n                )\n            except (FileNotFoundError, ConnectionError) as err:\n                LOG.warning(f\"Error while saving Axolotl config to SwanLab: {err}\")\n\n            # Log DeepSpeed config if available\n            if args.deepspeed:\n                try:\n                    import swanlab\n\n                    with NamedTemporaryFile(\n                        mode=\"w\",\n                        delete=False,\n                        suffix=\".json\",\n                        prefix=\"deepspeed_config_\",\n                    ) as temp_file:\n                        skip_upload = False\n                        if isinstance(args.deepspeed, dict):\n                            json.dump(args.deepspeed, temp_file, indent=4)\n                        elif isinstance(args.deepspeed, str) and os.path.exists(\n                            args.deepspeed\n                        ):\n                            copyfile(args.deepspeed, temp_file.name)\n                        else:\n                            skip_upload = True\n\n                        if not skip_upload:\n                            temp_file.flush()\n                            with open(\n                                temp_file.name, \"r\", encoding=\"utf-8\"\n                            ) as ds_config_file:\n                                swanlab.log(\n                                    {\n                                        \"deepspeed_config\": swanlab.Text(\n                                            ds_config_file.read(),\n                                            caption=\"DeepSpeed Config\",\n                                        )\n                                    }\n                                )\n                            LOG.info(\n                                \"The DeepSpeed config has been saved to the SwanLab run under logs.\"\n                            )\n\n                        # Clean up temp file\n                        os.unlink(temp_file.name)\n\n                except (FileNotFoundError, ConnectionError) as err:\n                    LOG.warning(\n                        f\"Error while saving DeepSpeed config to SwanLab: {err}\"\n                    )\n                except ImportError:\n                    pass\n\n        return control\n"
  },
  {
    "path": "src/axolotl/utils/callbacks/tokens_per_second.py",
    "content": "\"\"\"A callback for calculating tokens per second during training.\"\"\"\n\nimport json\nimport os\nimport time\n\nimport torch\nfrom transformers import (\n    TrainerCallback,\n    TrainerControl,\n    TrainerState,\n    TrainingArguments,\n)\n\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\nTOKENS_STATE_FILE = \"tokens_state.json\"\n\n\nclass TokensPerSecondCallback(TrainerCallback):\n    \"\"\"\n    A callback to measure and log tokens per second during training.\n    Also handles saving/restoring total_tokens state across checkpoint resumes.\n    \"\"\"\n\n    def __init__(\n        self, tensor_parallel_size, context_parallel_size, resume_from_checkpoint=None\n    ):\n        super().__init__()\n        self.step_time = 0.0\n        self.start_time = 0.0\n        self.non_data_parallel_size = 1\n        self.resume_from_checkpoint = resume_from_checkpoint\n        if tensor_parallel_size is not None:\n            self.non_data_parallel_size *= tensor_parallel_size\n        if context_parallel_size is not None:\n            self.non_data_parallel_size *= context_parallel_size\n\n    def on_train_begin(\n        self,\n        args: TrainingArguments,\n        state: TrainerState,\n        control: TrainerControl,\n        **kwargs,\n    ):  # pylint: disable=unused-argument\n        \"\"\"Restore total_tokens state when resuming from checkpoint.\"\"\"\n        if not isinstance(self.resume_from_checkpoint, str):\n            return\n        tokens_state_path = os.path.join(self.resume_from_checkpoint, TOKENS_STATE_FILE)\n        if os.path.isfile(tokens_state_path):\n            with open(tokens_state_path, \"r\", encoding=\"utf-8\") as f:\n                tokens_state = json.load(f)\n            state.tokens = {\n                \"total\": torch.tensor(tokens_state.get(\"total\", 0)),\n                \"trainable\": torch.tensor(tokens_state.get(\"trainable\", 0)),\n            }\n            LOG.info(f\"Restored total_tokens: {state.tokens['total']}\")\n\n    def on_step_begin(\n        self,\n        args: TrainingArguments,\n        state: TrainerState,\n        control: TrainerControl,\n        **kwargs,\n    ):  # pylint: disable=unused-argument\n        if not hasattr(state, \"tokens\"):\n            state.tokens = {\"trainable\": torch.zeros(1), \"total\": torch.zeros(1)}\n        self.start_time = time.perf_counter()\n        state.last_tokens_per_second = torch.zeros(1)\n\n    def on_step_end(\n        self,\n        args: TrainingArguments,\n        state: TrainerState,\n        control: TrainerControl,\n        **kwargs,\n    ):  # pylint: disable=unused-argument\n        tokens = getattr(state, \"tokens\", None)\n        if not (tokens and \"trainable_tokens\" in tokens):\n            return\n        step_time = time.perf_counter() - self.start_time\n        if step_time <= 0:\n            return\n\n        num_tokens = tokens[\"trainable_tokens\"].clone() / self.non_data_parallel_size\n        if torch.distributed.is_initialized():\n            dp_size = max(\n                1, torch.distributed.get_world_size() // self.non_data_parallel_size\n            )\n            num_tokens = num_tokens / dp_size\n        state.last_tokens_per_second = num_tokens / step_time\n\n    def on_log(\n        self,\n        args: TrainingArguments,\n        state: TrainerState,\n        control: TrainerControl,\n        logs=None,\n        **kwargs,\n    ):  # pylint: disable=unused-argument\n        # after logging, clear the running metrics\n        if hasattr(state, \"last_tokens_per_second\"):\n            logs[\"tokens/train_per_sec_per_gpu\"] = state.last_tokens_per_second.item()\n            state.last_tokens_per_second.zero_()\n        tokens = getattr(state, \"tokens\", None)\n        # Clear per-step tokens after logging\n        if tokens and \"trainable_tokens\" in tokens:\n            tokens[\"trainable_tokens\"] = torch.zeros_like(tokens[\"trainable_tokens\"])\n"
  },
  {
    "path": "src/axolotl/utils/callbacks/trackio_.py",
    "content": "\"\"\"Trackio module for trainer callbacks\"\"\"\n\nfrom typing import TYPE_CHECKING\n\nimport trackio\nfrom transformers import TrainerCallback, TrainerControl, TrainerState\n\nfrom axolotl.utils.distributed import is_main_process\nfrom axolotl.utils.environment import is_package_version_ge\nfrom axolotl.utils.logging import get_logger\n\nif TYPE_CHECKING:\n    from axolotl.core.training_args import AxolotlTrainingArguments\n\nLOG = get_logger(__name__)\n\n\nclass SaveAxolotlConfigtoTrackioCallback(TrainerCallback):\n    \"\"\"Callback for trackio integration\"\"\"\n\n    def __init__(self, axolotl_config_path):\n        self.axolotl_config_path = axolotl_config_path\n\n    def on_train_begin(\n        self,\n        args: \"AxolotlTrainingArguments\",\n        state: TrainerState,\n        control: TrainerControl,\n        **kwargs,\n    ):\n        if is_main_process():\n            try:\n                if not is_package_version_ge(\"trackio\", \"0.11.0\"):\n                    LOG.warning(\n                        \"Trackio version 0.11.0 or higher is required to save config files. \"\n                        \"Please upgrade trackio: pip install --upgrade trackio\"\n                    )\n                    return control\n\n                trackio.save(self.axolotl_config_path)\n                LOG.info(\"The Axolotl config has been saved to Trackio.\")\n            except (FileNotFoundError, ConnectionError, AttributeError) as err:\n                LOG.warning(f\"Error while saving Axolotl config to Trackio: {err}\")\n        return control\n"
  },
  {
    "path": "src/axolotl/utils/chat_templates/__init__.py",
    "content": "\"\"\"\nThis module provides functionality for selecting chat templates based on user choices.\nThese templates are used for formatting messages in a conversation.\n\"\"\"\n\nfrom .base import (\n    _CHAT_TEMPLATES,\n    extract_chat_template_args,\n    get_chat_template,\n    get_chat_template_from_config,\n    register_chat_template,\n)\n\n__all__ = [\n    \"get_chat_template\",\n    \"extract_chat_template_args\",\n    \"get_chat_template_from_config\",\n    \"register_chat_template\",\n    \"_CHAT_TEMPLATES\",\n]\n"
  },
  {
    "path": "src/axolotl/utils/chat_templates/base.py",
    "content": "\"\"\"\nutility functions for chat templates\n\"\"\"\n\nimport os\nfrom typing import TYPE_CHECKING, Any, Dict, Optional\n\nfrom axolotl.utils.logging import get_logger\n\nif TYPE_CHECKING:\n    from transformers import PreTrainedTokenizerBase\n\nLOG = get_logger(\"axolotl.utils.chat_templates\")\n\n_JINJA_TEMPLATE_CHOICE = \"jinja\"\n_DEFAULT_TEMPLATE_CHOICE = \"tokenizer_default\"\n_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX = \"tokenizer_default_fallback_\"\n\nTEMPLATE_DIR = os.path.join(os.path.dirname(__file__), \"templates\")\n_CHAT_TEMPLATES: dict[str, str] = {}\nfor filename in [f for f in os.listdir(TEMPLATE_DIR) if f.endswith(\".jinja\")]:\n    with open(os.path.join(TEMPLATE_DIR, filename), \"r\", encoding=\"utf-8\") as f:\n        _CHAT_TEMPLATES[filename[:-6]] = f.read()\n\n\ndef get_chat_template(\n    user_choice: str,\n    jinja_template: str | None = None,\n    tokenizer: Optional[\"PreTrainedTokenizerBase\"] = None,\n) -> str:\n    \"\"\"\n    Finds the correct chat_template based on the user's choice, jinja_template, and tokenizer.\n\n    Args:\n        user_choice (str): The user's choice of template.\n        jinja_template (str, optional): The jinja template string or Path to a valid jinja template file. Defaults to None.\n        tokenizer (PreTrainedTokenizerBase, optional): The tokenizer. Defaults to None.\n\n    Returns:\n        str: The chosen template string.\n\n    Raises:\n        ValueError: If the user_choice is not found in the templates.\n    \"\"\"\n    if user_choice == _JINJA_TEMPLATE_CHOICE:\n        if not jinja_template:\n            raise ValueError(\n                f\"`jinja_template` cannot be None when `chat_template` choice is {_JINJA_TEMPLATE_CHOICE}\"\n            )\n        if os.path.exists(jinja_template) and os.path.isfile(jinja_template):\n            with open(jinja_template, \"r\", encoding=\"utf-8\") as file:\n                jinja_template = file.read()\n        return jinja_template\n\n    if user_choice == _DEFAULT_TEMPLATE_CHOICE:\n        if not tokenizer:\n            raise ValueError(\n                f\"`tokenizer` cannot be None when chat_template choice is {_DEFAULT_TEMPLATE_CHOICE}\"\n            )\n        if not tokenizer.chat_template:\n            raise ValueError(\n                f\"`chat_template choice is {_DEFAULT_TEMPLATE_CHOICE} but tokenizer's chat_template is null. \"\n                f\"Please add a chat_template in tokenizer config\"\n            )\n        return tokenizer.chat_template  # type: ignore\n\n    if user_choice.startswith(_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX):\n        if not tokenizer:\n            raise ValueError(\n                f\"`tokenizer` cannot be None when chat_template choice starts with {_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX}\"\n            )\n        if tokenizer.chat_template:\n            return tokenizer.chat_template  # type: ignore\n\n        user_choice = user_choice[\n            len(_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX) :\n        ]\n        LOG.warning(\n            f\"No chat template found on tokenizer, falling back to {user_choice}. It is recommended to set --train_on_inputs to True for the model to learn this chat template.\"\n        )\n\n    if user_choice in _CHAT_TEMPLATES:\n        return _CHAT_TEMPLATES[user_choice]\n\n    raise ValueError(f\"Template '{user_choice}' not found.\")\n\n\ndef extract_chat_template_args(cfg, ds_cfg: Dict[str, Any] | None = None):\n    if ds_cfg and ds_cfg.get(\"chat_template\"):\n        chat_template_choice = ds_cfg.get(\"chat_template\") or _DEFAULT_TEMPLATE_CHOICE\n        chat_template_jinja = ds_cfg.get(\"chat_template_jinja\")\n    else:\n        chat_template_choice = cfg.get(\"chat_template\") or _DEFAULT_TEMPLATE_CHOICE\n        chat_template_jinja = cfg.get(\"chat_template_jinja\")\n    return chat_template_choice, chat_template_jinja\n\n\ndef get_chat_template_from_config(\n    cfg,\n    ds_cfg: Dict[str, Any] | None = None,\n    tokenizer: Optional[\"PreTrainedTokenizerBase\"] = None,\n) -> str:\n    chat_template_choice, chat_template_jinja = extract_chat_template_args(\n        cfg=cfg, ds_cfg=ds_cfg\n    )\n    return get_chat_template(\n        user_choice=chat_template_choice,\n        jinja_template=chat_template_jinja,\n        tokenizer=tokenizer,\n    )\n\n\ndef register_chat_template(template_name: str, chat_template: str):\n    \"\"\"\n    Registers chat templates.\n\n    Args:\n        template_name (str): The name of the template.\n        chat_template (str): The template string.\n    \"\"\"\n\n    if template_name in _CHAT_TEMPLATES:\n        raise ValueError(f\"Template '{template_name}' already exists.\")\n\n    _CHAT_TEMPLATES[template_name] = chat_template\n"
  },
  {
    "path": "src/axolotl/utils/chat_templates/templates/alpaca.jinja",
    "content": "{{ bos_token }}{% for message in messages %}{% if message['role'] == 'system' and loop.first %}{{ message['content'] }}{% elif message['role'] == 'user' %}{{ '### Instruction:\n' + message['content'] }}{% elif message['role'] == 'assistant' %}{{ '### Response:\n' + message['content'] + eos_token }}{% endif %}{% if not loop.last %}{{ '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '\n\n### Response:\n' }}{% endif %}\n"
  },
  {
    "path": "src/axolotl/utils/chat_templates/templates/aya.jinja",
    "content": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Aya, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>'  + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}\n"
  },
  {
    "path": "src/axolotl/utils/chat_templates/templates/chatml.jinja",
    "content": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}\n"
  },
  {
    "path": "src/axolotl/utils/chat_templates/templates/cohere.jinja",
    "content": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>'  + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}\n"
  },
  {
    "path": "src/axolotl/utils/chat_templates/templates/command_a.jinja",
    "content": "{{ bos_token }}{% if documents %}\n{% set tools = [] %}\n{%- macro document_turn(documents) -%}\n{# format documents into chat turn #}\n<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|><|START_THINKING|>I will look through the document to address the users needs.<|END_THINKING|><|START_ACTION|>[\n    {\"tool_call_id\": \"0\", \"tool_name\": \"direct-injected-document\", \"parameters\": {}}\n]<|END_ACTION|><|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|><|START_TOOL_RESULT|>[\n    {\n        \"tool_call_id\": \"0\",\n        \"results\": {\n{% for doc in documents %}\n            \"{{ loop.index0 }}\": {{doc|tojson}}{% if not loop.last %},\n            {% endif %}\n{% endfor %}\n\n        },\n        \"is_error\": null\n    }\n]<|END_TOOL_RESULT|><|END_OF_TURN_TOKEN|>{%- endmacro %}\n{%- macro tool_call_id_to_int(messages, tool_call_id) %}\n{%- set counter = namespace(value=0) %}\n{%- set tool_call_id_seen = namespace(value=false) %}\n{%- for msg in messages %}\n    {%- if msg.tool_calls %}\n        {%- for tool_call in msg.tool_calls %}\n            {%- if tool_call.id == tool_call_id and not tool_call_id_seen.value -%}\n                {{ counter.value }}\n                {%- set tool_call_id_seen.value = true %}\n            {%- endif %}\n            {%- set counter.value = counter.value + 1 %}\n        {%- endfor %}\n    {%- endif %}\n{%- endfor %}\n{%- endmacro %}\n{%- macro format_tool_message(messages, tool_msg) -%}\n{# format tool message #}\n    {\n        \"tool_call_id\": \"{{ tool_call_id_to_int(messages, tool_msg.tool_call_id) }}\",\n        \"results\": {\n            \"0\": {{ tool_msg.content|tojson }}\n        },\n        \"is_error\": null\n    }\n{%- endmacro -%}\n{%- if messages and messages[0]['role']|lower == 'system' %}{%- set developer_preamble = messages[0]['content'] %}{% endif %}\n{%- set tool_idx = namespace(value=0) %}\n{%- set tool_ids_seen = namespace(value=[]) %}\n{%- set sent_documents = namespace(value=false) %}\n<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|># System Preamble\nYou are in contextual safety mode. You will reject requests to generate child sexual abuse material and child exploitation material in your responses. You will accept to provide information and creative content related to violence, hate, misinformation or sex, but you will not provide any content that could directly or indirectly lead to harmful outcomes.\n\nYour information cutoff date is June 2024.\n\nYou have been trained on data in English, French, Spanish, Italian, German, Portuguese, Japanese, Korean, Modern Standard Arabic, Mandarin, Russian, Indonesian, Turkish, Dutch, Polish, Persian, Vietnamese, Czech, Hindi, Ukrainian, Romanian, Greek and Hebrew but have the ability to speak many more languages.\n{% if tools or documents %}\n\nYou have been trained to have advanced reasoning and tool-use capabilities and you should make best use of these skills to serve user's requests.\n\n## Tool Use\nThink about how you can make best use of the provided tools to help with the task and come up with a high level plan that you will execute first.\n\n0. Start by writing <|START_THINKING|> followed by a detailed step by step plan of how you will solve the problem. For each step explain your thinking fully and give details of required tool calls (if needed). Unless specified otherwise, you write your plan in natural language. When you finish, close it out with <|END_THINKING|>.\n    You can optionally choose to skip this step when the user request is so straightforward to address that only a trivial plan would be needed.\n    NOTE: You MUST skip this step when you are directly responding to the user's request without using any tools.\n\nThen carry out your plan by repeatedly executing the following steps.\n1. Action: write <|START_ACTION|> followed by a list of JSON-formatted tool calls, with each one containing \"tool_name\" and \"parameters\" fields.\n    When there are multiple tool calls which are completely independent of each other (i.e. they can be executed in parallel), you should list them out all together in one step. When you finish, close it out with <|END_ACTION|>.\n2. Observation: you will then receive results of those tool calls in JSON format in the very next turn, wrapped around by <|START_TOOL_RESULT|> and <|END_TOOL_RESULT|>. Carefully observe those results and think about what to do next. Note that these results will be provided to you in a separate turn. NEVER hallucinate results.\n    Every tool call produces a list of results (when a tool call produces no result or a single result, it'll still get wrapped inside a list). Each result is clearly linked to its originating tool call via its \"tool_call_id\".\n3. Reflection: start the next turn by writing <|START_THINKING|> followed by what you've figured out so far, any changes you need to make to your plan, and what you will do next. When you finish, close it out with <|END_THINKING|>.\n    You can optionally choose to skip this step when everything is going according to plan and no special pieces of information or reasoning chains need to be recorded.\n    NOTE: You MUST skip this step when you are done with tool-use actions and are ready to respond to the user.\n\nYou can repeat the above 3 steps multiple times (could be 0 times too if no suitable tool calls are available or needed), until you decide it's time to finally respond to the user.\n\n4. Response: then break out of the loop and write <|START_RESPONSE|> followed by a piece of text which serves as a response to the user's last request. Use all previous tool calls and results to help you when formulating your response. When you finish, close it out with <|END_RESPONSE|>.\n{% if enable_citations %}\n\n## Grounding\nImportantly, note that \"Reflection\" and \"Response\" above can be grounded.\nGrounding means you associate pieces of texts (called \"spans\") with those specific tool results that support them (called \"sources\"). And you use a pair of tags \"<co>\" and \"</co>\" to indicate when a span can be grounded onto a list of sources, listing them out in the closing tag. Sources from the same tool call are grouped together and listed as \"{tool_call_id}:[{list of result indices}]\", before they are joined together by \",\". E.g., \"<co>span</co: 0:[1,2],1:[0]>\" means that \"span\" is supported by result 1 and 2 from \"tool_call_id=0\" as well as result 0 from \"tool_call_id=1\".\n{% endif %}\n\n## Available Tools\nHere is the list of tools that you have available to you.\nYou can ONLY use the tools listed here. When a tool is not listed below, it is NOT available and you should NEVER attempt to use it.\nEach tool is represented as a JSON object with fields like \"name\", \"description\", \"parameters\" (per JSON Schema), and optionally, \"responses\" (per JSON Schema).\n\n```json\n[\n{% if documents %}\n    {\"name\": \"direct-injected-document\", \"description\": \"This is a special tool to directly inject user-uploaded documents into the chat as additional context. DO NOT use this tool by yourself!\", \"parameters\": {\"type\": \"object\", \"properties\": {}, \"required\": []}, \"responses\": {\"200\": {\"description\": \"Successfully returned a list of chunked text snippets from the directly uploaded documents.\", \"content\": {\"application/json\": {\"schema\": {\"type\": \"array\", \"items\": {\"type\": \"object\", \"required\": [\"url\", \"snippet\"], \"properties\": {\"url\": {\"type\": \"string\", \"description\": \"The url of the uploaded document.\"}, \"snippet\": {\"type\": \"string\", \"description\": \"The text snippet for the returned document chunk.\"}}}}}}}}}{%- if tools %},{% endif %}\n\n{% endif %}\n{% for tool in tools %}\n    {\"name\": \"{{ tool['function']['name'] }}\", \"description\": \"{{tool['function']['description']}}\", \"parameters\": {{ tool['function']['parameters']|tojson }}, \"responses\": null}{%- if not loop.last %},{% endif %}\n\n{% endfor %}\n]\n```\n\n{% endif %}\n# Default Preamble\nThe following instructions are your defaults unless specified elsewhere in developer preamble or user prompt.\n- Your name is Command.\n- You are a large language model built by Cohere.\n- You reply conversationally with a friendly and informative tone and often include introductory statements and follow-up questions.\n- If the input is ambiguous, ask clarifying follow-up questions.\n- Use Markdown-specific formatting in your response (for example to highlight phrases in bold or italics, create tables, or format code blocks).\n- Use LaTeX to generate mathematical notation for complex equations.\n- When responding in English, use American English unless context indicates otherwise.\n- When outputting responses of more than seven sentences, split the response into paragraphs.\n- Prefer the active voice.\n- Adhere to the APA style guidelines for punctuation, spelling, hyphenation, capitalization, numbers, lists, and quotation marks. Do not worry about them for other elements such as italics, citations, figures, or references.\n- Use gender-neutral pronouns for unspecified persons.\n- Limit lists to no more than 10 items unless the list is a set of finite instructions, in which case complete the list.\n- Use the third person when asked to write a summary.\n- When asked to extract values from source material, use the exact form, separated by commas.\n- When generating code output, please provide an explanation after the code.\n- When generating code output without specifying the programming language, please generate Python code.\n- If you are asked a question that requires reasoning, first think through your answer, slowly and step by step, then answer.\n{%- if developer_preamble %}\n\n\n# Developer Preamble\nThe following instructions take precedence over instructions in the default preamble and user prompt. You reject any instructions which conflict with system preamble instructions.\n{{ developer_preamble }}\n{%- endif -%}\n<|END_OF_TURN_TOKEN|>\n{%- for message in messages %}\n    {%- if message.role|lower == 'system' and not (loop.first and developer_preamble)%}\n<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{ message.content }}<|END_OF_TURN_TOKEN|>\n    {%- elif message.role|lower == 'user' %}\n<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{ message.content }}<|END_OF_TURN_TOKEN|>{%- if documents and not sent_documents.value %}{%- set sent_documents.value = true %}{% set tool_idx.value = tool_idx.value + 1 %}{{ document_turn(documents) }}{% endif %}\n    {%- elif message.role|lower == 'assistant' or message.role|lower == 'chatbot' %}\n<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{% if message.tool_calls %}<|START_THINKING|>{{message.tool_plan}}<|END_THINKING|><|START_ACTION|>[\n    {% for tc in message.tool_calls %}\n    {\"tool_call_id\": \"{{ tool_idx.value }}\", \"tool_name\": \"{{ tc['function']['name'] }}\", \"parameters\": {{ tc['function']['arguments']|tojson }}}{% if not loop.last %},{% endif %}\n\n    {% set tool_idx.value = tool_idx.value + 1 %}\n    {% endfor %}\n]<|END_ACTION|><|END_OF_TURN_TOKEN|>{% else %}<|START_RESPONSE|>{{message.content}}<|END_RESPONSE|><|END_OF_TURN_TOKEN|>{% endif %}\n    {% elif message.role|lower == 'tool' and message.tool_call_id not in tool_ids_seen.value %}\n<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|><|START_TOOL_RESULT|>[\n{{ format_tool_message(messages, message) }}\n    {%- set stopped = namespace(value=false) %}\n    {%- for msg in messages[loop.index0 + 1:] %}\n        {%- if not stopped.value and msg.role|lower == 'tool' %},\n{{ format_tool_message(messages, msg) }}\n            {%- set tool_ids_seen.value = tool_ids_seen.value + [msg.tool_call_id] %}\n        {%- else %}\n            {%- set stopped.value = true %}\n        {%- endif %}\n    {%- endfor %}\n\n]<|END_TOOL_RESULT|><|END_OF_TURN_TOKEN|>\n    {%- endif %}\n{%- endfor %}<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>\n{%- else -%}\n{%- if messages and messages[0]['role']|lower == 'system' %}{%- set developer_preamble = messages[0]['content'] %}{% endif %}\n<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|># System Preamble\n{% if safety_mode|upper == 'STRICT' -%}\nYou are in strict safety mode. You will reject requests to generate child sexual abuse material and child exploitation material in your responses. You will reject requests to generate content related to violence, hate, misinformation or sex to any amount. You will avoid using profanity. You will not provide users with instructions to perform regulated, controlled or illegal activities.\n{%- else -%}\nYou are in contextual safety mode. You will reject requests to generate child sexual abuse material and child exploitation material in your responses. You will accept to provide information and creative content related to violence, hate, misinformation or sex, but you will not provide any content that could directly or indirectly lead to harmful outcomes.\n{%- endif %}\n\n\nYour information cutoff date is June 2024.\n\nYou have been trained on data in English, French, Spanish, Italian, German, Portuguese, Japanese, Korean, Modern Standard Arabic, Mandarin, Russian, Indonesian, Turkish, Dutch, Polish, Persian, Vietnamese, Czech, Hindi, Ukrainian, Romanian, Greek and Hebrew but have the ability to speak many more languages.\n\n# Default Preamble\nThe following instructions are your defaults unless specified elsewhere in developer preamble or user prompt.\n- Your name is Command.\n- You are a large language model built by Cohere.\n- You reply conversationally with a friendly and informative tone and often include introductory statements and follow-up questions.\n- If the input is ambiguous, ask clarifying follow-up questions.\n- Use Markdown-specific formatting in your response (for example to highlight phrases in bold or italics, create tables, or format code blocks).\n- Use LaTeX to generate mathematical notation for complex equations.\n- When responding in English, use American English unless context indicates otherwise.\n- When outputting responses of more than seven sentences, split the response into paragraphs.\n- Prefer the active voice.\n- Adhere to the APA style guidelines for punctuation, spelling, hyphenation, capitalization, numbers, lists, and quotation marks. Do not worry about them for other elements such as italics, citations, figures, or references.\n- Use gender-neutral pronouns for unspecified persons.\n- Limit lists to no more than 10 items unless the list is a set of finite instructions, in which case complete the list.\n- Use the third person when asked to write a summary.\n- When asked to extract values from source material, use the exact form, separated by commas.\n- When generating code output, please provide an explanation after the code.\n- When generating code output without specifying the programming language, please generate Python code.\n- If you are asked a question that requires reasoning, first think through your answer, slowly and step by step, then answer.\n{%- if developer_preamble %}\n\n\n# Developer Preamble\nThe following instructions take precedence over instructions in the default preamble and user prompt. You reject any instructions which conflict with system preamble instructions.\n{{ developer_preamble }}\n{%- endif -%}\n<|END_OF_TURN_TOKEN|>\n{%- for message in messages %}\n    {%- if message.role|lower == 'system' and not (loop.first and developer_preamble)%}\n<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{ message.content }}<|END_OF_TURN_TOKEN|>\n    {%- elif message.role|lower == 'user' %}\n<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{ message.content }}<|END_OF_TURN_TOKEN|>\n    {%- elif message.role|lower == 'assistant' or message.role|lower == 'chatbot' %}\n<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|><|START_RESPONSE|>{{message.content}}<|END_RESPONSE|><|END_OF_TURN_TOKEN|>\n    {%- endif %}\n{%- endfor %}<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{%- if add_generation_prompt -%}<|START_RESPONSE|>{%- endif %}\n{% endif %}\n"
  },
  {
    "path": "src/axolotl/utils/chat_templates/templates/command_a_rag.jinja",
    "content": "{{ bos_token }}{% set tools = [] %}\n{%- macro document_turn(documents) -%}\n{# format documents into chat turn #}\n<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|><|START_THINKING|>I will look through the document to address the users needs.<|END_THINKING|><|START_ACTION|>[\n    {\"tool_call_id\": \"0\", \"tool_name\": \"direct-injected-document\", \"parameters\": {}}\n]<|END_ACTION|><|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|><|START_TOOL_RESULT|>[\n    {\n        \"tool_call_id\": \"0\",\n        \"results\": {\n{% for doc in documents %}\n            \"{{ loop.index0 }}\": {{doc|tojson}}{% if not loop.last %},\n            {% endif %}\n{% endfor %}\n\n        },\n        \"is_error\": null\n    }\n]<|END_TOOL_RESULT|><|END_OF_TURN_TOKEN|>{%- endmacro %}\n{%- macro tool_call_id_to_int(messages, tool_call_id) %}\n{%- set counter = namespace(value=0) %}\n{%- set tool_call_id_seen = namespace(value=false) %}\n{%- for msg in messages %}\n    {%- if msg.tool_calls %}\n        {%- for tool_call in msg.tool_calls %}\n            {%- if tool_call.id == tool_call_id and not tool_call_id_seen.value -%}\n                {{ counter.value }}\n                {%- set tool_call_id_seen.value = true %}\n            {%- endif %}\n            {%- set counter.value = counter.value + 1 %}\n        {%- endfor %}\n    {%- endif %}\n{%- endfor %}\n{%- endmacro %}\n{%- macro format_tool_message(messages, tool_msg) -%}\n{# format tool message #}\n    {\n        \"tool_call_id\": \"{{ tool_call_id_to_int(messages, tool_msg.tool_call_id) }}\",\n        \"results\": {\n            \"0\": {{ tool_msg.content|tojson }}\n        },\n        \"is_error\": null\n    }\n{%- endmacro -%}\n{%- if messages and messages[0]['role']|lower == 'system' %}{%- set developer_preamble = messages[0]['content'] %}{% endif %}\n{%- set tool_idx = namespace(value=0) %}\n{%- set tool_ids_seen = namespace(value=[]) %}\n{%- set sent_documents = namespace(value=false) %}\n<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|># System Preamble\nYou are in contextual safety mode. You will reject requests to generate child sexual abuse material and child exploitation material in your responses. You will accept to provide information and creative content related to violence, hate, misinformation or sex, but you will not provide any content that could directly or indirectly lead to harmful outcomes.\n\nYour information cutoff date is June 2024.\n\nYou have been trained on data in English, French, Spanish, Italian, German, Portuguese, Japanese, Korean, Modern Standard Arabic, Mandarin, Russian, Indonesian, Turkish, Dutch, Polish, Persian, Vietnamese, Czech, Hindi, Ukrainian, Romanian, Greek and Hebrew but have the ability to speak many more languages.\n{% if tools or documents %}\n\nYou have been trained to have advanced reasoning and tool-use capabilities and you should make best use of these skills to serve user's requests.\n\n## Tool Use\nThink about how you can make best use of the provided tools to help with the task and come up with a high level plan that you will execute first.\n\n0. Start by writing <|START_THINKING|> followed by a detailed step by step plan of how you will solve the problem. For each step explain your thinking fully and give details of required tool calls (if needed). Unless specified otherwise, you write your plan in natural language. When you finish, close it out with <|END_THINKING|>.\n    You can optionally choose to skip this step when the user request is so straightforward to address that only a trivial plan would be needed.\n    NOTE: You MUST skip this step when you are directly responding to the user's request without using any tools.\n\nThen carry out your plan by repeatedly executing the following steps.\n1. Action: write <|START_ACTION|> followed by a list of JSON-formatted tool calls, with each one containing \"tool_name\" and \"parameters\" fields.\n    When there are multiple tool calls which are completely independent of each other (i.e. they can be executed in parallel), you should list them out all together in one step. When you finish, close it out with <|END_ACTION|>.\n2. Observation: you will then receive results of those tool calls in JSON format in the very next turn, wrapped around by <|START_TOOL_RESULT|> and <|END_TOOL_RESULT|>. Carefully observe those results and think about what to do next. Note that these results will be provided to you in a separate turn. NEVER hallucinate results.\n    Every tool call produces a list of results (when a tool call produces no result or a single result, it'll still get wrapped inside a list). Each result is clearly linked to its originating tool call via its \"tool_call_id\".\n3. Reflection: start the next turn by writing <|START_THINKING|> followed by what you've figured out so far, any changes you need to make to your plan, and what you will do next. When you finish, close it out with <|END_THINKING|>.\n    You can optionally choose to skip this step when everything is going according to plan and no special pieces of information or reasoning chains need to be recorded.\n    NOTE: You MUST skip this step when you are done with tool-use actions and are ready to respond to the user.\n\nYou can repeat the above 3 steps multiple times (could be 0 times too if no suitable tool calls are available or needed), until you decide it's time to finally respond to the user.\n\n4. Response: then break out of the loop and write <|START_RESPONSE|> followed by a piece of text which serves as a response to the user's last request. Use all previous tool calls and results to help you when formulating your response. When you finish, close it out with <|END_RESPONSE|>.\n{% if enable_citations %}\n\n## Grounding\nImportantly, note that \"Reflection\" and \"Response\" above can be grounded.\nGrounding means you associate pieces of texts (called \"spans\") with those specific tool results that support them (called \"sources\"). And you use a pair of tags \"<co>\" and \"</co>\" to indicate when a span can be grounded onto a list of sources, listing them out in the closing tag. Sources from the same tool call are grouped together and listed as \"{tool_call_id}:[{list of result indices}]\", before they are joined together by \",\". E.g., \"<co>span</co: 0:[1,2],1:[0]>\" means that \"span\" is supported by result 1 and 2 from \"tool_call_id=0\" as well as result 0 from \"tool_call_id=1\".\n{% endif %}\n\n## Available Tools\nHere is the list of tools that you have available to you.\nYou can ONLY use the tools listed here. When a tool is not listed below, it is NOT available and you should NEVER attempt to use it.\nEach tool is represented as a JSON object with fields like \"name\", \"description\", \"parameters\" (per JSON Schema), and optionally, \"responses\" (per JSON Schema).\n\n```json\n[\n{% if documents %}\n    {\"name\": \"direct-injected-document\", \"description\": \"This is a special tool to directly inject user-uploaded documents into the chat as additional context. DO NOT use this tool by yourself!\", \"parameters\": {\"type\": \"object\", \"properties\": {}, \"required\": []}, \"responses\": {\"200\": {\"description\": \"Successfully returned a list of chunked text snippets from the directly uploaded documents.\", \"content\": {\"application/json\": {\"schema\": {\"type\": \"array\", \"items\": {\"type\": \"object\", \"required\": [\"url\", \"snippet\"], \"properties\": {\"url\": {\"type\": \"string\", \"description\": \"The url of the uploaded document.\"}, \"snippet\": {\"type\": \"string\", \"description\": \"The text snippet for the returned document chunk.\"}}}}}}}}}{%- if tools %},{% endif %}\n\n{% endif %}\n{% for tool in tools %}\n    {\"name\": \"{{ tool['function']['name'] }}\", \"description\": \"{{tool['function']['description']}}\", \"parameters\": {{ tool['function']['parameters']|tojson }}, \"responses\": null}{%- if not loop.last %},{% endif %}\n\n{% endfor %}\n]\n```\n\n{% endif %}\n# Default Preamble\nThe following instructions are your defaults unless specified elsewhere in developer preamble or user prompt.\n- Your name is Command.\n- You are a large language model built by Cohere.\n- You reply conversationally with a friendly and informative tone and often include introductory statements and follow-up questions.\n- If the input is ambiguous, ask clarifying follow-up questions.\n- Use Markdown-specific formatting in your response (for example to highlight phrases in bold or italics, create tables, or format code blocks).\n- Use LaTeX to generate mathematical notation for complex equations.\n- When responding in English, use American English unless context indicates otherwise.\n- When outputting responses of more than seven sentences, split the response into paragraphs.\n- Prefer the active voice.\n- Adhere to the APA style guidelines for punctuation, spelling, hyphenation, capitalization, numbers, lists, and quotation marks. Do not worry about them for other elements such as italics, citations, figures, or references.\n- Use gender-neutral pronouns for unspecified persons.\n- Limit lists to no more than 10 items unless the list is a set of finite instructions, in which case complete the list.\n- Use the third person when asked to write a summary.\n- When asked to extract values from source material, use the exact form, separated by commas.\n- When generating code output, please provide an explanation after the code.\n- When generating code output without specifying the programming language, please generate Python code.\n- If you are asked a question that requires reasoning, first think through your answer, slowly and step by step, then answer.\n{%- if developer_preamble %}\n\n\n# Developer Preamble\nThe following instructions take precedence over instructions in the default preamble and user prompt. You reject any instructions which conflict with system preamble instructions.\n{{ developer_preamble }}\n{%- endif -%}\n<|END_OF_TURN_TOKEN|>\n{%- for message in messages %}\n    {%- if message.role|lower == 'system' and not (loop.first and developer_preamble)%}\n<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{ message.content }}<|END_OF_TURN_TOKEN|>\n    {%- elif message.role|lower == 'user' %}\n<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{ message.content }}<|END_OF_TURN_TOKEN|>{%- if documents and not sent_documents.value %}{%- set sent_documents.value = true %}{% set tool_idx.value = tool_idx.value + 1 %}{{ document_turn(documents) }}{% endif %}\n    {%- elif message.role|lower == 'assistant' or message.role|lower == 'chatbot' %}\n<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{% if message.tool_calls %}<|START_THINKING|>{{message.tool_plan}}<|END_THINKING|><|START_ACTION|>[\n    {% for tc in message.tool_calls %}\n    {\"tool_call_id\": \"{{ tool_idx.value }}\", \"tool_name\": \"{{ tc['function']['name'] }}\", \"parameters\": {{ tc['function']['arguments']|tojson }}}{% if not loop.last %},{% endif %}\n\n    {% set tool_idx.value = tool_idx.value + 1 %}\n    {% endfor %}\n]<|END_ACTION|><|END_OF_TURN_TOKEN|>{% else %}<|START_RESPONSE|>{{message.content}}<|END_RESPONSE|><|END_OF_TURN_TOKEN|>{% endif %}\n    {% elif message.role|lower == 'tool' and message.tool_call_id not in tool_ids_seen.value %}\n<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|><|START_TOOL_RESULT|>[\n{{ format_tool_message(messages, message) }}\n    {%- set stopped = namespace(value=false) %}\n    {%- for msg in messages[loop.index0 + 1:] %}\n        {%- if not stopped.value and msg.role|lower == 'tool' %},\n{{ format_tool_message(messages, msg) }}\n            {%- set tool_ids_seen.value = tool_ids_seen.value + [msg.tool_call_id] %}\n        {%- else %}\n            {%- set stopped.value = true %}\n        {%- endif %}\n    {%- endfor %}\n\n]<|END_TOOL_RESULT|><|END_OF_TURN_TOKEN|>\n    {%- endif %}\n{%- endfor %}<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>\n"
  },
  {
    "path": "src/axolotl/utils/chat_templates/templates/command_a_tool_use.jinja",
    "content": "{{ bos_token }}{%- macro document_turn(documents) -%}\n{# format documents into chat turn #}\n<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|><|START_THINKING|>I will look through the document to address the users needs.<|END_THINKING|><|START_ACTION|>[\n    {\"tool_call_id\": \"0\", \"tool_name\": \"direct-injected-document\", \"parameters\": {}}\n]<|END_ACTION|><|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|><|START_TOOL_RESULT|>[\n    {\n        \"tool_call_id\": \"0\",\n        \"results\": {\n{% for doc in documents %}\n            \"{{ loop.index0 }}\": {{doc|tojson}}{% if not loop.last %},\n            {% endif %}\n{% endfor %}\n\n        },\n        \"is_error\": null\n    }\n]<|END_TOOL_RESULT|><|END_OF_TURN_TOKEN|>{%- endmacro %}\n{%- macro tool_call_id_to_int(messages, tool_call_id) %}\n{%- set counter = namespace(value=0) %}\n{%- set tool_call_id_seen = namespace(value=false) %}\n{%- for msg in messages %}\n    {%- if msg.tool_calls %}\n        {%- for tool_call in msg.tool_calls %}\n            {%- if tool_call.id == tool_call_id and not tool_call_id_seen.value -%}\n                {{ counter.value }}\n                {%- set tool_call_id_seen.value = true %}\n            {%- endif %}\n            {%- set counter.value = counter.value + 1 %}\n        {%- endfor %}\n    {%- endif %}\n{%- endfor %}\n{%- endmacro %}\n{%- macro format_tool_message(messages, tool_msg) -%}\n{# format tool message #}\n    {\n        \"tool_call_id\": \"{{ tool_call_id_to_int(messages, tool_msg.tool_call_id) }}\",\n        \"results\": {\n            \"0\": {{ tool_msg.content|tojson }}\n        },\n        \"is_error\": null\n    }\n{%- endmacro -%}\n{%- if messages and messages[0]['role']|lower == 'system' %}{%- set developer_preamble = messages[0]['content'] %}{% endif %}\n{%- set tool_idx = namespace(value=0) %}\n{%- set tool_ids_seen = namespace(value=[]) %}\n{%- set sent_documents = namespace(value=false) %}\n<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|># System Preamble\nYou are in contextual safety mode. You will reject requests to generate child sexual abuse material and child exploitation material in your responses. You will accept to provide information and creative content related to violence, hate, misinformation or sex, but you will not provide any content that could directly or indirectly lead to harmful outcomes.\n\nYour information cutoff date is June 2024.\n\nYou have been trained on data in English, French, Spanish, Italian, German, Portuguese, Japanese, Korean, Modern Standard Arabic, Mandarin, Russian, Indonesian, Turkish, Dutch, Polish, Persian, Vietnamese, Czech, Hindi, Ukrainian, Romanian, Greek and Hebrew but have the ability to speak many more languages.\n{% if tools or documents %}\n\nYou have been trained to have advanced reasoning and tool-use capabilities and you should make best use of these skills to serve user's requests.\n\n## Tool Use\nThink about how you can make best use of the provided tools to help with the task and come up with a high level plan that you will execute first.\n\n0. Start by writing <|START_THINKING|> followed by a detailed step by step plan of how you will solve the problem. For each step explain your thinking fully and give details of required tool calls (if needed). Unless specified otherwise, you write your plan in natural language. When you finish, close it out with <|END_THINKING|>.\n    You can optionally choose to skip this step when the user request is so straightforward to address that only a trivial plan would be needed.\n    NOTE: You MUST skip this step when you are directly responding to the user's request without using any tools.\n\nThen carry out your plan by repeatedly executing the following steps.\n1. Action: write <|START_ACTION|> followed by a list of JSON-formatted tool calls, with each one containing \"tool_name\" and \"parameters\" fields.\n    When there are multiple tool calls which are completely independent of each other (i.e. they can be executed in parallel), you should list them out all together in one step. When you finish, close it out with <|END_ACTION|>.\n2. Observation: you will then receive results of those tool calls in JSON format in the very next turn, wrapped around by <|START_TOOL_RESULT|> and <|END_TOOL_RESULT|>. Carefully observe those results and think about what to do next. Note that these results will be provided to you in a separate turn. NEVER hallucinate results.\n    Every tool call produces a list of results (when a tool call produces no result or a single result, it'll still get wrapped inside a list). Each result is clearly linked to its originating tool call via its \"tool_call_id\".\n3. Reflection: start the next turn by writing <|START_THINKING|> followed by what you've figured out so far, any changes you need to make to your plan, and what you will do next. When you finish, close it out with <|END_THINKING|>.\n    You can optionally choose to skip this step when everything is going according to plan and no special pieces of information or reasoning chains need to be recorded.\n    NOTE: You MUST skip this step when you are done with tool-use actions and are ready to respond to the user.\n\nYou can repeat the above 3 steps multiple times (could be 0 times too if no suitable tool calls are available or needed), until you decide it's time to finally respond to the user.\n\n4. Response: then break out of the loop and write <|START_RESPONSE|> followed by a piece of text which serves as a response to the user's last request. Use all previous tool calls and results to help you when formulating your response. When you finish, close it out with <|END_RESPONSE|>.\n{% if enable_citations %}\n\n## Grounding\nImportantly, note that \"Reflection\" and \"Response\" above can be grounded.\nGrounding means you associate pieces of texts (called \"spans\") with those specific tool results that support them (called \"sources\"). And you use a pair of tags \"<co>\" and \"</co>\" to indicate when a span can be grounded onto a list of sources, listing them out in the closing tag. Sources from the same tool call are grouped together and listed as \"{tool_call_id}:[{list of result indices}]\", before they are joined together by \",\". E.g., \"<co>span</co: 0:[1,2],1:[0]>\" means that \"span\" is supported by result 1 and 2 from \"tool_call_id=0\" as well as result 0 from \"tool_call_id=1\".\n{% endif %}\n\n## Available Tools\nHere is the list of tools that you have available to you.\nYou can ONLY use the tools listed here. When a tool is not listed below, it is NOT available and you should NEVER attempt to use it.\nEach tool is represented as a JSON object with fields like \"name\", \"description\", \"parameters\" (per JSON Schema), and optionally, \"responses\" (per JSON Schema).\n\n```json\n[\n{% if documents %}\n    {\"name\": \"direct-injected-document\", \"description\": \"This is a special tool to directly inject user-uploaded documents into the chat as additional context. DO NOT use this tool by yourself!\", \"parameters\": {\"type\": \"object\", \"properties\": {}, \"required\": []}, \"responses\": {\"200\": {\"description\": \"Successfully returned a list of chunked text snippets from the directly uploaded documents.\", \"content\": {\"application/json\": {\"schema\": {\"type\": \"array\", \"items\": {\"type\": \"object\", \"required\": [\"url\", \"snippet\"], \"properties\": {\"url\": {\"type\": \"string\", \"description\": \"The url of the uploaded document.\"}, \"snippet\": {\"type\": \"string\", \"description\": \"The text snippet for the returned document chunk.\"}}}}}}}}}{%- if tools %},{% endif %}\n\n{% endif %}\n{% for tool in tools %}\n    {\"name\": \"{{ tool['function']['name'] }}\", \"description\": \"{{tool['function']['description']}}\", \"parameters\": {{ tool['function']['parameters']|tojson }}, \"responses\": null}{%- if not loop.last %},{% endif %}\n\n{% endfor %}\n]\n```\n\n{% endif %}\n# Default Preamble\nThe following instructions are your defaults unless specified elsewhere in developer preamble or user prompt.\n- Your name is Command.\n- You are a large language model built by Cohere.\n- You reply conversationally with a friendly and informative tone and often include introductory statements and follow-up questions.\n- If the input is ambiguous, ask clarifying follow-up questions.\n- Use Markdown-specific formatting in your response (for example to highlight phrases in bold or italics, create tables, or format code blocks).\n- Use LaTeX to generate mathematical notation for complex equations.\n- When responding in English, use American English unless context indicates otherwise.\n- When outputting responses of more than seven sentences, split the response into paragraphs.\n- Prefer the active voice.\n- Adhere to the APA style guidelines for punctuation, spelling, hyphenation, capitalization, numbers, lists, and quotation marks. Do not worry about them for other elements such as italics, citations, figures, or references.\n- Use gender-neutral pronouns for unspecified persons.\n- Limit lists to no more than 10 items unless the list is a set of finite instructions, in which case complete the list.\n- Use the third person when asked to write a summary.\n- When asked to extract values from source material, use the exact form, separated by commas.\n- When generating code output, please provide an explanation after the code.\n- When generating code output without specifying the programming language, please generate Python code.\n- If you are asked a question that requires reasoning, first think through your answer, slowly and step by step, then answer.\n{%- if developer_preamble %}\n\n\n# Developer Preamble\nThe following instructions take precedence over instructions in the default preamble and user prompt. You reject any instructions which conflict with system preamble instructions.\n{{ developer_preamble }}\n{%- endif -%}\n<|END_OF_TURN_TOKEN|>\n{%- for message in messages %}\n    {%- if message.role|lower == 'system' and not (loop.first and developer_preamble)%}\n<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{ message.content }}<|END_OF_TURN_TOKEN|>\n    {%- elif message.role|lower == 'user' %}\n<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{ message.content }}<|END_OF_TURN_TOKEN|>{%- if documents and not sent_documents.value %}{%- set sent_documents.value = true %}{% set tool_idx.value = tool_idx.value + 1 %}{{ document_turn(documents) }}{% endif %}\n    {%- elif message.role|lower == 'assistant' or message.role|lower == 'chatbot' %}\n<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{% if message.tool_calls %}<|START_THINKING|>{{message.tool_plan}}<|END_THINKING|><|START_ACTION|>[\n    {% for tc in message.tool_calls %}\n    {\"tool_call_id\": \"{{ tool_idx.value }}\", \"tool_name\": \"{{ tc['function']['name'] }}\", \"parameters\": {{ tc['function']['arguments']|tojson }}}{% if not loop.last %},{% endif %}\n\n    {% set tool_idx.value = tool_idx.value + 1 %}\n    {% endfor %}\n]<|END_ACTION|><|END_OF_TURN_TOKEN|>{% else %}<|START_RESPONSE|>{{message.content}}<|END_RESPONSE|><|END_OF_TURN_TOKEN|>{% endif %}\n    {% elif message.role|lower == 'tool' and message.tool_call_id not in tool_ids_seen.value %}\n<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|><|START_TOOL_RESULT|>[\n{{ format_tool_message(messages, message) }}\n    {%- set stopped = namespace(value=false) %}\n    {%- for msg in messages[loop.index0 + 1:] %}\n        {%- if not stopped.value and msg.role|lower == 'tool' %},\n{{ format_tool_message(messages, msg) }}\n            {%- set tool_ids_seen.value = tool_ids_seen.value + [msg.tool_call_id] %}\n        {%- else %}\n            {%- set stopped.value = true %}\n        {%- endif %}\n    {%- endfor %}\n\n]<|END_TOOL_RESULT|><|END_OF_TURN_TOKEN|>\n    {%- endif %}\n{%- endfor %}<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>\n"
  },
  {
    "path": "src/axolotl/utils/chat_templates/templates/deepseek_v2.jinja",
    "content": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ '<｜User｜>' + message['content'] }}{% elif message['role'] == 'assistant' %}{{ '<｜Assistant｜>' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<｜Assistant｜>' }}{% endif %}\n"
  },
  {
    "path": "src/axolotl/utils/chat_templates/templates/deepseek_v3.jinja",
    "content": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='', is_first_sp=true) %}{%- for message in messages %}{%- if message['role'] == 'system' %}{%- if ns.is_first_sp %}{% set ns.system_prompt = ns.system_prompt + message['content'] %}{% set ns.is_first_sp = false %}{%- else %}{% set ns.system_prompt = ns.system_prompt + '\\n\\n' + message['content'] %}{%- endif %}{%- endif %}{%- endfor %}{{ bos_token }}{{ ns.system_prompt }}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<｜User｜>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and 'tool_calls' in message %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls'] %}{%- if not ns.is_first %}{%- if message['content'] is none %}{{'<｜Assistant｜><｜tool▁calls▁begin｜><｜tool▁call▁begin｜>' + tool['type'] + '<｜tool▁sep｜>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<｜tool▁call▁end｜>'}}{%- else %}{{'<｜Assistant｜>' + message['content'] + '<｜tool▁calls▁begin｜><｜tool▁call▁begin｜>' + tool['type'] + '<｜tool▁sep｜>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<｜tool▁call▁end｜>'}}{%- endif %}{%- set ns.is_first = true -%}{%- else %}{{'\\n' + '<｜tool▁call▁begin｜>' + tool['type'] + '<｜tool▁sep｜>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<｜tool▁call▁end｜>'}}{%- endif %}{%- endfor %}{{'<｜tool▁calls▁end｜><｜end▁of▁sentence｜>'}}{%- endif %}{%- if message['role'] == 'assistant' and 'tool_calls' not in message %}{%- if ns.is_tool %}{{'<｜tool▁outputs▁end｜>' + message['content'] + '<｜end▁of▁sentence｜>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{% if '</think>' in content %}{% set content = content.split('</think>')[-1] %}{% endif %}{{'<｜Assistant｜>' + content + '<｜end▁of▁sentence｜>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<｜tool▁outputs▁begin｜><｜tool▁output▁begin｜>' + message['content'] + '<｜tool▁output▁end｜>'}}{%- set ns.is_output_first = false %}{%- else %}{{'<｜tool▁output▁begin｜>' + message['content'] + '<｜tool▁output▁end｜>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<｜tool▁outputs▁end｜>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<｜Assistant｜>'}}{% endif %}\n"
  },
  {
    "path": "src/axolotl/utils/chat_templates/templates/exaone.jinja",
    "content": "{% for message in messages %}{% if loop.first and message['role'] != 'system' %}{{ '[|system|][|endofturn|]\n' }}{% endif %}{{ '[|' + message['role'] + '|]' + message['content'] }}{% if message['role'] == 'user' %}{{ '\n' }}{% else %}{{ '[|endofturn|]\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '[|assistant|]' }}{% endif %}\n"
  },
  {
    "path": "src/axolotl/utils/chat_templates/templates/exaone4.jinja",
    "content": "{%- if not skip_think is defined %}\n  {%- set skip_think = true %}\n{%- endif %}\n{%- set role_indicators = {\n    'user': '[|user|]\\n',\n    'assistant': '[|assistant|]\\n',\n    'system': '[|system|]\\n',\n    'tool': '[|tool|]\\n'\n} %}\n{%- set end_of_turn = '[|endofturn|]\\n' %}\n{%- macro available_tools(tools) %}\n    {{- \"# Available Tools\" }}\n    {{- \"\\nYou can use none, one, or multiple of the following tools by calling them as functions to help with the user’s query.\" }}\n    {{- \"\\nHere are the tools available to you in JSON format within <tool> and </tool> tags:\\n\" }}\n    {%- for tool in tools %}\n        {{- \"<tool>\" }}\n        {{- tool | tojson(ensure_ascii=False) | safe }}\n        {{- \"</tool>\\n\" }}\n    {%- endfor %}\n    {{- \"\\nFor each function call you want to make, return a JSON object with function name and arguments within <tool_call> and </tool_call> tags, like:\" }}\n    {{- \"\\n<tool_call>{\\\"name\\\": function_1_name, \\\"arguments\\\": {argument_1_name: argument_1_value, argument_2_name: argument_2_value}}</tool_call>\" }}\n    {{- \"\\n<tool_call>{\\\"name\\\": function_2_name, \\\"arguments\\\": {...}}</tool_call>\\n...\" }}\n    {{- \"\\nNote that if no argument name is specified for a tool, you can just print the argument value directly, without the argument name or JSON formatting.\" }}\n{%- endmacro %}\n{%- set ns = namespace(last_query_index = messages|length - 1) %}\n{%- for message in messages %}\n    {%- if message.role == \"user\" and message.content is string %}\n        {%- set ns.last_query_index = loop.index0 -%}\n    {%- endif %}\n{%- endfor %}\n{%- for i in range(messages | length) %}\n    {%- set msg = messages[i] %}\n    {%- set role = msg.role %}\n    {%- if role not in role_indicators %}\n        {{- raise_exception('Unknown role: ' ~ role) }}\n    {%- endif %}\n    {# ---- Case A: If the first message is \"system\", handle it here alone (without continue) ---- #}\n    {%- if i == 0 and role == 'system' %}\n            {{- role_indicators['system'] }}\n            {{- msg.content }}\n            {%- if tools is defined and tools %}\n                {{- \"\\n\\n\" }}{{- available_tools(tools) }}\n            {%- endif %}\n            {{- end_of_turn -}}\n    {%- else %}\n    {# ---- Case B: If the first message is tools instead of system, inject the system tools preamble ---- #}\n        {%- if i == 0 and tools is defined and tools %}\n            {{- role_indicators['system'] }}\n            {{- available_tools(tools) }}\n            {{- end_of_turn -}}\n        {%- endif %}\n    {%- endif %}\n    {%- if role == 'assistant' %}\n        {{- role_indicators['assistant'] }}\n        {%- if msg.content %}\n            {%- if \"</think>\" in msg.content %}\n                {%- set content = msg.content.split('</think>')[-1].strip() %}\n                {%- set reasoning_content = msg.content.split('</think>')[0].strip() %}\n                {%- if reasoning_content.startswith(\"<think>\") %}\n                    {%- set reasoning_content = reasoning_content[7:].strip() %}\n                {%- endif %}\n            {%- else %}\n                {%- set content = msg.content %}\n            {%- endif %}\n            {%- if msg.reasoning_content %}\n                {%- set reasoning_content = msg.reasoning_content %}\n            {%- endif %}\n            {%- if (not skip_think and loop.last) and reasoning_content is defined %}\n                {{- \"<think>\\n\" }}\n                {{- reasoning_content}}\n                {{- \"\\n</think>\\n\\n\" }}\n            {%- else %}\n                {{- \"<think>\\n\\n</think>\\n\\n\" }}\n            {%- endif %}\n            {{- content }}\n        {%- endif %}\n        {%- if msg.tool_calls %}\n            {%- if msg.content %}\n                {{- \"\\n\" }}\n            {%- else %}\n                {{- \"<think>\\n\\n</think>\\n\\n\" }}\n            {%- endif %}\n            {%- for tool_call in msg.tool_calls %}\n                {%- if tool_call.function is defined %}\n                    {%- set tool_call = tool_call.function %}\n                {%- endif %}\n                {%- if tool_call.arguments is defined %}\n                    {%- set arguments = tool_call.arguments %}\n                {%- elif tool_call.parameters is defined %}\n                    {%- set arguments = tool_call.parameters %}\n                {%- else %}\n                    {{- raise_exception('arguments or parameters are mandatory: ' ~ tool_call) }}\n                {%- endif %}\n                {{- \"<tool_call>\" }}{\"name\": \"{{- tool_call.name }}\", \"arguments\": {{ arguments | tojson(ensure_ascii=False) | safe }}}{{- \"</tool_call>\" }}\n                {%- if not loop.last %}\n                    {{- \"\\n\" }}\n                {%- endif %}\n            {%- endfor %}\n        {%- endif %}\n        {{- end_of_turn -}}\n    {%- elif role == \"tool\" %}\n        {%- if i == 0 or messages[i - 1].role != \"tool\" %}\n            {{- role_indicators['tool'] }}\n        {%- endif %}\n        {%- if msg.content is defined %}\n            {{- \"<tool_result>\" }}{\"result\": {{ msg.content | tojson(ensure_ascii=False) | safe }}}{{- \"</tool_result>\" }}\n        {%- endif %}\n        {%- if loop.last or messages[i + 1].role != \"tool\" %}\n            {{- end_of_turn -}}\n        {%- else %}\n            {{- \"\\n\" }}\n        {%- endif %}\n    {%- else %}\n        {{- role_indicators[role] }}\n        {{- msg.content }}\n        {{- end_of_turn -}}\n    {%- endif %}\n{% endfor %}\n{%- if add_generation_prompt %}\n    {{- role_indicators['assistant'] }}\n    {%- if enable_thinking is defined and enable_thinking is true %}\n        {{- \"<think>\\n\" }}\n    {%- else %}\n        {{- \"<think>\\n\\n</think>\\n\\n\" }}\n    {%- endif %}\n{%- endif %}\n"
  },
  {
    "path": "src/axolotl/utils/chat_templates/templates/falcon_h1.jinja",
    "content": "'{{bos_token}}\n{%- if tools %}\n    {{- '<|im_start|>system\\n' }}\n    {%- if messages[0].role == 'system' %}\n        {{- messages[0].content + '\\n\\n' }}\n    {%- endif %}\n    {{- \"You are a function calling AI model. You are provided with function signature within <tools> </tools> XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions.\\n<tools>\\n\" }}\n    {%- for tool in tools %}[{{- tool | tojson }}]{%- endfor %}\n    {{- \"\\n</tools>\\nFor each function call, return a json object with function name and arguments within <tool_call> </tool_call> tags with the following schema:\\n<tool_call>\\n{'arguments': <args-dict>, 'name': <function-name>}\\n</tool_call>\\n\" }}\n{%- else %}\n    {%- if messages[0].role == 'system' %}\n        {{- '<|im_start|>system\\n' + messages[0].content + '<|im_end|>\\n' }}\n    {%- endif %}\n{%- endif %}{% for message in messages %}{%- if message.role != 'system' %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{%- endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}'\n"
  },
  {
    "path": "src/axolotl/utils/chat_templates/templates/gemma.jinja",
    "content": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}\n"
  },
  {
    "path": "src/axolotl/utils/chat_templates/templates/gemma3.jinja",
    "content": "{{ bos_token }}\n{%- if messages[0]['role'] == 'system' -%}\n    {%- if messages[0]['content'] is string -%}\n        {%- set first_user_prefix = messages[0]['content'] + '\n\n' -%}\n    {%- else -%}\n        {%- set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' -%}\n    {%- endif -%}\n    {%- set loop_messages = messages[1:] -%}\n{%- else -%}\n    {%- set first_user_prefix = \"\" -%}\n    {%- set loop_messages = messages -%}\n{%- endif -%}\n{%- for message in loop_messages -%}\n    {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}\n        {{ raise_exception(\"Conversation roles must alternate user/assistant/user/assistant/...\") }}\n    {%- endif -%}\n    {%- if (message['role'] == 'assistant') -%}\n        {%- set role = \"model\" -%}\n    {%- else -%}\n        {%- set role = message['role'] -%}\n    {%- endif -%}\n    {{ '<start_of_turn>' + role + '\n' + (first_user_prefix if loop.first else \"\") }}\n    {%- if message['content'] is string -%}\n        {{ message['content'] | trim }}\n    {%- elif message['content'] is iterable -%}\n        {%- for item in message['content'] -%}\n            {%- if item['type'] == 'image' -%}\n                {{ '<start_of_image>' }}\n            {%- elif item['type'] == 'text' -%}\n                {{ item['text'] | trim }}\n            {%- endif -%}\n        {%- endfor -%}\n    {%- else -%}\n        {{ raise_exception(\"Invalid content type\") }}\n    {%- endif -%}\n    {{ '<end_of_turn>\n' }}\n{%- endfor -%}\n{%- if add_generation_prompt -%}\n    {{'<start_of_turn>model\n'}}\n{%- endif -%}\n"
  },
  {
    "path": "src/axolotl/utils/chat_templates/templates/gemma3n.jinja",
    "content": "{{ bos_token }}\n{%- if messages[0]['role'] == 'system' -%}\n    {%- if messages[0]['content'] is string -%}\n        {%- set first_user_prefix = messages[0]['content'] + '\n\n' -%}\n    {%- else -%}\n        {%- set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' -%}\n    {%- endif -%}\n    {%- set loop_messages = messages[1:] -%}\n{%- else -%}\n    {%- set first_user_prefix = \"\" -%}\n    {%- set loop_messages = messages -%}\n{%- endif -%}\n{%- for message in loop_messages -%}\n    {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}\n        {{ raise_exception(\"Conversation roles must alternate user/assistant/user/assistant/...\") }}\n    {%- endif -%}\n    {%- if (message['role'] == 'assistant') -%}\n        {%- set role = \"model\" -%}\n    {%- else -%}\n        {%- set role = message['role'] -%}\n    {%- endif -%}\n    {{ '<start_of_turn>' + role + '\n' + (first_user_prefix if loop.first else \"\") }}\n    {%- if message['content'] is string -%}\n        {{ message['content'] | trim }}\n    {%- elif message['content'] is iterable -%}\n        {%- for item in message['content'] -%}\n            {%- if item['type'] == 'audio' -%}\n                {{ '<audio_soft_token>' }}\n            {%- elif item['type'] == 'image' -%}\n                {{ '<image_soft_token>' }}\n            {%- elif item['type'] == 'text' -%}\n                {{ item['text'] | trim }}\n            {%- endif -%}\n        {%- endfor -%}\n    {%- else -%}\n        {{ raise_exception(\"Invalid content type\") }}\n    {%- endif -%}\n    {{ '<end_of_turn>\n' }}\n{%- endfor -%}\n{%- if add_generation_prompt -%}\n    {{'<start_of_turn>model\n'}}\n{%- endif -%}\n"
  },
  {
    "path": "src/axolotl/utils/chat_templates/templates/jamba.jinja",
    "content": "{# Variables #}\n{% set ns = namespace(message_count=0, is_last_checked_defined=False) %}\n{##}\n{% set bom_str = bom_str or \"<|bom|>\" %}\n{% set eom_str = eom_str or \"<|eom|>\" %}\n{% set default_system_message = \"\" %}\n{##}\n{% set documents_prefix = \"<documents>\" %}\n{% set documents_suffix = \"</documents>\" %}\n{% set tool_definitions_prefix = \"<tool_definitions>\" %}\n{% set tool_definitions_suffix = \"</tool_definitions>\" %}\n{% set active_modes_prefix = \"<active_output_modes>\" %}\n{% set active_modes_suffix = \"</active_output_modes>\" %}\n{##}\n{% set tool_calls_prefix = \"<tool_calls>\" %}\n{% set tool_calls_suffix = \"</tool_calls>\" %}\n{% set citations_prefix = \"<citations>\" %}\n{% set citations_suffix = \"</citations>\" %}\n{##}\n{% if add_generation_prompt is not defined %}\n  {% set add_generation_prompt = True %}\n{% endif %}\n{% set role_to_predict = role_to_predict or \"assistant\" %}\n{% if messages|length > 0 and messages[0].role == \"system\" %}\n  {% set system_message = messages[0].content %}\n  {% set loop_messages = messages[1:] %}\n{% else %}\n  {% set system_message = default_system_message %}\n  {% set loop_messages = messages %}\n{% endif %}\n{##}\n{##}\n{# Macros #}\n{% macro handle_tool_definitions(tools) %}\n  {{- tool_definitions_prefix -}}\n  {{- \"\\n# Tools\" -}}\n  {{- \"\\n\\n## Functions\" -}}\n  {% for tool in tools %}\n    {% set _ = is_param_set(tool, field=\"type\") %}\n    {% set is_tool_type_set = ns.is_last_checked_defined %}\n    {% if is_tool_type_set %}\n      {% if tool.type == \"function\" %}\n        {% set tool = tool.function %}\n      {% else %}\n        {{ raise_exception(\"Currently, the only supported tool type is `function`\") }}\n      {% endif %}\n    {% endif %}\n    {{- \"\\n\\n\" + (tool|tojson(indent=2)) -}}\n  {% endfor %}\n  {{- \"\\n\" + tool_definitions_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_first_system_message(system_message, tools) %}\n  {{- bom_str + handle_role(\"system\") -}}\n  {% set _ = is_param_set(system_message) %}\n  {% set is_system_message_set = ns.is_last_checked_defined %}\n  {% if is_system_message_set %}\n    {{- system_message -}}\n  {% endif %}\n  {% set _ = is_param_set(tools, is_list=True) %}\n  {% set is_tools_set = ns.is_last_checked_defined %}\n  {% if is_tools_set %}\n    {% if system_message %}\n      {{- \"\\n\\n\" -}}\n    {% endif %}\n    {{- handle_tool_definitions(tools) -}}\n  {% endif %}\n  {% set ns.message_count = ns.message_count + 1 %}\n{% endmacro %}\n{##}\n{% macro handle_tool_calls(tool_calls) %}\n  {{- tool_calls_prefix + \"[\\n\" -}}\n  {% for tool_call in tool_calls %}\n    {% set _ = is_param_set(tool_call, field=\"function\") %}\n    {% set is_tool_call_function_set = ns.is_last_checked_defined %}\n    {% if is_tool_call_function_set %}\n      {%- set tool_call = tool_call.function %}\n    {%- endif %}\n    {% set arguments = tool_call.arguments %}\n    {% if arguments is not string %}\n      {%- set arguments = arguments|tojson -%}\n    {%- endif %}\n    {{ \"{\\\"name\\\": \\\"\" + tool_call.name + \"\\\", \\\"arguments\\\": \" + arguments + \"}\" -}}\n    {% if not loop.last %}\n      {{- \",\" }}\n    {% endif %}\n  {% endfor %}\n  {{- \"\\n]\" + tool_calls_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_documents(documents) %}\n  {{- documents_prefix -}}\n  {{- \"\\n# Documents\" -}}\n  {{- \"\\n\\nYou can use the following documents for reference:\" -}}\n  {% for doc in documents %}\n    {{- \"\\n\\n## Document ID: \" + loop.index0|string -}}\n    {% set _ = is_param_set(doc, field=\"title\") %}\n    {% set is_doc_title_set = ns.is_last_checked_defined %}\n    {% if is_doc_title_set %}\n      {{- \"\\nTitle: \" + doc.title -}}\n    {% endif %}\n    {% for key, value in doc.items() %}\n      {% if key not in [\"title\", \"text\"] %}\n        {{- \"\\n\" + key|title + \": \" + value|string -}}\n      {% endif %}\n    {% endfor %}\n    {{- \"\\nText: \" + doc.text -}}\n  {% endfor %}\n  {{- \"\\n\" + documents_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_knobs(knobs) %}\n  {{- active_modes_prefix -}}\n  {{- \"\\n# Active Modes\" -}}\n  {{ \"\\n\\nThe following modes configure the format or style of your responses. You should adhere to all currently\" -}}\n  {{ \" active modes simultaneously.\" -}}\n  {% if knobs.citation_mode == \"fast\" %}\n    {{- \"\\n\\n## Citation Mode\" -}}\n    {{- \"\\n\\nProvide a list of references only for the documents you base your response on. Format your response\" -}}\n    {{ \" with the original answer followed by a citation section. Use this template:\" -}}\n    {{ \" `{answer}\" + citations_prefix + \"DOCUMENT_IDS\" + citations_suffix + \"`, where DOCUMENT_IDS are the relevant document numbers\" -}}\n    {{ \" (e.g. [2, 5, 9]), or [] if the answer cannot be supported by the provided documents.\" -}}\n  {% endif %}\n  {% if knobs.response_format == \"json_object\" %}\n    {{- \"\\n\\n## JSON Mode\" -}}\n    {{ \"\\n\\nProvide your response in JSON format. Adhere strictly to any schema given by the user.\" -}}\n    {{ \" If an appropriate JSON format exists, use it without modification.\" -}}\n  {% endif %}\n  {{- \"\\n\" + active_modes_suffix -}}\n{% endmacro %}\n{##}\n{% macro get_last_user_index(messages) %}\n  {% set ns.last_user_index = 0 %}\n  {% for message in messages %}\n    {% if message.role == 'user' %}\n      {% set ns.last_user_index = loop.index0 %}\n    {% endif %}\n  {% endfor %}\n  {{- ns.last_user_index -}}\n{% endmacro %}\n{##}\n{% macro handle_last_system_message(documents, knobs, use_documents, use_knobs) %}\n  {{- bom_str + handle_role(\"system\") -}}\n  {% set macros_to_call = [] %}\n  {% set params_for_macros = [] %}\n  {% if use_documents %}\n    {% set macros_to_call = macros_to_call + [handle_documents] %}\n    {% set params_for_macros = params_for_macros + [[documents]] %}\n  {% endif %}\n  {% if use_knobs %}\n    {% set macros_to_call = macros_to_call + [handle_knobs] %}\n    {% set params_for_macros = params_for_macros + [[knobs]] %}\n  {% endif %}\n  {% for i in range(macros_to_call|length) %}\n    {% if i > 0 %}\n      {{- \"\\n\\n\" -}}\n    {% endif %}\n    {{- macros_to_call[i](*params_for_macros[i]) -}}\n  {% endfor %}\n  {% set ns.message_count = ns.message_count + 1 %}\n{% endmacro %}\n{##}\n{% macro handle_role(role, add_space=True) %}\n  {{- \"<|\" + role + \"|>\" -}}\n  {% if add_space %}\n    {{- \" \" -}}\n  {% endif %}\n{% endmacro %}\n{##}\n{% macro is_param_set(param, field=none, is_list=False) %}\n  {% if field is not none %}\n    {% if field in param %}\n      {% set param = param[field] %}\n    {% else %}\n      {% set param = none %}\n    {% endif %}\n  {% endif %}\n  {% set is_defined = param is defined and param is not none %}\n  {% if is_list %}\n    {% set ns.is_last_checked_defined = is_defined and param|length > 0 %}\n  {% else %}\n    {% set ns.is_last_checked_defined = is_defined %}\n  {% endif %}\n{% endmacro %}\n{##}\n{##}\n{# Template #}\n{{- \"<|startoftext|>\" -}}\n{% set _ = is_param_set(system_message) %}\n{% set is_system_message_set = ns.is_last_checked_defined %}\n{% set _ = is_param_set(tools, is_list=True) %}\n{% set is_tools_set = ns.is_last_checked_defined %}\n{% set has_system_message = (is_system_message_set or is_tools_set) %}\n{% if has_system_message %}\n  {{- handle_first_system_message(system_message, tools) -}}\n{% endif %}\n{% set last_user_index = get_last_user_index(loop_messages)|int %}\n{% for message in loop_messages %}\n  {% if loop.index0 == last_user_index %}\n    {% set _ = is_param_set(documents, is_list=True) %}\n    {% set use_documents = ns.is_last_checked_defined %}\n    {% set _ = is_param_set(knobs) %}\n    {% set use_knobs = ns.is_last_checked_defined and knobs.is_set %}\n    {% set add_last_system_message = use_documents or use_knobs %}\n    {% if add_last_system_message %}\n      {% if ns.message_count > 0 %}\n        {{- eom_str -}}\n      {% endif %}\n      {{- handle_last_system_message(documents, knobs, use_documents, use_knobs) -}}\n    {% endif %}\n  {% endif %}\n  {% set role = message.role %}\n  {% set _ = is_param_set(message, field=\"name\") %}\n  {% set is_message_name_set = ns.is_last_checked_defined %}\n  {% if is_message_name_set %}\n    {% set message_prefix = handle_role(role) + \"(\" + message.name + \")\" %}\n  {% else %}\n    {% set message_prefix = handle_role(role) %}\n  {% endif %}\n  {% set content = (message.content or \"\") %}\n  {% if content is not string %}\n    {% set content = content|tojson %}\n  {% endif %}\n  {% if ns.message_count > 0 %}\n    {{- eom_str -}}\n  {% endif %}\n  {{- bom_str + message_prefix + content -}}\n  {% set _ = is_param_set(message, field=\"tool_calls\", is_list=True) %}\n  {% set is_tool_calls_set = ns.is_last_checked_defined %}\n  {% if role == \"assistant\" and is_tool_calls_set %}\n    {{- handle_tool_calls(message.tool_calls) -}}\n  {% endif %}\n  {% set _ = is_param_set(message, field=\"citations\", is_list=True) %}\n  {% set is_citations_set = ns.is_last_checked_defined %}\n  {% if role == \"assistant\" and is_citations_set %}\n    {{- citations_prefix + message.citations|map(attribute=\"document_id\")|list|string + citations_suffix -}}\n  {% endif %}\n  {% set ns.message_count = ns.message_count + 1 %}\n{% endfor %}\n{% if add_generation_prompt %}\n  {% if ns.message_count > 0 %}\n    {{- eom_str -}}\n  {% endif %}\n  {{- bom_str + handle_role(role_to_predict, add_space=False) -}}\n  {% set _ = is_param_set(generation_preamble) %}\n  {% set is_generation_preamble_set = ns.is_last_checked_defined %}\n  {% if is_generation_preamble_set and generation_preamble.strip() != \"\" %}\n    {{- \" \" + generation_preamble -}}\n  {% endif %}\n  {% set ns.message_count = ns.message_count + 1 %}\n{% else %}\n  {% if ns.message_count > 0 %}\n    {{- eom_str -}}\n  {% endif %}\n{% endif %}\n"
  },
  {
    "path": "src/axolotl/utils/chat_templates/templates/llama3.jinja",
    "content": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}\n"
  },
  {
    "path": "src/axolotl/utils/chat_templates/templates/llama3_2_vision.jinja",
    "content": "{{- bos_token }}\n{%- if custom_tools is defined %}\n    {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n    {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n    {%- if strftime_now is defined %}\n        {%- set date_string = strftime_now(\"%d %b %Y\") %}\n    {%- else %}\n        {%- set date_string = \"26 Jul 2024\" %}\n    {%- endif %}\n{%- endif %}\n{%- if not tools is defined %}\n    {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n    {%- set system_message = messages[0]['content']|trim %}\n    {%- set messages = messages[1:] %}\n{%- else %}\n    {%- set system_message = \"\" %}\n{%- endif %}\n\n{#- Find out if there are any images #}\n{% set image_ns = namespace(has_images=false) %}\n{%- for message in messages %}\n    {%- for content in message['content'] %}\n        {%- if content['type'] == 'image' %}\n            {%- set image_ns.has_images = true %}\n        {%- endif %}\n    {%- endfor %}\n{%- endfor %}\n\n{#- Error out if there are images and system message #}\n{%- if image_ns.has_images and not system_message == \"\" %}\n    {{- raise_exception(\"Prompting with images is incompatible with system messages.\") }}\n{%- endif %}\n\n{#- System message if there are no images #}\n{%- if not image_ns.has_images %}\n    {{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n    {%- if tools is not none %}\n        {{- \"Environment: ipython\\n\" }}\n    {%- endif %}\n    {{- \"Cutting Knowledge Date: December 2023\\n\" }}\n    {{- \"Today Date: \" + date_string + \"\\n\\n\" }}\n    {%- if tools is not none and not tools_in_user_message %}\n        {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n        {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n        {{- \"Do not use variables.\\n\\n\" }}\n        {%- for t in tools %}\n            {{- t | tojson(indent=4) }}\n            {{- \"\\n\\n\" }}\n        {%- endfor %}\n    {%- endif %}\n    {{- system_message }}\n    {{- \"<|eot_id|>\" }}\n{%- endif %}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n    {#- Extract the first user message so we can plug it in here #}\n    {%- if messages | length != 0 %}\n        {%- set first_user_message = messages[0]['content']|trim %}\n        {%- set messages = messages[1:] %}\n    {%- else %}\n        {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n    {{- '<|start_header_id|>user<|end_header_id|>\\n\\n' -}}\n    {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n    {{- \"with its proper arguments that best answers the given prompt.\\n\\n\" }}\n    {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n    {{- \"Do not use variables.\\n\\n\" }}\n    {%- for t in tools %}\n        {{- t | tojson(indent=4) }}\n        {{- \"\\n\\n\" }}\n    {%- endfor %}\n    {{- first_user_message + \"<|eot_id|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n    {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n    {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n' }}\n        {%- if message['content'] is string %}\n            {{- message['content'] }}\n        {%- else %}\n            {%- for content in message['content'] %}\n                {%- if content['type'] == 'image' %}\n                    {{- '<|image|>' }}\n                {%- elif content['type'] == 'text' %}\n                    {{- content['text'] }}\n                {%- endif %}\n            {%- endfor %}\n        {%- endif %}\n        {{- '<|eot_id|>' }}\n    {%- elif 'tool_calls' in message %}\n        {%- if not message.tool_calls|length == 1 %}\n            {{- raise_exception(\"This model only supports single tool-calls at once!\") }}\n        {%- endif %}\n        {%- set tool_call = message.tool_calls[0].function %}\n        {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n        {{- '{\"name\": \"' + tool_call.name + '\", ' }}\n        {{- '\"parameters\": ' }}\n        {{- tool_call.arguments | tojson }}\n        {{- \"}\" }}\n        {{- \"<|eot_id|>\" }}\n    {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n        {{- \"<|start_header_id|>ipython<|end_header_id|>\\n\\n\" }}\n        {%- if message.content is mapping or message.content is iterable %}\n            {{- message.content | tojson }}\n        {%- else %}\n            {{- message.content }}\n        {%- endif %}\n        {{- \"<|eot_id|>\" }}\n    {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n    {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n"
  },
  {
    "path": "src/axolotl/utils/chat_templates/templates/llama4.jinja",
    "content": "{{- bos_token }}\n{%- if custom_tools is defined %}\n    {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n    {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n    {%- if strftime_now is defined %}\n        {%- set date_string = strftime_now(\"%d %b %Y\") %}\n    {%- else %}\n        {%- set date_string = \"26 Jul 2024\" %}\n    {%- endif %}\n{%- endif %}\n{%- if not tools is defined %}\n    {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n    {%- if messages[0]['content'] is string %}\n        {%- set system_message = messages[0]['content']|trim %}\n    {%- else %}\n        {#- FIXME: The processor requires an array, always. #}\n        {%- set system_message = messages[0]['content'][0]['text']|trim %}\n    {%- endif %}\n    {%- set messages = messages[1:] %}\n    {%- set user_supplied_system_message = true %}\n{%- else %}\n    {%- set system_message = \"\" %}\n    {%- set user_supplied_system_message = false %}\n{%- endif %}\n\n{#- System message if the user supplied one #}\n{%- if user_supplied_system_message %}\n    {{- \"<|header_start|>system<|header_end|>\\n\\n\" }}\n    {%- if tools is not none %}\n        {{- \"Environment: ipython\\n\" }}\n    {%- endif %}\n    {%- if tools is not none and not tools_in_user_message %}\n        {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n        {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n        {{- \"Do not use variables.\\n\\n\" }}\n        {%- for t in tools %}\n            {{- t | tojson(indent=4) }}\n            {{- \"\\n\\n\" }}\n        {%- endfor %}\n    {%- endif %}\n    {{- system_message }}\n    {{- \"<|eot|>\" }}\n{%- endif %}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n    {#- Extract the first user message so we can plug it in here #}\n    {%- if messages | length != 0 %}\n        {%- set first_user_message = messages[0]['content']|trim %}\n        {%- set messages = messages[1:] %}\n    {%- else %}\n        {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n    {{- '<|header_start|>user<|header_end|>\\n\\n' -}}\n    {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n    {{- \"with its proper arguments that best answers the given prompt.\\n\\n\" }}\n    {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n    {{- \"Do not use variables.\\n\\n\" }}\n    {%- for t in tools %}\n        {{- t | tojson(indent=4) }}\n        {{- \"\\n\\n\" }}\n    {%- endfor %}\n    {{- first_user_message + \"<|eot|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n    {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n    {{- '<|header_start|>' + message['role'] + '<|header_end|>\\n\\n' }}\n        {%- if message['content'] is string %}\n            {{- message['content'] }}\n        {%- else %}\n            {%- for content in message['content'] %}\n                {%- if content['type'] == 'image' %}\n                    {{- '<|image|>' }}\n                {%- elif content['type'] == 'text' %}\n                    {{- content['text'] }}\n                {%- endif %}\n            {%- endfor %}\n        {%- endif %}\n        {{- \"<|eot|>\" }}\n    {%- elif 'tool_calls' in message and message.tool_calls|length > 0 %}\n       {{- '<|header_start|>assistant<|header_end|>\\n\\n' -}}\n       {{- '<|python_start|>' }}\n        {%- if message['content'] is string %}\n            {{- message['content'] }}\n        {%- else %}\n            {%- for content in message['content'] %}\n                {%- if content['type'] == 'image' %}\n                    {{- '<|image|>' }}\n                {%- elif content['type'] == 'text' %}\n                    {{- content['text'] }}\n                {%- endif %}\n            {%- endfor %}\n        {%- endif %}\n       {{- '<|python_end|>' }}\n        {%- for tool_call in message.tool_calls %}\n           {{- '{\"name\": \"' + tool_call.function.name + '\", ' }}\n           {{- '\"parameters\": ' }}\n           {{- tool_call.function.arguments | tojson }}\n           {{- \"}\" }}\n        {%- endfor %}\n       {{- \"<|eot|>\" }}\n    {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n        {{- \"<|header_start|>ipython<|header_end|>\\n\\n\" }}\n        {%- if message.content is mapping or message.content is iterable %}\n            {{- message.content | tojson }}\n        {%- else %}\n            {{- message.content }}\n        {%- endif %}\n        {{- \"<|eot|>\" }}\n    {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n    {{- '<|header_start|>assistant<|header_end|>\\n\\n' }}\n{%- endif %}\n"
  },
  {
    "path": "src/axolotl/utils/chat_templates/templates/llava.jinja",
    "content": "{% for message in messages %}{% if message['role'] != 'system' %}{{ message['role'].upper() + ': '}}{% endif %}{# Render all images first #}{% for content in message['content'] | selectattr('type', 'equalto', 'image') %}{{ '<image>\n' }}{% endfor %}{# Render all text next #}{% if message['role'] != 'assistant' %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{{ content['text'] + ' '}}{% endfor %}{% else %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{% generation %}{{ content['text'] + ' '}}{% endgeneration %}{% endfor %}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'ASSISTANT:' }}{% endif %}\n"
  },
  {
    "path": "src/axolotl/utils/chat_templates/templates/metharme.jinja",
    "content": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = 'Enter RP mode. You shall reply to the user while staying in character. Your responses must be detailed, creative, immersive, and drive the scenario forward.' %}{% endif %}{{ '<|system|>' + system_message }}{% for message in loop_messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|user|>' + content.strip() }}{% elif message['role'] == 'assistant' %}{{ '<|model|>'  + content.strip() }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|model|>' }}{% else %}{{ eos_token }}{% endif %}\n"
  },
  {
    "path": "src/axolotl/utils/chat_templates/templates/mistral_v1.jinja",
    "content": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ ' [INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}\n"
  },
  {
    "path": "src/axolotl/utils/chat_templates/templates/mistral_v2v3.jinja",
    "content": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + '[/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}\n"
  },
  {
    "path": "src/axolotl/utils/chat_templates/templates/mistral_v3_tekken.jinja",
    "content": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST]' + message['content'] + '[/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}\n"
  },
  {
    "path": "src/axolotl/utils/chat_templates/templates/mistral_v7_tekken.jinja",
    "content": "{%- set today = strftime_now(\"%Y-%m-%d\") %}\n{%- set default_system_message = \"You are Mistral Small 3, a Large Language Model (LLM) created by Mistral AI, a French startup headquartered in Paris.\\nYour knowledge base was last updated on 2023-10-01. The current date is \" + today + \".\\n\\nWhen you're not sure about some information, you say that you don't have the information and don't make up anything.\\nIf the user's question is not clear, ambiguous, or does not provide enough context for you to accurately answer the question, you do not try to answer it right away and you rather ask the user to clarify their request (e.g. \\\"What are some good restaurants around me?\\\" => \\\"Where are you?\\\" or \\\"When is the next flight to Tokyo\\\" => \\\"Where do you travel from?\\\")\" %}\n\n{{- bos_token }}\n\n{%- if messages[0]['role'] == 'system' %}\n    {%- if messages[0]['content'] is string %}\n        {%- set system_message = messages[0]['content'] %}\n    {%- else %}\n        {%- set system_message = messages[0]['content'][0]['text'] %}\n    {%- endif %}\n    {%- set loop_messages = messages[1:] %}\n{%- else %}\n    {%- set system_message = default_system_message %}\n    {%- set loop_messages = messages %}\n{%- endif %}\n{{- '[SYSTEM_PROMPT]' + system_message + '[/SYSTEM_PROMPT]' }}\n\n{%- for message in loop_messages %}\n    {%- if message['role'] == 'user' %}\n        {%- if message['content'] is string %}\n            {{- '[INST]' + message['content'] + '[/INST]' }}\n        {%- else %}\n            {{- '[INST]' }}\n            {%- for block in message['content'] %}\n                {%- if block['type'] == 'text' %}\n                    {{- block['text'] }}\n                {%- elif block['type'] in ['image', 'image_url'] %}\n                    {{- '[IMG]' }}\n                {%- else %}\n                    {{- raise_exception('Only text and image blocks are supported in message content!') }}\n                {%- endif %}\n            {%- endfor %}\n            {{- '[/INST]' }}\n        {%- endif %}\n    {%- elif message['role'] == 'system' %}\n        {%- if message['content'] is string %}\n            {{- '[SYSTEM_PROMPT]' + message['content'] + '[/SYSTEM_PROMPT]' }}\n        {%- else %}\n            {{- '[SYSTEM_PROMPT]' + message['content'][0]['text'] + '[/SYSTEM_PROMPT]' }}\n        {%- endif %}\n    {%- elif message['role'] == 'assistant' %}\n        {%- if message['content'] is string %}\n            {{- message['content'] + eos_token }}\n        {%- else %}\n            {{- message['content'][0]['text'] + eos_token }}\n        {%- endif %}\n    {%- else %}\n        {{- raise_exception('Only user, system and assistant roles are supported!') }}\n    {%- endif %}\n{%- endfor %}\n"
  },
  {
    "path": "src/axolotl/utils/chat_templates/templates/phi_3.jinja",
    "content": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '\n' + message['content'] + '<|end|>' + '\n'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}\n"
  },
  {
    "path": "src/axolotl/utils/chat_templates/templates/phi_35.jinja",
    "content": "{% for message in messages %}{% if message['role'] == 'system' and message['content'] %}{{'<|system|>\n' + message['content'] + '<|end|>\n'}}{% elif message['role'] == 'user' %}{{'<|user|>\n' + message['content'] + '<|end|>\n'}}{% elif message['role'] == 'assistant' %}{{'<|assistant|>\n' + message['content'] + '<|end|>\n'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% endif %}\n"
  },
  {
    "path": "src/axolotl/utils/chat_templates/templates/phi_4.jinja",
    "content": "{% set system_message = 'You are Phi, a language model trained by Microsoft to help users. Your role as an assistant involves thoroughly exploring questions through a systematic thinking process before providing the final precise and accurate solutions. This requires engaging in a comprehensive cycle of analysis, summarizing, exploration, reassessment, reflection, backtracing, and iteration to develop well-considered thinking process. Please structure your response into two main sections: Thought and Solution using the specified format: <think> {Thought section} </think> {Solution section}. In the Thought section, detail your reasoning process in steps. Each step should include detailed considerations such as analysing questions, summarizing relevant findings, brainstorming new ideas, verifying the accuracy of the current steps, refining any errors, and revisiting previous steps. In the Solution section, based on various attempts, explorations, and reflections from the Thought section, systematically present the final solution that you deem correct. The Solution section should be logical, accurate, and concise and detail necessary steps needed to reach the conclusion. Now, try to solve the following question through the above guidelines:' -%}{%- if messages and messages[0]['role'] == 'system' -%}{%- set system_message = messages[0]['content'] -%}{%- set messages = messages[1:] -%}{%- endif -%}<|im_start|>system<|im_sep|>{{ system_message }}<|im_end|>{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|im_start|>user<|im_sep|>' + message['content'] + '<|im_end|>'}}{% elif (message['role'] == 'assistant') %}{{'<|im_start|>assistant<|im_sep|>'}}{% generation %}{{message['content'] + '<|im_end|>'}}{% endgeneration %}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant<|im_sep|>' }}{% endif %}\n"
  },
  {
    "path": "src/axolotl/utils/chat_templates/templates/pixtral.jinja",
    "content": "{%- if messages[0][\"role\"] == \"system\" %}\n    {%- set system_message = messages[0][\"content\"] %}\n    {%- set loop_messages = messages[1:] %}\n{%- else %}\n    {%- set loop_messages = messages %}\n{%- endif %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n    {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}\n        {{- raise_exception('After the optional system message, conversation roles must alternate user/assistant/user/assistant/...') }}\n    {%- endif %}\n    {%- if message[\"role\"] == \"user\" %}\n        {%- if loop.last and system_message is defined %}\n            {{- \"[INST]\" + system_message + \"\n\n\" }}\n        {%- else %}\n            {{- \"[INST]\" }}\n        {%- endif %}\n        {%- if message[\"content\"] is not string %}\n            {%- for chunk in message[\"content\"] %}\n                {%- if chunk[\"type\"] == \"text\" %}\n                    {{- chunk[\"text\"] }}\n                {%- elif chunk[\"type\"] == \"image\" %}\n                    {{- \"[IMG]\" }}\n                {%- else %}\n                    {{- raise_exception(\"Unrecognized content type!\") }}\n                {%- endif %}\n            {%- endfor %}\n        {%- else %}\n            {{- message[\"content\"] }}\n        {%- endif %}\n        {{- \"[/INST]\" }}\n    {%- elif message[\"role\"] == \"assistant\" %}\n {%- if message[\"content\"] is not string %}\n {%- for chunk in message[\"content\"] %}\n {%- if chunk[\"type\"] == \"text\" %}\n {{- chunk[\"text\"] }}\n {%- elif chunk[\"type\"] == \"image\" %}\n {{- \"[IMG]\" }}\n {%- else %}\n {{- raise_exception(\"Unrecognized content type!\") }}\n{%- endif %}\n{%- endfor %}\n{{- eos_token }}\n{%- else %}\n{{- message[\"content\"] + eos_token }}\n{%- endif %}\n    {%- else %}\n        {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n    {%- endif %}\n{%- endfor %}\n"
  },
  {
    "path": "src/axolotl/utils/chat_templates/templates/qwen2_vl.jinja",
    "content": "{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n{% endif %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}\n"
  },
  {
    "path": "src/axolotl/utils/chat_templates/templates/qwen3.jinja",
    "content": "{%- if tools %}\n    {{- '<|im_start|>system\\n' }}\n    {%- if messages[0].role == 'system' %}\n        {{- messages[0].content + '\\n\\n' }}\n    {%- endif %}\n    {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n    {%- for tool in tools %}\n        {{- \"\\n\" }}\n        {{- tool | tojson }}\n    {%- endfor %}\n    {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n    {%- if messages[0].role == 'system' %}\n        {{- '<|im_start|>system\\n' + messages[0].content + '<|im_end|>\\n' }}\n    {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{#- Determine the real last index: use provided value or default to messages length - 1 #}\n{%- if real_last_index is defined and real_last_index is not none %}\n    {%- set ns.real_last_index = real_last_index %}\n{%- else %}\n    {%- set ns.real_last_index = messages|length - 1 %}\n{%- endif %}\n{%- for message in messages[::-1] %}\n    {%- set index = (messages|length - 1) - loop.index0 %}\n    {%- if ns.multi_step_tool and message.role == \"user\" and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}\n        {%- set ns.multi_step_tool = false %}\n        {%- set ns.last_query_index = index %}\n    {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n    {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n        {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n    {%- elif message.role == \"assistant\" %}\n        {%- set content = message.content %}\n        {%- set reasoning_content = '' %}\n        {%- if message.reasoning_content is defined and message.reasoning_content is not none %}\n            {%- set reasoning_content = message.reasoning_content %}\n        {%- else %}\n            {%- if '</think>' in message.content %}\n                {%- set content = message.content.split('</think>')[-1].lstrip('\\n') %}\n                {%- set reasoning_content = message.content.split('</think>')[0].rstrip('\\n').split('<think>')[-1].lstrip('\\n') %}\n            {%- endif %}\n        {%- endif %}\n        {%- if loop.index0 > ns.last_query_index %}\n            {%- if loop.index0 == ns.real_last_index or (loop.index0 != ns.real_last_index and reasoning_content) %}\n                {{- '<|im_start|>' + message.role + '\\n<think>\\n' + reasoning_content.strip('\\n') + '\\n</think>\\n\\n' + content.lstrip('\\n') }}\n            {%- else %}\n                {{- '<|im_start|>' + message.role + '\\n' + content }}\n            {%- endif %}\n        {%- else %}\n            {{- '<|im_start|>' + message.role + '\\n' + content }}\n        {%- endif %}\n        {%- if message.tool_calls %}\n            {%- for tool_call in message.tool_calls %}\n                {%- if (loop.first and content) or (not loop.first) %}\n                    {{- '\\n' }}\n                {%- endif %}\n                {%- if tool_call.function %}\n                    {%- set tool_call = tool_call.function %}\n                {%- endif %}\n                {{- '<tool_call>\\n{\"name\": \"' }}\n                {{- tool_call.name }}\n                {{- '\", \"arguments\": ' }}\n                {%- if tool_call.arguments is string %}\n                    {{- tool_call.arguments }}\n                {%- else %}\n                    {{- tool_call.arguments | tojson }}\n                {%- endif %}\n                {{- '}\\n</tool_call>' }}\n            {%- endfor %}\n        {%- endif %}\n        {{- '<|im_end|>\\n' }}\n    {%- elif message.role == \"tool\" %}\n        {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n            {{- '<|im_start|>user' }}\n        {%- endif %}\n        {{- '\\n<tool_response>\\n' }}\n        {{- message.content }}\n        {{- '\\n</tool_response>' }}\n        {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n            {{- '<|im_end|>\\n' }}\n        {%- endif %}\n    {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n    {{- '<|im_start|>assistant\\n' }}\n    {%- if enable_thinking is defined and enable_thinking is false %}\n        {{- '<think>\\n\\n</think>\\n\\n' }}\n    {%- else %}\n        {{- '<think>\\n\\n' }}\n    {%- endif %}\n{%- endif %}\n"
  },
  {
    "path": "src/axolotl/utils/chat_templates/templates/qwen3_5.jinja",
    "content": "{%- if tools %}\n    {{- '<|im_start|>system\\n' }}\n    {%- if messages[0].role == 'system' %}\n        {{- messages[0].content + '\\n\\n' }}\n    {%- endif %}\n    {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n    {%- for tool in tools %}\n        {{- \"\\n\" }}\n        {{- tool | tojson }}\n    {%- endfor %}\n    {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n    {%- if messages[0].role == 'system' %}\n        {{- '<|im_start|>system\\n' + messages[0].content + '<|im_end|>\\n' }}\n    {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{#- Determine the real last index: use provided value or default to messages length - 1 #}\n{%- if real_last_index is defined and real_last_index is not none %}\n    {%- set ns.real_last_index = real_last_index %}\n{%- else %}\n    {%- set ns.real_last_index = messages|length - 1 %}\n{%- endif %}\n{%- for message in messages[::-1] %}\n    {%- set index = (messages|length - 1) - loop.index0 %}\n    {%- if message['content'] is string %}\n        {%- if ns.multi_step_tool and message.role == \"user\" and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}\n            {%- set ns.multi_step_tool = false %}\n            {%- set ns.last_query_index = index %}\n        {%- endif %}\n    {%- else %}\n        {%- if ns.multi_step_tool and message.role == \"user\" %}\n            {%- set ns.multi_step_tool = false %}\n            {%- set ns.last_query_index = index %}\n        {%- endif %}\n    {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n    {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n        {{- '<|im_start|>' + message.role + '\\n' }}\n        {%- if message['content'] is string %}\n            {{- message.content }}\n        {%- else %}\n            {%- for content in message['content'] %}\n                {%- if content['type'] == 'image' or 'image' in content or 'image_url' in content %}\n                    {{- '<|vision_start|><|image_pad|><|vision_end|>' }}\n                {%- elif content['type'] == 'video' or 'video' in content %}\n                    {{- '<|vision_start|><|video_pad|><|vision_end|>' }}\n                {%- elif 'text' in content %}\n                    {{- content['text'] }}\n                {%- endif %}\n            {%- endfor %}\n        {%- endif %}\n        {{- '<|im_end|>\\n' }}\n    {%- elif message.role == \"assistant\" %}\n        {%- if message['content'] is string %}\n            {%- set content = message.content %}\n        {%- else %}\n            {%- set content = '' %}\n            {%- for item in message['content'] %}\n                {%- if 'text' in item %}\n                    {%- set content = content + item['text'] %}\n                {%- endif %}\n            {%- endfor %}\n        {%- endif %}\n        {%- set reasoning_content = '' %}\n        {%- if message.reasoning_content is defined and message.reasoning_content is not none %}\n            {%- set reasoning_content = message.reasoning_content %}\n        {%- else %}\n            {%- if '</think>' in content %}\n                {%- set content = content.split('</think>')[-1].lstrip('\\n') %}\n                {%- set reasoning_content = content.split('</think>')[0].rstrip('\\n').split('<think>')[-1].lstrip('\\n') %}\n            {%- endif %}\n        {%- endif %}\n        {%- if loop.index0 > ns.last_query_index %}\n            {%- if loop.index0 == ns.real_last_index or (loop.index0 != ns.real_last_index and reasoning_content) %}\n                {{- '<|im_start|>' + message.role + '\\n<think>\\n' + reasoning_content.strip('\\n') + '\\n</think>\\n\\n' + content.lstrip('\\n') }}\n            {%- else %}\n                {{- '<|im_start|>' + message.role + '\\n' + content }}\n            {%- endif %}\n        {%- else %}\n            {{- '<|im_start|>' + message.role + '\\n' + content }}\n        {%- endif %}\n        {%- if message.tool_calls %}\n            {%- for tool_call in message.tool_calls %}\n                {%- if (loop.first and content) or (not loop.first) %}\n                    {{- '\\n' }}\n                {%- endif %}\n                {%- if tool_call.function %}\n                    {%- set tool_call = tool_call.function %}\n                {%- endif %}\n                {{- '<tool_call>\\n{\"name\": \"' }}\n                {{- tool_call.name }}\n                {{- '\", \"arguments\": ' }}\n                {%- if tool_call.arguments is string %}\n                    {{- tool_call.arguments }}\n                {%- else %}\n                    {{- tool_call.arguments | tojson }}\n                {%- endif %}\n                {{- '}\\n</tool_call>' }}\n            {%- endfor %}\n        {%- endif %}\n        {{- '<|im_end|>\\n' }}\n    {%- elif message.role == \"tool\" %}\n        {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n            {{- '<|im_start|>user' }}\n        {%- endif %}\n        {{- '\\n<tool_response>\\n' }}\n        {{- message.content }}\n        {{- '\\n</tool_response>' }}\n        {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n            {{- '<|im_end|>\\n' }}\n        {%- endif %}\n    {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n    {{- '<|im_start|>assistant\\n' }}\n    {%- if enable_thinking is defined and enable_thinking is false %}\n        {{- '<think>\\n\\n</think>\\n\\n' }}\n    {%- else %}\n        {{- '<think>\\n\\n' }}\n    {%- endif %}\n{%- endif %}\n"
  },
  {
    "path": "src/axolotl/utils/chat_templates/templates/qwen_25.jinja",
    "content": "{%- if tools %}\n    {{- '<|im_start|>system\\n' }}\n    {%- if messages[0]['role'] == 'system' %}\n        {{- messages[0]['content'] }}\n    {%- else %}\n        {{- 'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.' }}\n    {%- endif %}\n    {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n    {%- for tool in tools %}\n        {{- \"\\n\" }}\n        {{- tool | tojson }}\n    {%- endfor %}\n    {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n    {%- if messages[0]['role'] == 'system' %}\n        {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n    {%- else %}\n        {{- '<|im_start|>system\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n' }}\n    {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n    {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n        {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n    {%- elif message.role == \"assistant\" %}\n        {{- '<|im_start|>' + message.role }}\n        {%- if message.content %}\n            {{- '\\n' + message.content }}\n        {%- endif %}\n        {%- for tool_call in message.tool_calls %}\n            {%- if tool_call.function is defined %}\n                {%- set tool_call = tool_call.function %}\n            {%- endif %}\n            {{- '\\n<tool_call>\\n{\"name\": \"' }}\n            {{- tool_call.name }}\n            {{- '\", \"arguments\": ' }}\n            {{- tool_call.arguments | tojson }}\n            {{- '}\\n</tool_call>' }}\n        {%- endfor %}\n        {{- '<|im_end|>\\n' }}\n    {%- elif message.role == \"tool\" %}\n        {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n            {{- '<|im_start|>user' }}\n        {%- endif %}\n        {{- '\\n<tool_response>\\n' }}\n        {{- message.content }}\n        {{- '\\n</tool_response>' }}\n        {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n            {{- '<|im_end|>\\n' }}\n        {%- endif %}\n    {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n    {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n"
  },
  {
    "path": "src/axolotl/utils/collators/__init__.py",
    "content": "\"\"\"Shared axolotl collators for multipacking, mamba, multimodal.\"\"\"\n\nfrom .batching import (\n    BatchSamplerDataCollatorForSeq2Seq,\n    DataCollatorForSeq2Seq,\n    PretrainingBatchSamplerDataCollatorForSeq2Seq,\n    V2BatchSamplerDataCollatorForSeq2Seq,\n)\nfrom .mamba import MambaDataCollator\n\n__all__ = [\n    \"DataCollatorForSeq2Seq\",\n    \"BatchSamplerDataCollatorForSeq2Seq\",\n    \"V2BatchSamplerDataCollatorForSeq2Seq\",\n    \"PretrainingBatchSamplerDataCollatorForSeq2Seq\",\n    \"MambaDataCollator\",\n]\n"
  },
  {
    "path": "src/axolotl/utils/collators/batching.py",
    "content": "\"\"\"Data collators for axolotl to pad labels and position_ids for packed sequences\"\"\"\n\nfrom dataclasses import dataclass\nfrom typing import Any, List\n\nimport numpy as np\nfrom transformers import PreTrainedTokenizerBase\nfrom transformers.utils import PaddingStrategy\n\n\n@dataclass\nclass DataCollatorForSeq2Seq:\n    \"\"\"\n    Data collator that will dynamically pad the inputs received, as well as the labels and position_ids\n\n    Args:\n        tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):\n            The tokenizer used for encoding the data.\n        model ([`PreTrainedModel`]):\n            The model that is being trained. If set and has the *prepare_decoder_input_ids_from_labels*, use it to\n            prepare the *decoder_input_ids*\n\n            This is useful when using *label_smoothing* to avoid calculating loss twice.\n        padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):\n            Select a strategy to pad the returned sequences (according to the model's padding side and padding index)\n            among:\n\n            - `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single\n              sequence is provided).\n            - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum\n              acceptable input length for the model if that argument is not provided.\n            - `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different lengths).\n        max_length (`int`, *optional*):\n            Maximum length of the returned list and optionally padding length (see above).\n        pad_to_multiple_of (`int`, *optional*):\n            If set will pad the sequence to a multiple of the provided value.\n\n            This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=\n            7.5 (Volta).\n        label_pad_token_id (`int`, *optional*, defaults to -100):\n            The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions).\n        return_tensors (`str`):\n            The type of Tensor to return. Allowable values are \"np\", \"pt\" and \"tf\".\n    \"\"\"\n\n    tokenizer: PreTrainedTokenizerBase\n    model: Any | None = None\n    padding: bool | str | PaddingStrategy = True\n    max_length: int | None = None\n    pad_to_multiple_of: int | None = None\n    label_pad_token_id: int = -100\n    position_pad_token_id: int = 0\n    return_tensors: str = \"pt\"\n\n    def __call__(self, features, return_tensors=None):\n        has_attn_mask = \"attention_mask\" in features[0].keys()\n        labels = None\n        if return_tensors is None:\n            return_tensors = self.return_tensors\n\n        for feature_name, pad_token_id in [\n            (\"labels\", self.label_pad_token_id),\n            (\"position_ids\", self.position_pad_token_id),\n        ]:\n            feat = (\n                [feature[feature_name] for feature in features]\n                if feature_name in features[0].keys()\n                else None\n            )\n            labels = feat if feat and feature_name == \"labels\" else labels\n            # We have to pad the labels before calling `tokenizer.pad` as this method won't pad them and needs them of the\n            # same length to return tensors.\n            if feat is not None:\n                max_feature_length = max(len(l) for l in feat)  # noqa: E741\n                if self.pad_to_multiple_of is not None:\n                    max_feature_length = (\n                        (max_feature_length + self.pad_to_multiple_of - 1)\n                        // self.pad_to_multiple_of\n                        * self.pad_to_multiple_of\n                    )\n\n                padding_side = self.tokenizer.padding_side\n                for feature in features:\n                    remainder_len = max_feature_length - len(feature[feature_name])\n                    if feature_name == \"position_ids\":\n                        remainder = list(range(remainder_len))\n                    else:\n                        remainder = [pad_token_id] * remainder_len\n                    if isinstance(feature[feature_name], list):\n                        feature[feature_name] = (\n                            feature[feature_name] + remainder\n                            if padding_side == \"right\"\n                            else remainder + feature[feature_name]\n                        )\n                    elif padding_side == \"right\":\n                        feature[feature_name] = np.concatenate(\n                            [feature[feature_name], remainder]\n                        ).astype(np.int64)\n                    else:\n                        feature[feature_name] = np.concatenate(\n                            [remainder, feature[feature_name]]\n                        ).astype(np.int64)\n\n        features = self.tokenizer.pad(\n            features,\n            padding=self.padding,\n            max_length=self.max_length,\n            pad_to_multiple_of=self.pad_to_multiple_of,\n            return_tensors=return_tensors,\n        )\n        if not has_attn_mask and \"attention_mask\" in features:\n            del features[\"attention_mask\"]\n\n        # prepare decoder_input_ids\n        if (\n            labels is not None\n            and self.model is not None\n            and hasattr(self.model, \"prepare_decoder_input_ids_from_labels\")\n        ):\n            decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels(\n                labels=features[\"labels\"]\n            )\n            features[\"decoder_input_ids\"] = decoder_input_ids\n\n        return features\n\n\n@dataclass\nclass BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):\n    \"\"\"\n    Collator for multipack specific to the using the BatchSampler\n    \"\"\"\n\n    def __call__(self, features, return_tensors=None):\n        if not isinstance(features[0], list):\n            features = [features]\n        out_features = [{} for _ in features]\n        for i, features_ in enumerate(features):\n            for feature in features_[0].keys():\n                if feature == \"length\":\n                    continue\n                if feature == \"attention_mask\":\n                    arrays = [\n                        (1) * np.array(item[feature])\n                        for i, item in enumerate(features_)\n                        if feature in item\n                    ]\n                    out_features[i][feature] = np.concatenate(arrays)\n                else:\n                    arrays = [\n                        np.array(item[feature]) for item in features_ if feature in item\n                    ]\n                    out_features[i][feature] = np.concatenate(arrays)\n\n        return super().__call__(out_features, return_tensors=return_tensors)\n\n\n@dataclass\nclass V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):\n    \"\"\"\n    Collator for multipack specific to the using the BatchSampler\n    \"\"\"\n\n    squash_position_ids: bool = False\n\n    def __call__(self, features, return_tensors=None):\n        if not isinstance(features[0], list):\n            features: List[List[dict]] = [features]\n        out_features = [{} for _ in features]\n        for i, features_ in enumerate(features):\n            for feature in features_[0].keys():\n                if feature == \"length\":\n                    continue\n                if feature == \"attention_mask\":\n                    arrays = [\n                        (i + 1) * np.array(item[feature])\n                        for i, item in enumerate(features_)\n                        if feature in item\n                    ]\n                    out_features[i][feature] = np.concatenate(arrays)\n                elif feature == \"position_ids\" and self.squash_position_ids:\n                    arrays = [\n                        np.array(item[feature]) for item in features_ if feature in item\n                    ]\n                    # concatenate, get total length and create arange of new total position ids\n                    position_ids = np.concatenate(arrays)\n                    total_length = position_ids.shape[0]\n                    position_ids = np.arange(total_length)\n                    out_features[i][feature] = position_ids\n                else:\n                    arrays = [\n                        np.array(item[feature]) for item in features_ if feature in item\n                    ]\n                    out_features[i][feature] = np.concatenate(arrays)\n\n        return super().__call__(out_features, return_tensors=return_tensors)\n\n\n@dataclass\nclass PretrainingBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):\n    \"\"\"\n    Collator for multipack specific to the using the BatchSampler\n    \"\"\"\n\n    def __init__(self, *args, multipack_attn=True, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.multipack_attn = multipack_attn\n\n    def __call__(self, features, return_tensors=None):\n        chunked_data = {}\n        for feature in features.keys():\n            if feature == \"length\":\n                continue\n            if feature == \"attention_mask\":\n                if self.multipack_attn:\n                    arrays = [\n                        (i + 1) * np.array(item)\n                        for i, item in enumerate(features[feature])\n                    ]\n                else:\n                    arrays = [(1) * np.array(item) for item in features[feature]]\n                chunked_data[feature] = np.concatenate(arrays)\n            else:\n                arrays = [np.array(item) for item in features[feature]]\n                chunked_data[feature] = np.concatenate(arrays)\n        features = [chunked_data]\n        return super().__call__(features, return_tensors=return_tensors)\n"
  },
  {
    "path": "src/axolotl/utils/collators/core.py",
    "content": "\"\"\"\nbasic shared collator constants\n\"\"\"\n\nIGNORE_INDEX = -100\n"
  },
  {
    "path": "src/axolotl/utils/collators/mamba.py",
    "content": "\"\"\"\ncollators for Mamba\n\"\"\"\n\nfrom dataclasses import dataclass\nfrom typing import Dict, Sequence\n\nimport torch\nimport transformers\n\nfrom axolotl.utils.collators.core import IGNORE_INDEX\n\n\n@dataclass\nclass MambaDataCollator:\n    \"\"\"\n    Collator for State Space Models (Mamba)\n    \"\"\"\n\n    tokenizer: transformers.PreTrainedTokenizer\n\n    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:\n        input_ids, labels = tuple(\n            [torch.LongTensor(instance[key]) for instance in instances]\n            for key in (\"input_ids\", \"labels\")\n        )\n        input_ids = torch.nn.utils.rnn.pad_sequence(\n            input_ids,\n            batch_first=True,\n            padding_value=self.tokenizer.pad_token_id,\n        )\n        labels = torch.nn.utils.rnn.pad_sequence(\n            labels, batch_first=True, padding_value=IGNORE_INDEX\n        )\n\n        return {\n            \"input_ids\": input_ids,\n            \"labels\": labels,\n        }\n"
  },
  {
    "path": "src/axolotl/utils/collators/mm_chat.py",
    "content": "\"\"\"\nCollators for multi-modal chat messages and packing\n\"\"\"\n\nfrom dataclasses import dataclass\nfrom typing import Any, Optional, Union\n\nfrom torch import Tensor\nfrom transformers import PreTrainedTokenizerBase\nfrom transformers.data.data_collator import DataCollatorMixin\nfrom transformers.utils import PaddingStrategy\n\nfrom axolotl.processing_strategies import ProcessingStrategy\n\n\n@dataclass\nclass MultiModalChatDataCollator(DataCollatorMixin):\n    \"\"\"\n    Collator for multi-modal chat messages\n    \"\"\"\n\n    tokenizer: PreTrainedTokenizerBase\n    processing_strategy: ProcessingStrategy\n    packing: bool = False\n    return_tensors: str = \"pt\"\n    padding: Union[bool, str, PaddingStrategy] = True\n    pad_to_multiple_of: Optional[int] = None\n\n    def __post_init__(self):\n        if self.packing:\n            raise ValueError(\"Packing is currently not supported.\")\n\n    def torch_call(self, examples: list[dict]) -> dict[str, Any]:\n        return self.process_rows(examples)\n\n    def process_rows(\n        self,\n        examples: list[dict],\n    ) -> dict[str, Tensor]:\n        # Preprocess the examples\n        examples = self.processing_strategy(examples)\n\n        # Initialize batch\n        messages = [ex[\"messages\"] for ex in examples]\n\n        batch = self.processing_strategy.processor.apply_chat_template(\n            messages,\n            add_generation_prompt=False,\n            tokenize=True,\n            return_tensors=\"pt\",\n            padding=True,\n            return_dict=True,\n            chat_template=self.processing_strategy.chat_template,\n        )\n\n        # Process the labels\n        batch[\"labels\"] = self.processing_strategy.process_labels(batch[\"input_ids\"])\n\n        return batch\n"
  },
  {
    "path": "src/axolotl/utils/comet_.py",
    "content": "\"\"\"Module for wandb utilities\"\"\"\n\nimport os\n\nfrom axolotl.utils.dict import DictDefault\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\nCOMET_ENV_MAPPING_OVERRIDE = {\n    \"comet_mode\": \"COMET_START_MODE\",\n    \"comet_online\": \"COMET_START_ONLINE\",\n}\nCOMET_EXPERIMENT_CONFIG_ENV_MAPPING_OVERRIDE = {\n    \"auto_histogram_activation_logging\": \"COMET_AUTO_LOG_HISTOGRAM_ACTIVATIONS\",\n    \"auto_histogram_epoch_rate\": \"COMET_AUTO_LOG_HISTOGRAM_EPOCH_RATE\",\n    \"auto_histogram_gradient_logging\": \"COMET_AUTO_LOG_HISTOGRAM_GRADIENTS\",\n    \"auto_histogram_tensorboard_logging\": \"COMET_AUTO_LOG_HISTOGRAM_TENSORBOARD\",\n    \"auto_histogram_weight_logging\": \"COMET_AUTO_LOG_HISTOGRAM_WEIGHTS\",\n    \"auto_log_co2\": \"COMET_AUTO_LOG_CO2\",\n    \"auto_metric_logging\": \"COMET_AUTO_LOG_METRICS\",\n    \"auto_metric_step_rate\": \"COMET_AUTO_LOG_METRIC_STEP_RATE\",\n    \"auto_output_logging\": \"COMET_AUTO_LOG_OUTPUT_LOGGER\",\n    \"auto_param_logging\": \"COMET_AUTO_LOG_PARAMETERS\",\n    \"comet_disabled\": \"COMET_AUTO_LOG_DISABLE\",\n    \"display_summary_level\": \"COMET_DISPLAY_SUMMARY_LEVEL\",\n    \"distributed_node_identifier\": \"COMET_DISTRIBUTED_NODE_IDENTIFIER\",\n    \"log_code\": \"COMET_AUTO_LOG_CODE\",\n    \"log_env_cpu\": \"COMET_AUTO_LOG_ENV_CPU\",\n    \"log_env_details\": \"COMET_AUTO_LOG_ENV_DETAILS\",\n    \"log_env_disk\": \"COMET_AUTO_LOG_ENV_DISK\",\n    \"log_env_gpu\": \"COMET_AUTO_LOG_ENV_GPU\",\n    \"log_env_host\": \"COMET_AUTO_LOG_ENV_HOST\",\n    \"log_env_network\": \"COMET_AUTO_LOG_ENV_NETWORK\",\n    \"log_git_metadata\": \"COMET_AUTO_LOG_GIT_METADATA\",\n    \"log_git_patch\": \"COMET_AUTO_LOG_GIT_PATCH\",\n    \"log_graph\": \"COMET_AUTO_LOG_GRAPH\",\n    \"name\": \"COMET_START_EXPERIMENT_NAME\",\n    \"offline_directory\": \"COMET_OFFLINE_DIRECTORY\",\n    \"parse_args\": \"COMET_AUTO_LOG_CLI_ARGUMENTS\",\n    \"tags\": \"COMET_START_EXPERIMENT_TAGS\",\n}\n\n\ndef python_value_to_environ_value(python_value):\n    if isinstance(python_value, bool):\n        if python_value is True:\n            return \"true\"\n\n        return \"false\"\n\n    if isinstance(python_value, int):\n        return str(python_value)\n\n    if isinstance(python_value, list):  # Comet only have one list of string parameter\n        return \",\".join(map(str, python_value))\n\n    return python_value\n\n\ndef setup_comet_env_vars(cfg: DictDefault):\n    # TODO, we need to convert Axolotl configuration to environment variables\n    # as Transformers integration are call first and would create an\n    # Experiment first\n\n    for key in cfg.keys():\n        if key.startswith(\"comet_\") and key != \"comet_experiment_config\":\n            value = cfg.get(key, \"\")\n\n            if value is not None and value != \"\":\n                env_variable_name = COMET_ENV_MAPPING_OVERRIDE.get(key, key.upper())\n                final_value = python_value_to_environ_value(value)\n                os.environ[env_variable_name] = final_value\n\n    if cfg.comet_experiment_config:\n        for key, value in cfg.comet_experiment_config.items():\n            if value is not None and value != \"\":\n                config_env_variable_name = (\n                    COMET_EXPERIMENT_CONFIG_ENV_MAPPING_OVERRIDE.get(key)\n                )\n\n                if config_env_variable_name is None:\n                    LOG.warning(\n                        f\"Unknown Comet Experiment Config name {key}, ignoring it\"\n                    )\n                    continue\n\n                final_value = python_value_to_environ_value(value)\n                os.environ[config_env_variable_name] = final_value\n\n    # Enable comet if project name is present\n    if cfg.comet_project_name and len(cfg.comet_project_name) > 0:\n        cfg.use_comet = True\n"
  },
  {
    "path": "src/axolotl/utils/config/__init__.py",
    "content": "\"\"\"Module for working with config dicts\"\"\"\n\nimport json\nimport os\nfrom typing import Optional\n\nimport torch\nfrom transformers.utils import is_torch_bf16_gpu_available\nfrom transformers.utils.import_utils import (\n    is_torch_greater_or_equal,\n    is_torch_npu_available,\n)\n\nfrom axolotl.integrations.base import PluginManager\nfrom axolotl.integrations.config import merge_input_args\nfrom axolotl.loaders import MULTIMODAL_AUTO_MODEL_MAPPING\nfrom axolotl.loaders.utils import load_model_config\nfrom axolotl.utils.bench import log_gpu_memory_usage\nfrom axolotl.utils.dict import DictDefault\nfrom axolotl.utils.logging import get_logger\nfrom axolotl.utils.schemas.config import (\n    AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,\n    AxolotlInputConfig as AxolotlInputConfigBase,\n)\nfrom axolotl.utils.schemas.datasets import DPODataset, KTODataset, SFTDataset\n\nLOG = get_logger(__name__)\n\n\ndef choose_device(cfg):\n    def get_device():\n        try:\n            if torch.cuda.is_available():\n                return f\"cuda:{cfg.local_rank}\"\n\n            if torch.backends.mps.is_available():\n                return \"mps\"\n\n            if is_torch_npu_available():\n                return f\"npu:{cfg.local_rank}\"\n\n            raise SystemError(\"No CUDA/mps/npu device found\")\n        except Exception:\n            return \"cpu\"\n\n    cfg.device = get_device()\n    if cfg.world_size == 1:\n        cfg.device_map = cfg.device_map or \"auto\"\n    else:\n        if cfg.device.startswith(\"cuda\"):\n            cfg.device_map = {\"\": torch.cuda.current_device()}\n        elif cfg.device.startswith(\"npu\"):\n            cfg.device_map = {\"npu\": torch.npu.current_device()}\n        else:\n            cfg.device_map = {\"\": cfg.device}\n\n    # in `accelerate launch`, we need to not pass through any device map and let\n    # accelerate figure out which parts of the model to put on which gpu\n    accelerate_vars = [var for var in os.environ if var.startswith(\"ACCELERATE_USE_\")]\n    if accelerate_vars:\n        cfg.device_map = None\n\n\ndef resolve_dtype(cfg):\n    if (\n        not cfg.fp16 and cfg.bf16 == \"auto\" and not cfg.use_ray\n    ):  # if we use ray we want to defer this check to the worker node\n        if is_torch_bf16_gpu_available():\n            LOG.debug(\"bf16 support detected, enabling for this configuration.\")\n            cfg.bf16 = True\n        else:\n            LOG.debug(\"bf16 support not detected, disabling for this configuration.\")\n            cfg.bf16 = False\n            if cfg.fp16 is None and not cfg.float16:\n                cfg.fp16 = True\n\n    if cfg.fp16 and cfg.bf16 == \"auto\":\n        cfg.bf16 = False\n\n    if cfg.device == \"mps\":\n        cfg.load_in_8bit = False\n        cfg.tf32 = False\n        if cfg.bf16 and cfg.fp16 is not False:\n            cfg.fp16 = True\n        cfg.bf16 = False\n    else:\n        if cfg.tf32 is True:\n            torch.set_float32_matmul_precision(\"high\")\n            if is_torch_greater_or_equal(\"2.9.0\"):\n                torch.backends.fp32_precision = \"tf32\"\n                torch.backends.cuda.matmul.fp32_precision = \"tf32\"\n                torch.backends.cudnn.fp32_precision = \"tf32\"\n            else:\n                torch.backends.cuda.matmul.allow_tf32 = True\n                torch.backends.cudnn.allow_tf32 = True\n        if cfg.bf16:\n            cfg.fp16 = False\n\n    if cfg.bf16 or cfg.bfloat16:\n        cfg.torch_dtype = torch.bfloat16\n    elif cfg.load_in_8bit or cfg.fp16 or cfg.float16:\n        cfg.torch_dtype = torch.float16\n    else:\n        cfg.torch_dtype = torch.float32\n\n\ndef normalize_config(cfg):\n    # setup some derived config / hyperparams\n    cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (\n        cfg.batch_size // cfg.micro_batch_size\n    )\n    cfg.batch_size = (\n        cfg.batch_size or cfg.micro_batch_size * cfg.gradient_accumulation_steps\n    )\n    if cfg.eval_batch_size is None:\n        cfg.eval_batch_size = cfg.micro_batch_size\n    cfg.world_size = int(os.environ.get(\"WORLD_SIZE\", 1))\n    cfg.local_rank = int(os.environ.get(\"LOCAL_RANK\", 0))\n    cfg.eval_table_size = cfg.eval_table_size or 0\n    cfg.eval_max_new_tokens = cfg.eval_max_new_tokens or 128\n    cfg.eval_causal_lm_metrics = cfg.eval_causal_lm_metrics or [\n        \"sacrebleu\",\n        \"comet\",\n        \"ter\",\n        \"chrf\",\n    ]\n    choose_device(cfg)\n    cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1\n    if cfg.world_size != 1:\n        cfg.device_map = {\"\": int(os.environ.get(\"LOCAL_RANK\", 0))}\n        if cfg.fsdp or cfg.fsdp_config or cfg.ddp:\n            effective_world_size = (\n                cfg.world_size\n                // (cfg.context_parallel_size or 1)\n                // (cfg.tensor_parallel_size or 1)\n            )\n            cfg.batch_size = cfg.batch_size * effective_world_size\n\n    if not cfg.use_ray:\n        # delay resolving dtype until on worker node when launching with ray\n        resolve_dtype(cfg)\n\n    if cfg.deepspeed:\n        if isinstance(cfg.deepspeed, str) and os.path.exists(cfg.deepspeed):\n            ds_config_path = cfg.deepspeed\n            with open(ds_config_path, encoding=\"utf-8\") as f:\n                cfg.deepspeed = json.load(f)\n\n    if cfg.saves_per_epoch:\n        save_steps = 1.0 / (cfg.saves_per_epoch * cfg.num_epochs)\n        if save_steps < 1.0:  # prevent saves on every step\n            cfg.save_steps = save_steps\n        elif save_steps > 1:\n            LOG.warning(\n                f\"Invalid value for save_steps ({save_steps}) from saves_per_epoch and/or num_epochs. Saving at training end only.\"\n            )\n    if (cfg.val_set_size or cfg.test_datasets) and cfg.evals_per_epoch:\n        eval_steps = 1.0 / (cfg.evals_per_epoch * cfg.num_epochs)\n        if eval_steps < 1.0:  # prevent evals on every step\n            cfg.eval_steps = eval_steps\n        elif eval_steps > 1:\n            LOG.warning(\n                f\"Invalid value for eval_steps ({eval_steps}) from evals_per_epoch and/or num_epochs. Skipping evaluations.\"\n            )\n\n    if not cfg.base_model_config:\n        cfg.base_model_config = cfg.base_model\n\n    # Apply pre-config load patches (e.g., for Kimi Linear remote code patching)\n    from axolotl.loaders.patch_manager import PatchManager\n\n    PatchManager.apply_pre_config_load_patches(cfg)\n\n    model_config = load_model_config(cfg)\n\n    cfg.tokenizer_config = (\n        cfg.tokenizer_config or cfg.base_model_config or cfg.base_model\n    )\n\n    cfg.is_multimodal = (\n        hasattr(model_config, \"model_type\")\n        and model_config.model_type in MULTIMODAL_AUTO_MODEL_MAPPING\n        or any(\n            multimodal_name in cfg.base_model.lower()\n            for multimodal_name in [\n                \"pixtral\",\n            ]\n        )\n        or cfg.is_multimodal\n    )\n    if cfg.is_multimodal:\n        cfg.processor_config = (\n            cfg.processor_config or cfg.base_model_config or cfg.base_model\n        )\n\n    cfg.model_config_type = model_config.model_type\n\n    # Resolve inner text backbone type for VLM wrappers (e.g. mistral3 -> mistral4)\n    if callable(getattr(model_config, \"get_text_config\", None)):\n        text_config = model_config.get_text_config()\n        if (\n            hasattr(text_config, \"model_type\")\n            and text_config.model_type != model_config.model_type\n        ):\n            cfg.model_config_type_text = text_config.model_type\n\n    # figure out if the model is llama\n    cfg.is_llama_derived_model = (\n        (\n            hasattr(model_config, \"model_type\")\n            and model_config.model_type in [\"llama\", \"mllama_text_model\"]\n        )\n        or cfg.is_llama_derived_model\n        or \"llama\" in cfg.base_model.lower()\n        or (cfg.type_of_model and \"llama\" in cfg.type_of_model.lower())\n    )\n\n    # figure out if the model is falcon\n    cfg.is_falcon_derived_model = (\n        (\n            hasattr(model_config, \"model_type\")\n            and model_config.model_type\n            in [\n                \"falcon\",\n                \"RefinedWebModel\",\n                \"RefinedWeb\",\n            ]\n        )\n        or cfg.is_falcon_derived_model\n        or \"falcon\" in cfg.base_model.lower()\n        or (cfg.type_of_model and \"rwforcausallm\" in cfg.type_of_model.lower())\n    )\n\n    cfg.is_mistral_derived_model = (\n        (\n            hasattr(model_config, \"model_type\")\n            and model_config.model_type\n            in [\n                \"mistral\",\n            ]\n        )\n        or cfg.is_mistral_derived_model\n        or \"mistral\" in cfg.base_model.lower().split(\"/\")[-1]\n        or (cfg.type_of_model and \"mistral\" in cfg.type_of_model.lower())\n    )\n\n    cfg.is_qwen_derived_model = (\n        hasattr(model_config, \"model_type\")\n        and model_config.model_type\n        in [\n            \"qwen\",\n        ]\n    ) or cfg.is_qwen_derived_model\n\n    if isinstance(cfg.pretraining_dataset, dict):\n        cfg.pretraining_dataset = [cfg.pretraining_dataset]\n\n    if (\n        cfg.gradient_checkpointing\n        and cfg.unfrozen_parameters is None\n        and cfg.gradient_checkpointing_kwargs is None\n        and cfg.rl is None\n    ):\n        cfg.gradient_checkpointing_kwargs = {\"use_reentrant\": True}\n\n    log_gpu_memory_usage(LOG, \"baseline\", cfg.device)\n\n\ndef normalize_cfg_datasets(cfg):\n    \"\"\"\n    helpers for mapping chat_template to various dataset configurations as necessary\n    \"\"\"\n\n    if cfg.chat_template:\n        if cfg.datasets:\n            for idx, ds_cfg in enumerate(cfg.datasets):\n                if (\n                    ds_cfg.type in [\"orpo.chat_template\", \"chat_template\"]\n                    and not ds_cfg.chat_template\n                ):\n                    LOG.info(\n                        f\"updating dataset {ds_cfg.path} with `chat_template: {cfg.chat_template}` to match your chat_template\"\n                    )\n                    cfg.datasets[idx].chat_template = cfg.chat_template\n                    cfg.datasets[idx].chat_template_jinja = cfg.chat_template_jinja\n\n\ndef validate_config(\n    cfg: DictDefault,\n    capabilities: Optional[dict] = None,\n    env_capabilities: Optional[dict] = None,\n) -> DictDefault:\n    AxolotlConfigWCapabilities = AxolotlConfigWCapabilitiesBase\n    AxolotlInputConfig = AxolotlInputConfigBase\n\n    if cfg.plugins:\n        (\n            AxolotlConfigWCapabilities,\n            AxolotlInputConfig,\n        ) = merge_input_args()\n\n    # Convert datasets to proper format if needed\n    if cfg.get(\"datasets\"):\n        for idx, ds_cfg in enumerate(cfg[\"datasets\"]):\n            if cfg.get(\"rl\") in [\"dpo\", \"ipo\", \"simpo\"] and not isinstance(\n                ds_cfg, DPODataset\n            ):\n                cfg[\"datasets\"][idx] = DPODataset(**ds_cfg)\n            elif cfg.get(\"rl\") == \"kto\" and not isinstance(ds_cfg, KTODataset):\n                cfg[\"datasets\"][idx] = KTODataset(**dict(ds_cfg))\n            elif not isinstance(ds_cfg, SFTDataset):\n                cfg[\"datasets\"][idx] = SFTDataset(**dict(ds_cfg))\n\n    if capabilities or env_capabilities:\n        if (capabilities and env_capabilities is None) or (\n            env_capabilities and capabilities is None\n        ):\n            raise ValueError(\n                \"Both capabilities and env_capabilities must be provided or not provided.\"\n            )\n\n        return DictDefault(\n            dict(\n                AxolotlConfigWCapabilities(\n                    **cfg.to_dict(),\n                    capabilities=capabilities,\n                    env_capabilities=env_capabilities,\n                ).model_dump(exclude_none=True)\n            )\n        )\n\n    return DictDefault(\n        dict(AxolotlInputConfig(**cfg.to_dict()).model_dump(exclude_none=True))\n    )\n\n\ndef prepare_plugins(cfg):\n    \"\"\"\n    Prepare the plugins for the configuration\n    \"\"\"\n\n    if cfg.get(\"plugins\"):\n        plugin_manager = PluginManager.get_instance()\n        for plugin_name in cfg[\"plugins\"]:\n            plugin_manager.register(plugin_name)\n"
  },
  {
    "path": "src/axolotl/utils/config/models/__init__.py",
    "content": ""
  },
  {
    "path": "src/axolotl/utils/ctx_managers/__init__.py",
    "content": "\"\"\"Init for context manager submodule\"\"\"\n\n# flake8: noqa\n\nfrom .sequence_parallel import SequenceParallelContextManager\n"
  },
  {
    "path": "src/axolotl/utils/ctx_managers/sequence_parallel.py",
    "content": "\"\"\"Module for Axolotl trainer sequence parallelism manager and utilities\"\"\"\n\nimport functools\nimport inspect\n\nimport torch\nimport torch.distributed as dist\nfrom torch import nn\nfrom torch.distributed import DeviceMesh\nfrom torch.utils.hooks import RemovableHandle\nfrom transformers.modeling_outputs import CausalLMOutputWithPast\nfrom transformers.utils import ModelOutput\n\nfrom axolotl.monkeypatch.ring_attn import (\n    get_ring_attn_group,\n    register_ring_attn_from_device_mesh,\n    update_ring_attn_params,\n)\nfrom axolotl.utils.schemas.enums import RingAttnFunc\n\n\n# TODO(djsaunde): implement zigzag, stripe patterns here (and elsewhere) in this\n# module. Currently, we just focus on batch ring and varlen llama3 for simplicity.\ndef apply_sequence_parallelism(\n    batch: dict[str, torch.Tensor],\n    local_rank: int,\n    local_world_size: int,\n    gradient_accumulation_steps: int,\n    ring_attn_func: RingAttnFunc,\n) -> tuple[dict[str, torch.Tensor], int, int]:\n    \"\"\"\n    Apply sequence parallelism slicing to a batch.\n\n    Special handling is implemented for integer logits_to_keep, which indicates\n    to only keep the last N tokens in the sequence during generation.\n\n    Args:\n        batch: Batch dictionary (e.g., input_ids, attention_mask, etc.).\n        local_rank: Local rank in the sequence parallel group.\n        local_world_size: World size of the sequence parallel group.\n        gradient_accumulation_steps: Number of steps to accumulate gradients over.\n        ring_attn_func: Which ring attention function to use. Currently unused, but\n            related to above TODO.\n\n    Returns:\n        tuple of:\n            - Batch dictionary with sliced tensors.\n            - The original sequence length before padding.\n            - The number of padding tokens added.\n    \"\"\"\n    batch_size, original_seq_len = batch[\"input_ids\"].shape\n\n    # Update ring attention params if needed\n    if batch.get(\"position_ids\") is not None and batch_size == 1:\n        update_ring_attn_params(position_ids=batch[\"position_ids\"])\n    else:\n        # If position_ids aren't already in the batch, create them\n        batch[\"position_ids\"] = torch.arange(\n            0,\n            original_seq_len,\n            dtype=torch.long,\n            device=batch[\"input_ids\"].device,\n        ).expand(batch[\"input_ids\"].size(0), -1)\n\n    if \"logits_to_keep\" in batch and isinstance(batch[\"logits_to_keep\"], int):\n        logits_to_keep = batch[\"logits_to_keep\"]\n\n        # Calculate which positions in the full sequence contain the last N tokens\n        start_position = max(0, original_seq_len - logits_to_keep)\n        chunk_size = original_seq_len // local_world_size\n        rank_start = local_rank * chunk_size\n        rank_end = rank_start + chunk_size\n\n        # Create a boolean mask tensor for this rank's chunk\n        mask = torch.zeros(\n            chunk_size,\n            dtype=torch.bool,\n            device=batch[\"input_ids\"].device,\n        )\n\n        if rank_end > start_position:\n            # Calculate how many of the last N tokens fall within this rank's range\n            tokens_in_rank = min(rank_end, original_seq_len) - max(\n                rank_start, start_position\n            )\n\n            # Calculate where these tokens start in the local chunk\n            local_start_idx = max(0, start_position - rank_start)\n\n            # Set the appropriate positions in the mask to True\n            mask[local_start_idx : local_start_idx + tokens_in_rank] = True\n\n        # Replace the integer with the boolean mask\n        batch[\"logits_to_keep\"] = mask\n\n    # Add padding to make sequence length divisible by local_world_size\n    total_seq_len = original_seq_len\n    pad_len = 0\n    divisor = min(local_world_size, 64)\n    if total_seq_len % divisor != 0:\n        pad_len = divisor - (total_seq_len % divisor)\n\n        # Apply padding to all relevant tensors\n        for key in batch:\n            if (\n                isinstance(batch[key], torch.Tensor)\n                and batch[key].dim() > 1\n                and batch[key].size(1) == total_seq_len\n            ):\n                # Create padding tensor\n                pad_value = -100 if key == \"labels\" else 0\n                padding = torch.full(\n                    (batch[key].size(0), pad_len, *batch[key].shape[2:]),\n                    pad_value,\n                    dtype=batch[key].dtype,\n                    device=batch[key].device,\n                )\n\n                # Concatenate padding to the right side of the tensor\n                batch[key] = torch.cat([batch[key], padding], dim=1)\n            if key == \"logits_to_keep\":\n                # Create padding tensor\n                padding = torch.ones(\n                    1,\n                    dtype=batch[key].dtype,\n                    device=batch[key].device,\n                )\n\n                # Concatenate padding to the right side of the tensor\n                batch[key] = torch.cat([batch[key], padding], dim=0)\n\n        # Update the total sequence length after padding\n        total_seq_len = batch[\"input_ids\"].size(1)\n\n    # Slice batch for sequence parallel\n    for key in batch:\n        if not isinstance(batch[key], torch.Tensor) or batch[key].dim() <= 1:\n            continue\n\n        # Split in sequential fashion and grab this rank's chunk\n        if batch[key].size(1) == total_seq_len:\n            batch[key] = (\n                batch[key].chunk(local_world_size, dim=1)[local_rank].contiguous()\n            )\n        elif key == \"logits_to_keep\":\n            batch[key] = (\n                batch[key].chunk(local_world_size, dim=0)[local_rank].contiguous()\n            )\n\n        # Handle num_items_in_batch\n        if \"num_items_in_batch\" in batch:\n            # Approximation; this needed since num_items_in_batch may be counted across\n            # all samples in a gradient accumulated batch, not on a per-step basis.\n            local_valid_tokens = (batch[\"labels\"] != -100).sum()\n\n            # All-reduce across sequence parallel ranks to get global token count\n            cp_group = get_ring_attn_group()\n            global_valid_tokens = local_valid_tokens.clone()\n            # we use AVG instead of SUM as using sum seems to scale down the loss by over-accounting the number of tokens\n            dist.all_reduce(global_valid_tokens, op=dist.ReduceOp.AVG, group=cp_group)\n            global_valid_tokens = int(global_valid_tokens.item())\n\n            batch[\"num_items_in_batch\"] = (\n                global_valid_tokens * gradient_accumulation_steps\n            )\n\n    return batch, original_seq_len, pad_len\n\n\nclass SequenceParallelContextManager:\n    \"\"\"Context manager for sequence parallelism operations.\n\n    This class provides a context that will automatically apply sequence parallelism\n    during model forward passes using a pre-forward hook, and gather outputs from\n    across the sequence parallelism group using a post-forward hook.\n\n    Args:\n        models: List of models to apply sequence parallelism to pre- and post- forward\n            hooks.\n        context_parallel_size: Number of processes to split sequences over.\n        gradient_accumulation_steps: Number of steps to accumulate gradients over.\n        ring_attn_func: Which ring attention function to use. Currently unused.\n        heads_k_stride: Sequence parallelism K head stride size. Passed through to\n            `varlen_llama3` `ring_flash_attn` implementation.\n        gather_outputs: Whether to gather outputs after model forward pass across the\n            sequence parallel group.\n    \"\"\"\n\n    def __init__(\n        self,\n        models: list[nn.Module],\n        context_parallel_size: int,\n        gradient_accumulation_steps: int,\n        ring_attn_func: RingAttnFunc,\n        heads_k_stride: int | None,\n        gather_outputs: bool,\n        device_mesh: DeviceMesh | None = None,\n    ):\n        self.models = models\n        self.context_parallel_size = context_parallel_size\n        self.gradient_accumulation_steps = gradient_accumulation_steps\n        self.ring_attn_func = ring_attn_func\n        self.heads_k_stride = heads_k_stride\n        self.gather_outputs = gather_outputs\n        self.device_mesh = device_mesh\n\n        self._register_ring_attn()\n\n        # Set distributed info for local rank\n        self.process_group = get_ring_attn_group()\n        self.local_rank = dist.get_rank(self.process_group)\n        self.local_world_size = dist.get_world_size(self.process_group)\n\n        # Will store hook handles for removal\n        self.hook_handles: list[RemovableHandle] = []\n\n        # Store original sequence length and padding information\n        self.original_seq_len = 0\n        self.pad_len = 0\n\n        # Track local valid token count for eval loss correction across CP ranks\n        self._local_valid_tokens: torch.Tensor | None = None\n\n        # Create a partially applied version of the apply_sequence_parallelism function\n        self.apply_sequence_parallelism = functools.partial(\n            apply_sequence_parallelism,\n            local_rank=self.local_rank,\n            local_world_size=self.local_world_size,\n            gradient_accumulation_steps=self.gradient_accumulation_steps,\n            ring_attn_func=self.ring_attn_func,\n        )\n\n    def __enter__(self):\n        self._register_model_hooks()\n\n        return self\n\n    def __exit__(self, exc_type, exc_val, exc_tb):\n        # Remove all hooks\n        for handle in self.hook_handles:\n            handle.remove()\n        self.hook_handles = []\n\n        # TODO(djsaunde): Un-patch attention and accelerate functions (low priority)\n\n    def _register_ring_attn(self):\n        # Initialize ring attn for sequence parallelism\n        register_ring_attn_from_device_mesh(\n            device_mesh=self.device_mesh,\n            context_parallel_dim=(\"cp\",),\n            heads_k_stride=self.heads_k_stride,\n            ring_attn_func=self.ring_attn_func,\n        )\n\n    def _register_model_hooks(self):\n        # Forward pre-hook to apply sequence parallelism\n        def sequence_parallel_pre_hook(_, args, kwargs):\n            # Get parameter names from the model's forward function\n            forward_params = list(\n                inspect.signature(self.models[0].forward).parameters.keys()\n            )\n\n            updated_kwargs = kwargs.copy()\n            for i, arg in enumerate(args):\n                if i < len(forward_params):\n                    updated_kwargs[forward_params[i]] = arg\n\n            # Any excess positional arguments are kept as-is\n            remaining_args = args[len(forward_params) :]\n\n            # Apply sequence parallelism to updated kwargs\n            updated_kwargs, self.original_seq_len, self.pad_len = (\n                self.apply_sequence_parallelism(updated_kwargs)\n            )\n\n            # Track local valid tokens for eval loss correction\n            if \"labels\" in updated_kwargs and not self.models[0].training:\n                self._local_valid_tokens = (\n                    (updated_kwargs[\"labels\"] != -100).sum().float()\n                )\n                # Strip num_items_in_batch during eval so the model uses\n                # reduction='mean', allowing the post-hook weighted all-reduce\n                # formula (loss * local_valid) to correctly recover the loss sum\n                updated_kwargs.pop(\"num_items_in_batch\", None)\n            else:\n                self._local_valid_tokens = None\n\n            return remaining_args, updated_kwargs\n\n        # Forward post-hook to gather outputs\n        def sequence_parallel_post_hook(_, __, output: ModelOutput) -> ModelOutput:\n            # Gather the sharded outputs\n            output = self._gather_outputs(output)\n\n            # Remove padding if it was added\n            if self.pad_len > 0:\n                for key, value in output.items():\n                    if isinstance(value, torch.Tensor) and value.dim() > 1:\n                        if value.size(1) == self.original_seq_len + self.pad_len:\n                            # Slice to remove padding\n                            output[key] = value[:, : self.original_seq_len].contiguous()\n\n            return output\n\n        # Post-hook to correct eval loss via weighted all-reduce across CP ranks\n        def eval_loss_correction_post_hook(_, __, output: ModelOutput) -> ModelOutput:\n            if self._local_valid_tokens is None:\n                return output\n            if not hasattr(output, \"loss\") or output.loss is None:\n                return output\n\n            local_valid = self._local_valid_tokens.to(output.loss.device)\n            loss = output.loss.detach().clone()\n\n            # Handle rank with zero valid tokens (loss is NaN)\n            if local_valid.item() == 0:\n                weighted_loss = torch.zeros(1, device=loss.device, dtype=loss.dtype)\n            else:\n                weighted_loss = loss * local_valid\n\n            total_valid = local_valid.clone()\n            dist.all_reduce(\n                weighted_loss,\n                op=dist.ReduceOp.SUM,\n                group=self.process_group,\n            )\n            dist.all_reduce(\n                total_valid,\n                op=dist.ReduceOp.SUM,\n                group=self.process_group,\n            )\n\n            if total_valid.item() > 0:\n                output[\"loss\"] = (weighted_loss / total_valid).squeeze()\n            else:\n                output[\"loss\"] = torch.tensor(\n                    float(\"nan\"), device=loss.device, dtype=loss.dtype\n                )\n\n            self._local_valid_tokens = None\n            return output\n\n        # Register hooks\n        for model in self.models:\n            self.hook_handles.append(\n                model.register_forward_pre_hook(\n                    sequence_parallel_pre_hook, with_kwargs=True\n                )\n            )\n            if self.gather_outputs:\n                self.hook_handles.append(\n                    model.register_forward_hook(sequence_parallel_post_hook)\n                )\n            # Always register eval loss correction hook\n            self.hook_handles.append(\n                model.register_forward_hook(eval_loss_correction_post_hook)\n            )\n\n    def _gather_outputs(self, output: CausalLMOutputWithPast) -> CausalLMOutputWithPast:\n        \"\"\"Gather sharded outputs from all ranks and reconstruct the full tensor.\"\"\"\n        for key, value in output.items():\n            if isinstance(value, torch.Tensor) and value.dim() > 1:\n                output[key] = AllGatherWithGrad.apply(value, self.process_group)\n\n        return output\n\n\nclass AllGatherWithGrad(torch.autograd.Function):\n    \"\"\"Custom autograd function for all-gather to preserve gradients.\"\"\"\n\n    @staticmethod\n    def forward(\n        ctx: torch.autograd.function.FunctionCtx,\n        input_tensor: torch.Tensor,\n        group: dist.ProcessGroup,\n    ) -> torch.Tensor:\n        \"\"\"\n        Forward pass of all-gather of data with sequence dimension.\n\n        Args:\n            ctx: `torch.autograd` function context.\n            input_tensor: Tensor from model output with sequence dimension.\n            group: `torch.distributed` process group.\n\n        Returns:\n            Tensor from gathering the `input_tensor` from across the process group and\n                concatenating along the sequence dimension.\n        \"\"\"\n        ctx.group = group\n        ctx.rank = dist.get_rank(group)\n        world_size = dist.get_world_size(group)\n\n        # Gather shape metadata\n        local_shape = torch.tensor(list(input_tensor.shape), device=input_tensor.device)\n        all_shapes = [torch.zeros_like(local_shape) for _ in range(world_size)]\n        dist.all_gather(all_shapes, local_shape, group=group)\n\n        # Store sequence lengths for backward pass\n        seq_lens = [int(shape[1].item()) for shape in all_shapes]\n        ctx.seq_lens = seq_lens\n\n        # Perform all_gather operation\n        gathered = [\n            torch.zeros(\n                tuple(shape.tolist()),\n                dtype=input_tensor.dtype,\n                device=input_tensor.device,\n            )\n            for shape in all_shapes\n        ]\n        dist.all_gather(gathered, input_tensor, group=group)\n\n        # Concatenate tensors along sequence dimension\n        result = torch.cat(gathered, dim=1)\n\n        return result\n\n    @staticmethod\n    def backward(\n        ctx: torch.autograd.function.FunctionCtx, grad_output: torch.Tensor\n    ) -> tuple[torch.Tensor, None]:\n        \"\"\"\n        Backward pass for all-gather operation.\n\n        Extracts the gradient slice corresponding to this rank's original input\n        from the full gradient tensor.\n\n        Args:\n            ctx: `torch.autograd` function context.\n            grad_output: Gradient from subsequent layers with respect to the\n                concatenated output tensor.\n\n        Returns:\n            Tuple containing the gradient slice for this rank's input tensor and `None`\n                for the process group parameter which doesn't require gradients.\n        \"\"\"\n        rank = ctx.rank\n        seq_lens = ctx.seq_lens\n\n        # Extract gradient for this rank's chunk\n        offset = sum(seq_lens[:rank])\n        grad_slice = grad_output[:, offset : offset + seq_lens[rank]].contiguous()\n\n        return grad_slice, None\n"
  },
  {
    "path": "src/axolotl/utils/data/__init__.py",
    "content": "\"\"\"Init for `axolotl.utils.data` module.\"\"\"\n\nfrom axolotl.utils.data.rl import prepare_preference_datasets\nfrom axolotl.utils.data.sft import (\n    get_dataset_wrapper,\n    prepare_datasets,\n)\nfrom axolotl.utils.data.streaming import (\n    encode_streaming,\n    wrap_streaming_dataset,\n)\nfrom axolotl.utils.data.utils import md5\n\n__all__ = [\n    \"encode_streaming\",\n    \"wrap_streaming_dataset\",\n    \"prepare_preference_datasets\",\n    \"get_dataset_wrapper\",\n    \"prepare_datasets\",\n    \"md5\",\n]\n"
  },
  {
    "path": "src/axolotl/utils/data/lock.py",
    "content": "\"\"\"Logic for loading / preparing a dataset once over all processes.\"\"\"\n\nimport time\nfrom pathlib import Path\nfrom typing import Any, Callable\n\nfrom filelock import FileLock\n\nfrom axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH\nfrom axolotl.utils.dict import DictDefault\n\nLOCK_FILE_NAME = \"datasets_prep.lock\"\nREADY_FILE_NAME = \"datasets_ready.flag\"\nPROCESS_COUNTER_FILE_NAME = \"process_counter.txt\"\n\n\nclass FileLockLoader:\n    \"\"\"\n    Simple class for abstracting single process data loading / processing. The first\n    process that creates a lock file does the work; the remaining procesees simply load\n    the preprocessed dataset once the first process is done.\n    \"\"\"\n\n    def __init__(self, cfg: DictDefault):\n        self.cfg = cfg\n        self.dataset_prepared_path = (\n            cfg.dataset_prepared_path or DEFAULT_DATASET_PREPARED_PATH\n        )\n        self.lock_file_path = Path(self.dataset_prepared_path) / LOCK_FILE_NAME\n        self.ready_flag_path = Path(self.dataset_prepared_path) / READY_FILE_NAME\n        self.counter_path = Path(self.dataset_prepared_path) / PROCESS_COUNTER_FILE_NAME\n\n    def load(self, load_fn: Callable[[], Any]) -> Any:\n        with FileLock(str(self.lock_file_path)):\n            self._increment_counter()\n\n            if not self.ready_flag_path.exists():\n                result = load_fn()\n                self.ready_flag_path.touch()\n                return result\n\n            while not self.ready_flag_path.exists():\n                time.sleep(1)\n            return load_fn()\n\n    def _increment_counter(self):\n        \"\"\"Safely increment the process counter.\"\"\"\n        if self.counter_path.exists():\n            counter_content = self.counter_path.read_text().strip()\n            count = int(counter_content) if counter_content else 0\n        else:\n            count = 0\n        self.counter_path.write_text(str(count + 1))\n\n    def cleanup(self):\n        \"\"\"Clean up ready flag when last process is done.\"\"\"\n        try:\n            with FileLock(str(self.lock_file_path)):\n                counter_content = self.counter_path.read_text().strip()\n                count = int(counter_content) if counter_content else 0\n                count -= 1\n\n                if count <= 0:\n                    # Last process cleans everything up\n                    self.ready_flag_path.unlink(missing_ok=True)\n                    self.counter_path.unlink(missing_ok=True)\n                else:\n                    # Still have active processes\n                    self.counter_path.write_text(str(count))\n        except FileNotFoundError:\n            # Lock file might have already been deleted by another process\n            pass\n"
  },
  {
    "path": "src/axolotl/utils/data/rl.py",
    "content": "\"\"\"Data handling specific to RL trainers.\"\"\"\n\nimport inspect\nfrom functools import partial\nfrom typing import Any, Callable, Literal\n\nfrom datasets import Dataset, DatasetDict\nfrom transformers import PreTrainedTokenizer\n\nfrom axolotl.loaders import load_tokenizer\nfrom axolotl.prompt_strategies.dpo import load as load_dpo\nfrom axolotl.prompt_strategies.kto import load as load_kto\nfrom axolotl.prompt_strategies.orpo import load as load_orpo\nfrom axolotl.utils.data.lock import FileLockLoader\nfrom axolotl.utils.data.shared import (\n    create_train_validation_split,\n    datasets_with_name_generator,\n    generate_dataset_hash_from_config,\n    load_dataset_with_config,\n    load_preprocessed_dataset,\n    merge_datasets,\n    save_preprocessed_dataset,\n    try_load_from_hub,\n)\nfrom axolotl.utils.data.utils import (\n    deduplicate_and_log_datasets,\n    retry_on_request_exceptions,\n)\nfrom axolotl.utils.dict import DictDefault\nfrom axolotl.utils.logging import get_logger\nfrom axolotl.utils.schemas.enums import RLType\n\nLOG = get_logger(__name__)\n\n\n@retry_on_request_exceptions(max_retries=3, delay=5)\ndef prepare_preference_datasets(\n    cfg: DictDefault, tokenizer: PreTrainedTokenizer\n) -> tuple[Dataset, Dataset | None]:\n    \"\"\"Load and prepare preference datasets for RL training.\n\n    Loads training and evaluation datasets, handling preprocessing, caching, and\n    deduplication as configured. Uses FileLock for distributed coordination.\n\n    Args:\n        cfg: Configuration object containing dataset and training settings.\n        tokenizer: Tokenizer to use for processing text.\n\n    Returns:\n        Tuple of (train_dataset, eval_dataset). eval_dataset may be None\n            if no evaluation dataset is configured.\n    \"\"\"\n\n    def _load_datasets():\n        # Load training dataset\n        train_dataset = _load_or_create_dataset_split(cfg, tokenizer, split=\"train\")\n\n        # Load or create evaluation dataset\n        eval_dataset: Dataset | None = None\n        if cfg.test_datasets:\n            eval_dataset = _load_or_create_dataset_split(cfg, tokenizer, split=\"test\")\n        elif cfg.val_set_size:\n            # Create validation split from training data\n            train_dataset, eval_dataset = create_train_validation_split(\n                train_dataset, cfg, cfg.val_set_size\n            )\n\n        return train_dataset, eval_dataset\n\n    # Prepare datasets (with file locking logic for multiple ranks)\n    loader = FileLockLoader(cfg)\n    try:\n        train_dataset, eval_dataset = loader.load(_load_datasets)\n    finally:\n        loader.cleanup()\n\n    # Apply deduplication if configured\n    if cfg.dataset_exact_deduplication:\n        train_dataset, eval_dataset = deduplicate_and_log_datasets(\n            dataset=train_dataset, other_dataset=eval_dataset\n        )\n\n    return train_dataset, eval_dataset\n\n\ndef _map_dataset(\n    cfg: DictDefault,\n    dataset: Dataset | DatasetDict,\n    ds_transform_fn: Callable[..., Any],\n    tokenizer: Any | None = None,\n    **map_kwargs: Any,\n) -> Dataset:\n    \"\"\"Apply transformation function to dataset.\n\n    Args:\n        cfg: Configuration object.\n        dataset: Dataset to transform.\n        ds_transform_fn: Transformation function to apply.\n        tokenizer: Optional tokenizer for transformation.\n        **map_kwargs: Additional arguments for dataset mapping.\n\n    Returns:\n        Transformed dataset.\n    \"\"\"\n    sig = inspect.signature(ds_transform_fn)\n    if \"tokenizer\" in sig.parameters:\n        if not tokenizer:\n            tokenizer = load_tokenizer(cfg)\n        ds_transform_fn = partial(ds_transform_fn, tokenizer=tokenizer)\n\n    if isinstance(dataset, DatasetDict):\n        dataset = dataset[\"train\"]\n\n    dataset = dataset.map(\n        ds_transform_fn,\n        num_proc=cfg.dataset_num_proc,\n        load_from_cache_file=not cfg.is_preprocess,\n        desc=\"Mapping RL Dataset\",\n        **map_kwargs,\n    )\n\n    return dataset\n\n\ndef _drop_long_sequences(\n    sample: dict[str, Any], rl: RLType, tokenizer: Any, sequence_len: int\n) -> bool:\n    \"\"\"Filter out samples that exceed maximum sequence length.\n\n    Args:\n        sample: Dataset sample to check.\n        rl: Reinforcement learning type.\n        tokenizer: Tokenizer for length calculation.\n        sequence_len: Maximum allowed sequence length.\n\n    Returns:\n        True if sample should be kept, False if it should be dropped.\n\n    Raises:\n        ValueError: If required keys are missing or RL type is unknown.\n    \"\"\"\n    if rl in {RLType.DPO, RLType.IPO, RLType.ORPO, RLType.SIMPO}:\n        if not (\n            sample.get(\"prompt\") and sample.get(\"chosen\") and sample.get(\"rejected\")\n        ):\n            raise ValueError(\n                \"Prompt, chosen and rejected keys are required for DPO/ORPO datasets\"\n            )\n\n        prompt = sample[\"prompt\"]\n        chosen = sample[\"chosen\"]\n        rejected = sample[\"rejected\"]\n\n        len_prompt = len(tokenizer(prompt, add_special_tokens=False)[\"input_ids\"])\n        len_chosen = len(tokenizer(chosen, add_special_tokens=False)[\"input_ids\"])\n        len_rejected = len(tokenizer(rejected, add_special_tokens=False)[\"input_ids\"])\n\n        return (len_prompt + len_chosen) <= sequence_len and (\n            len_prompt + len_rejected\n        ) <= sequence_len\n\n    if rl is RLType.KTO:\n        if not (sample.get(\"prompt\") and sample.get(\"completion\")):\n            raise ValueError(\"Prompt and completion keys are required for KTO datasets\")\n\n        prompt = sample[\"prompt\"]\n        completion = sample[\"completion\"]\n\n        len_prompt = len(tokenizer(prompt, add_special_tokens=False)[\"input_ids\"])\n        len_completion = len(\n            tokenizer(completion, add_special_tokens=False)[\"input_ids\"]\n        )\n\n        return (len_prompt + len_completion) <= sequence_len\n\n    if rl in {RLType.GRPO, RLType.GDPO}:\n        return True\n\n    raise ValueError(\"Unknown RL type\")\n\n\ndef _load_split(cfg: DictDefault, split: Literal[\"train\", \"test\"]) -> Dataset:\n    \"\"\"Load and process dataset split for RL training.\n\n    Args:\n        cfg: Configuration object containing dataset settings.\n        split: Dataset split to load (\"train\" or \"test\").\n\n    Returns:\n        Combined and processed dataset for the specified split.\n    \"\"\"\n    datasets_configs = cfg.datasets if split == \"train\" else cfg.test_datasets\n    split_datasets: list[Dataset | DatasetDict] = []\n\n    for dataset_config in datasets_with_name_generator(datasets_configs):\n        dataset: Dataset | DatasetDict = load_dataset_with_config(\n            dataset_config, cfg.hf_use_auth_token, streaming=False\n        )\n        split_datasets.append(dataset)\n\n    tokenizer = load_tokenizer(cfg)\n\n    for i, dataset in enumerate(split_datasets):\n        _type = datasets_configs[i][\"type\"]\n        if _type:\n            if isinstance(_type, DictDefault):\n                _type = \"user_defined.default\"\n            if cfg.rl is RLType.ORPO:\n                ds_transform_fn = load_orpo(_type, cfg, dataset_idx=i)\n            elif cfg.rl is RLType.KTO:\n                ds_transform_fn = load_kto(_type, cfg, dataset_idx=i)\n            else:\n                ds_transform_fn = load_dpo(_type, cfg, dataset_idx=i)\n\n            map_kwargs: dict[str, Any] = {}\n            if isinstance(ds_transform_fn, tuple):\n                ds_transform_fn, map_kwargs = ds_transform_fn\n            split_datasets[i] = _map_dataset(\n                cfg, dataset, ds_transform_fn, tokenizer, **map_kwargs\n            )\n        else:\n            # If no `type` is provided, assume the dataset is already in the expected format with\n            # \"prompt\", \"chosen\", and \"rejected\" already preprocessed\n            split_datasets[i] = dataset\n\n        if not cfg.skip_prepare_dataset:\n            drop_long = partial(\n                _drop_long_sequences,\n                rl=cfg.rl,\n                tokenizer=tokenizer,\n                sequence_len=cfg.sequence_len,\n            )\n\n            prior_len = len(split_datasets[i])\n            split_datasets[i] = split_datasets[i].filter(\n                drop_long,\n                num_proc=cfg.dataset_num_proc,\n                load_from_cache_file=not cfg.is_preprocess,\n                desc=\"Dropping Long Sequences\",\n            )\n            dropped = prior_len - len(split_datasets[i])\n            if dropped:\n                LOG.warning(f\"Dropped {dropped} long samples from dataset index {i}\")\n\n    # Merge datasets\n    dataset = merge_datasets(split_datasets, cfg)\n\n    if not cfg.skip_prepare_dataset:\n        # Deduplicate before saving so the saved dataset is already de-duplicated\n        if cfg.dataset_exact_deduplication:\n            dataset, _ = deduplicate_and_log_datasets(dataset=dataset)\n\n        # Save preprocessed dataset\n        dataset_hash = generate_dataset_hash_from_config(\n            cfg, datasets_configs, tokenizer.name_or_path\n        )\n        save_preprocessed_dataset(cfg, dataset, dataset_hash, split)\n\n    return dataset\n\n\ndef _load_or_create_dataset_split(\n    cfg: DictDefault, tokenizer: PreTrainedTokenizer, split: Literal[\"train\", \"test\"]\n) -> Dataset:\n    \"\"\"Load preprocessed dataset or create new one for given split.\n\n    Args:\n        cfg: Configuration object.\n        tokenizer: Tokenizer to use for processing text.\n        split: Dataset split to load.\n\n    Returns:\n        Tuple of (dataset, is_preprocessed).\n    \"\"\"\n    # Select correct dataset configuration based on split\n    datasets_config = cfg.datasets if split == \"train\" else cfg.test_datasets\n\n    # Generate dataset hash for caching\n    dataset_hash = generate_dataset_hash_from_config(\n        cfg, datasets_config, tokenizer.name_or_path\n    )\n\n    # Try loading from hub if push_dataset_to_hub is configured\n    dataset = None\n    if cfg.push_dataset_to_hub:\n        dataset = try_load_from_hub(cfg, dataset_hash, split)\n\n    # Attempt to load preprocessed dataset\n    if dataset is None:\n        dataset = load_preprocessed_dataset(cfg, dataset_hash)\n\n    # Otherwise, load it\n    if dataset is None:\n        dataset = _load_split(cfg, split=split)\n\n    return dataset\n"
  },
  {
    "path": "src/axolotl/utils/data/sft.py",
    "content": "\"\"\"Data handling specific to SFT.\"\"\"\n\nimport functools\nimport os\nimport tempfile\nfrom typing import Literal\n\nfrom datasets import (\n    Dataset,\n    DatasetDict,\n    IterableDataset,\n    IterableDatasetDict,\n    load_dataset,\n)\nfrom transformers import PreTrainedTokenizer, ProcessorMixin\n\nfrom axolotl.prompters import Prompter\nfrom axolotl.utils.data.lock import FileLockLoader\nfrom axolotl.utils.data.shared import (\n    create_train_validation_split,\n    datasets_with_name_generator,\n    generate_dataset_hash_from_config,\n    load_dataset_with_config,\n    load_preprocessed_dataset,\n    merge_datasets,\n    save_preprocessed_dataset,\n    try_load_from_hub,\n)\nfrom axolotl.utils.data.streaming import wrap_streaming_dataset\nfrom axolotl.utils.data.utils import (\n    deduplicate_and_log_datasets,\n    handle_long_seq_in_dataset,\n    retry_on_request_exceptions,\n)\nfrom axolotl.utils.data.wrappers import get_dataset_wrapper\nfrom axolotl.utils.dict import DictDefault\nfrom axolotl.utils.distributed import is_local_main_process\nfrom axolotl.utils.logging import get_logger\nfrom axolotl.utils.trainer import (\n    calculate_total_num_steps,\n    process_datasets_for_packing,\n)\n\nLOG = get_logger(__name__)\n\n\n@retry_on_request_exceptions(max_retries=3, delay=5)\ndef prepare_datasets(\n    cfg: DictDefault,\n    tokenizer: PreTrainedTokenizer,\n    processor: ProcessorMixin | None = None,\n) -> tuple[IterableDataset | Dataset, Dataset | None, int, list[Prompter | None]]:\n    \"\"\"Prepare training and evaluation datasets based on configuration.\n\n    Args:\n        cfg: Dictionary mapping `axolotl` config keys to values.\n        tokenizer: Tokenizer to use for processing text.\n        processor: Optional processor for multimodal datasets.\n\n    Returns:\n        Tuple of (train_dataset, eval_dataset, total_steps, prompters).\n    \"\"\"\n    if cfg.streaming or cfg.pretraining_dataset:\n        return _prepare_streaming_dataset(cfg, tokenizer, processor)\n    return _prepare_standard_dataset(cfg, tokenizer, processor)\n\n\ndef _prepare_standard_dataset(\n    cfg: DictDefault,\n    tokenizer: PreTrainedTokenizer,\n    processor: ProcessorMixin | None,\n) -> tuple[Dataset, Dataset | None, int, list[Prompter | None]]:\n    \"\"\"Prepare standard (non-pretraining) datasets.\"\"\"\n\n    def _load_datasets():\n        # Always load training dataset\n        train_dataset, eval_dataset, prompters = _load_and_prepare_datasets(\n            tokenizer,\n            cfg,\n            split=\"train\",\n            processor=processor,\n        )\n\n        # Overwrite eval_dataset if test data exists\n        if cfg.test_datasets:\n            _, eval_dataset, _ = _load_and_prepare_datasets(\n                tokenizer,\n                cfg,\n                split=\"test\",\n                processor=processor,\n            )\n\n        return train_dataset, eval_dataset, prompters\n\n    # Prepare datasets (with file locking logic for multiple ranks)\n    loader = FileLockLoader(cfg)\n    try:\n        train_dataset, eval_dataset, prompters = loader.load(_load_datasets)\n    finally:\n        loader.cleanup()\n\n    if os.environ.get(\"AXOLOTL_IS_PREPROCESS\") == \"1\":\n        return train_dataset, eval_dataset, -1, prompters\n\n    # Validate sample packing configuration for evaluation\n    if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False:\n        total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False)\n        if total_eval_steps == 0:\n            raise ValueError(\n                \"eval dataset split is too small for sample_packing. \"\n                \"You should set `eval_sample_packing: False` in your config.\"\n            )\n\n    # Calculate total number of training steps\n    if cfg.max_steps:\n        total_num_steps = min(\n            calculate_total_num_steps(cfg, train_dataset), cfg.max_steps\n        )\n    else:\n        total_num_steps = calculate_total_num_steps(cfg, train_dataset)\n    LOG.info(f\"Maximum number of steps set at {total_num_steps}\")\n    return train_dataset, eval_dataset, total_num_steps, prompters\n\n\ndef _prepare_streaming_dataset(\n    cfg: DictDefault,\n    tokenizer: PreTrainedTokenizer,\n    processor: ProcessorMixin | None,\n) -> tuple[IterableDataset, Dataset | None, int, list[Prompter | None]]:\n    \"\"\"\n    Prepare dataset for streaming mode.\n\n    Note: Streaming datasets are loaded incrementally from the source.\n    \"\"\"\n    if cfg.pretraining_dataset:\n        dataset_config = _extract_pretraining_config(cfg)\n        train_dataset = _load_streaming_dataset(dataset_config, cfg, tokenizer)\n    elif cfg.sample_packing:\n        # TODO(djsaunde): Implement for multiple datasets\n        dataset_config = DictDefault(cfg.datasets[0])\n\n        # Ensure we have a split set - default to 'train' if not specified\n        if not hasattr(dataset_config, \"split\") or not dataset_config.split:\n            dataset_config.split = \"train\"\n        train_dataset = _load_streaming_dataset(dataset_config, cfg, tokenizer)\n    else:\n        # Use legacy loading function for non-packed streaming datasets\n        train_dataset, eval_dataset, prompters = _load_and_prepare_datasets(\n            tokenizer,\n            cfg,\n            split=\"train\",\n            processor=processor,\n            streaming=True,\n        )\n\n        # Return early for non-packed streaming datasets\n        total_num_steps = cfg.max_steps if cfg.max_steps else -1\n        return train_dataset, eval_dataset, total_num_steps, prompters\n\n    # Load evaluation dataset if specified\n    eval_dataset = None\n    if cfg.test_datasets:\n        _, eval_dataset, _ = _load_and_prepare_datasets(\n            tokenizer,\n            cfg,\n            split=\"test\",\n            processor=processor,\n            streaming=False,\n        )\n\n    # For streaming, we return max_steps directly from config or -1 if not set\n    total_num_steps = cfg.max_steps if cfg.max_steps else -1\n    return train_dataset, eval_dataset, total_num_steps, []\n\n\ndef _extract_pretraining_config(cfg: DictDefault) -> DictDefault:\n    \"\"\"Extract pretraining configuration from the main config.\"\"\"\n    if isinstance(cfg.pretraining_dataset, list) and isinstance(\n        cfg.pretraining_dataset[0], dict\n    ):\n        config = cfg.pretraining_dataset[0]\n        return DictDefault(\n            {\n                \"path\": config[\"path\"],\n                \"name\": config[\"name\"],\n                \"skip\": config[\"skip\"],\n                \"split\": config.get(\"split\", \"train\"),\n                \"data_files\": config.get(\"data_files\"),\n                \"type\": config.get(\"type\", \"pretrain\"),\n            }\n        )\n    # Simple string path case\n    return DictDefault(\n        {\n            \"path\": cfg.pretraining_dataset,\n            \"name\": None,\n            \"skip\": 0,\n            \"split\": \"train\",\n            \"data_files\": None,\n            \"type\": \"pretrain\",\n        }\n    )\n\n\ndef _load_streaming_dataset(\n    pretraining_config: DictDefault, cfg: DictDefault, tokenizer: PreTrainedTokenizer\n) -> IterableDataset:\n    \"\"\"Load and prepare a streaming dataset for pretraining.\"\"\"\n    # Create dataset wrapper partial function\n    dataset_wrapper_partial = functools.partial(\n        get_dataset_wrapper,\n        dataset_config=pretraining_config,\n        tokenizer=tokenizer,\n        cfg=cfg,\n        dataset_base_type=pretraining_config[\"type\"],\n    )\n\n    # Load the actual dataset\n    if (\n        cfg.accelerator_config\n        and cfg.accelerator_config.dispatch_batches\n        and not is_local_main_process()\n    ):\n        iter_dataset = _create_placeholder_dataset()\n    else:\n        iter_dataset = load_dataset(\n            pretraining_config[\"path\"],\n            streaming=True,\n            split=pretraining_config[\"split\"],\n            name=pretraining_config[\"name\"],\n            data_files=pretraining_config[\"data_files\"],\n        )\n\n    # Apply skip if specified\n    if pretraining_config[\"skip\"]:\n        LOG.info(f\"Skipping {pretraining_config['skip']} samples from the dataset\")\n        iter_dataset = iter_dataset.skip(pretraining_config[\"skip\"])\n\n    # Wrap the dataset for pretraining\n    train_dataset = wrap_streaming_dataset(\n        iter_dataset,\n        tokenizer,\n        cfg,\n        dataset_wrapper_partial,\n    )\n\n    # Format for PyTorch\n    return train_dataset.with_format(\"torch\")\n\n\ndef _create_placeholder_dataset() -> IterableDataset:\n    \"\"\"Create a minimal placeholder dataset for non-main processes.\"\"\"\n    with tempfile.NamedTemporaryFile(mode=\"w+\", delete=False) as f:\n        f.write(\"text\\n\")\n        f.write(\"lorem ipsum dolor sit amet\\n\")\n        f.seek(0)\n        return load_dataset(\"csv\", data_files=f.name, split=\"train\", streaming=True)\n\n\ndef _load_tokenized_prepared_datasets(\n    tokenizer: PreTrainedTokenizer,\n    cfg: DictDefault,\n    split: Literal[\"train\", \"test\"] = \"train\",\n    processor: ProcessorMixin | None = None,\n    streaming: bool = False,\n) -> tuple[Dataset | DatasetDict, list[Prompter | None]]:\n    \"\"\"Load or create tokenized and prepared datasets for training or testing.\n\n    Args:\n        tokenizer: Tokenizer for processing text.\n        cfg: Configuration object.\n        split: Dataset split to load ('train' or 'test').\n        processor: Optional processor for multimodal datasets.\n        streaming: Whether to use iterable preprocessing.\n\n    Returns:\n        Tuple of (dataset, prompters list).\n    \"\"\"\n    # Select correct dataset configuration based on split\n    datasets_configs = cfg.datasets if split == \"train\" else cfg.test_datasets\n\n    # Generate dataset hash for caching\n    dataset_hash = generate_dataset_hash_from_config(\n        cfg, datasets_configs, tokenizer.name_or_path\n    )\n\n    # Try loading from hub if push_dataset_to_hub is configured\n    dataset = None\n    if cfg.push_dataset_to_hub:\n        dataset = try_load_from_hub(cfg, dataset_hash, split)\n\n    # If not found on hub, try loading from disk\n    if dataset is None:\n        dataset = load_preprocessed_dataset(cfg, dataset_hash)\n\n    # If not found on disk or skipping prepared dataset, load and process raw datasets\n    prompters: list[Prompter | None] = []\n    if dataset is None:\n        dataset, prompters = _load_raw_datasets(\n            cfg,\n            datasets_configs,\n            tokenizer,\n            split,\n            processor,\n            streaming,\n        )\n\n    return dataset, prompters\n\n\ndef _load_raw_datasets(\n    cfg: DictDefault,\n    datasets_configs: list,\n    tokenizer: PreTrainedTokenizer,\n    split: str,\n    processor: ProcessorMixin | None = None,\n    streaming: bool = False,\n) -> tuple[Dataset, list[Prompter | None]]:\n    \"\"\"Load, process, merge, and save raw datasets.\"\"\"\n    LOG.info(\"Loading raw datasets...\", main_process_only=False)\n    if not cfg.is_preprocess and not cfg.skip_prepare_dataset:\n        LOG.warning(\n            \"Processing datasets during training can lead to VRAM instability. Please \"\n            \"pre-process your dataset using `axolotl preprocess path/to/config.yml`.\"\n        )\n\n    # Load and process individual datasets\n    datasets = []\n    prompters = []\n    for dataset_config in datasets_with_name_generator(datasets_configs):\n        dataset_wrapper, dataset_prompter = _load_and_process_single_dataset(\n            dataset_config=dataset_config,\n            cfg=cfg,\n            tokenizer=tokenizer,\n            split=split,\n            seed=cfg.seed,\n            processor=processor,\n            streaming=streaming,\n        )\n        datasets.append(dataset_wrapper)\n        prompters.append(dataset_prompter)\n\n    # Merge datasets\n    dataset = merge_datasets(datasets, cfg)\n\n    if not cfg.skip_prepare_dataset and not streaming:\n        if split == \"test\" and cfg.eval_sequence_len:\n            dataset = handle_long_seq_in_dataset(dataset, cfg.eval_sequence_len, cfg)\n        else:\n            dataset = handle_long_seq_in_dataset(dataset, cfg.sequence_len, cfg)\n        if (split == \"train\" and cfg.sample_packing) or (\n            split == \"test\" and cfg.eval_sample_packing\n        ):\n            dataset, _ = process_datasets_for_packing(cfg, dataset, None)\n\n        # Deduplicate before saving so the saved dataset is already de-duplicated\n        if cfg.dataset_exact_deduplication:\n            dataset, _ = deduplicate_and_log_datasets(dataset=dataset)\n\n        # Save the prepared dataset\n        dataset_hash = generate_dataset_hash_from_config(\n            cfg, datasets_configs, tokenizer.name_or_path\n        )\n        save_preprocessed_dataset(cfg, dataset, dataset_hash, split)\n\n    return dataset, prompters\n\n\ndef _load_and_process_single_dataset(\n    dataset_config: DictDefault,\n    cfg: DictDefault,\n    tokenizer: PreTrainedTokenizer,\n    split: str,\n    seed: int,\n    processor: ProcessorMixin | None = None,\n    streaming: bool = False,\n) -> tuple[Dataset | IterableDataset, Prompter | None]:\n    \"\"\"Load and process a single dataset based on the passed config.\"\"\"\n    # Load the dataset\n    dataset = load_dataset_with_config(\n        dataset_config, cfg.hf_use_auth_token, streaming=streaming\n    )\n\n    # Parse dataset type\n    d_base_type, d_prompt_style = _parse_dataset_type(dataset_config.type)\n\n    # Select the appropriate split\n    if isinstance(dataset, (DatasetDict, IterableDatasetDict)):\n        if dataset_config.split and dataset_config.split in dataset:\n            dataset = dataset[dataset_config.split]\n        elif split in dataset:\n            dataset = dataset[split]\n        else:\n            raise ValueError(\n                f\"no {split} split found for dataset {dataset_config.path}, you may \"\n                \"specify a split with 'split: ...'\"\n            )\n\n    # Apply sharding if configured\n    if dataset_config.shards:\n        shards_idx = dataset_config.get(\"shards_idx\", 0)\n        dataset = dataset.shuffle(seed=seed).shard(\n            num_shards=dataset_config.shards, index=shards_idx\n        )\n\n    # Apply dataset wrapper\n    dataset_wrapper, dataset_prompter = get_dataset_wrapper(\n        dataset_config=dataset_config,\n        tokenizer=tokenizer,\n        cfg=cfg,\n        dataset_base_type=d_base_type,\n        dataset=dataset,\n        dataset_prompt_style=d_prompt_style,\n        processor=processor,\n    )\n\n    return dataset_wrapper, dataset_prompter\n\n\ndef _parse_dataset_type(d_type: str) -> tuple[str | None, str | None]:\n    \"\"\"Parse the dataset type string into base type and prompt style.\"\"\"\n    if not isinstance(d_type, str):\n        return None, None\n\n    d_type_split = d_type.split(\":\")\n    d_base_type = d_type_split[0]\n    d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None\n\n    return d_base_type, d_prompt_style\n\n\ndef _handle_train_dataset_split(\n    dataset: Dataset, cfg: DictDefault\n) -> tuple[Dataset, Dataset | None]:\n    \"\"\"Handle processing for train split, including validation set creation.\"\"\"\n    val_set_size = (\n        int(cfg.val_set_size) if cfg.val_set_size > 1 else float(cfg.val_set_size)\n    )\n\n    if val_set_size:\n        # Create train/validation split\n        train_dataset, eval_dataset = create_train_validation_split(\n            dataset, cfg, val_set_size\n        )\n        return train_dataset, eval_dataset\n\n    # No validation split - deduplication already applied during preprocessing\n    return dataset, None\n\n\ndef _apply_dataset_sharding(dataset: Dataset, cfg: DictDefault) -> Dataset:\n    \"\"\"Apply dataset sharding if configured.\n\n    Args:\n        dataset: Dataset to shard.\n        cfg: Configuration object containing shard settings.\n\n    Returns:\n        Sharded dataset or original dataset if no sharding configured.\n    \"\"\"\n    if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None:\n        LOG.info(\n            f\"Using index #{cfg.dataset_shard_idx} of {cfg.dataset_shard_num} shards\"\n        )\n        dataset = dataset.shard(\n            num_shards=cfg.dataset_shard_num,\n            index=cfg.dataset_shard_idx,\n        )\n    return dataset\n\n\ndef _load_and_prepare_datasets(\n    tokenizer: PreTrainedTokenizer,\n    cfg: DictDefault,\n    split: Literal[\"train\", \"test\"] = \"train\",\n    processor: ProcessorMixin | None = None,\n    streaming: bool = False,\n) -> tuple[Dataset | None, Dataset | None, list[Prompter | None]]:\n    \"\"\"Load and prepare datasets with optional validation split and sharding.\n\n    Args:\n        tokenizer: Tokenizer for processing text.\n        cfg: Configuration object.\n        split: Dataset split to load ('train' or 'test').\n        processor: Optional processor for multimodal datasets.\n        streaming: Whether to use iterable preprocessing.\n\n    Returns:\n        Tuple of (train_dataset, eval_dataset, prompters).\n    \"\"\"\n    # Load the base dataset\n    dataset, prompters = _load_tokenized_prepared_datasets(\n        tokenizer,\n        cfg,\n        split=split,\n        processor=processor,\n        streaming=streaming,\n    )\n\n    # Apply dataset sharding if configured using shared function\n    dataset = _apply_dataset_sharding(dataset, cfg)\n\n    # Apply deduplication and create train / validation splits based on the split type\n    if split == \"train\":\n        train_dataset, eval_dataset = _handle_train_dataset_split(dataset, cfg)\n    else:\n        # Deduplication already applied during preprocessing\n        train_dataset, eval_dataset = None, dataset\n\n    return train_dataset, eval_dataset, prompters\n"
  },
  {
    "path": "src/axolotl/utils/data/shared.py",
    "content": "\"\"\"Dataset loading shared utils.\"\"\"\n\nfrom __future__ import annotations\n\nimport functools\nimport os\nfrom pathlib import Path\nfrom typing import TYPE_CHECKING, Any, Generator\n\nfrom datasets import (\n    Dataset,\n    DatasetDict,\n    IterableDataset,\n    IterableDatasetDict,\n    concatenate_datasets,\n    load_dataset,\n    load_from_disk,\n)\nfrom huggingface_hub import hf_hub_download, snapshot_download\nfrom huggingface_hub.errors import (\n    HFValidationError,\n    RepositoryNotFoundError,\n    RevisionNotFoundError,\n)\n\nfrom axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH\nfrom axolotl.utils.data.utils import deduplicate_and_log_datasets, md5\nfrom axolotl.utils.datasets import get_default_process_count\nfrom axolotl.utils.dict import DictDefault\nfrom axolotl.utils.logging import get_logger\n\nif TYPE_CHECKING:\n    from adlfs import AzureBlobFileSystem\n    from gcsfs import GCSFileSystem\n    from ocifs import OCIFileSystem\n    from s3fs import S3FileSystem\n\nLOG = get_logger(__name__)\n\nEXTENSIONS_TO_DATASET_TYPES = {\n    \".parquet\": \"parquet\",\n    \".arrow\": \"arrow\",\n    \".csv\": \"csv\",\n    \".txt\": \"text\",\n}\n\n\ndef get_dataset_type(dataset_config: DictDefault) -> str:\n    \"\"\"Get the dataset type from the path if it's not specified.\"\"\"\n    if dataset_config.ds_type:\n        return dataset_config.ds_type\n\n    for extension, dataset_type in EXTENSIONS_TO_DATASET_TYPES.items():\n        if extension in dataset_config.path:\n            return dataset_type\n\n    return \"json\"\n\n\ndef datasets_with_name_generator(\n    dataset_configs: list[DictDefault],\n) -> Generator[DictDefault, None, None]:\n    \"\"\"Yields expanded dataset configurations based on multiple names or preprocessing\n    shards.\n\n    When a dataset config has a list of names, it yields separate configs for each\n    name. When a dataset config specifies preprocessing shards, it yields configs for\n    each shard.\n\n    Args:\n        dataset_configs: List of dataset configuration objects.\n\n    Yields:\n        Individual dataset configurations, expanded as needed for names or shards.\n    \"\"\"\n    for config in dataset_configs:\n        if config.name and isinstance(config.name, list):\n            for name in config.name:\n                yield DictDefault({**config, \"name\": name})\n        elif config.preprocess_shards and not config.shards:\n            for shard_idx in range(config.preprocess_shards):\n                yield DictDefault(\n                    {\n                        **config,\n                        \"shards\": config.preprocess_shards,\n                        \"shards_idx\": shard_idx,\n                    }\n                )\n        else:\n            yield config\n\n\ndef load_dataset_with_config(\n    dataset_config: DictDefault, use_auth_token: bool, streaming=False\n) -> Dataset | IterableDataset:\n    \"\"\"Load a dataset from a config. Handles datasets that are stored locally, in the\n    HuggingFace Hub, in a remote filesystem (S3, GCS, Azure, OCI), a URL, or\n    `data_files`.\n\n    Args:\n        dataset_config: Single dataset config.\n        use_auth_token: Whether to use HF auth token.\n        streaming: Whether to stream the dataset.\n\n    Returns:\n        Loaded dataset.\n    \"\"\"\n    # Set up common kwargs for dataset loading\n    load_dataset_kwargs = {\n        \"split\": dataset_config.split if dataset_config.split else None,\n        \"name\": dataset_config.name,\n        \"streaming\": streaming,\n        \"trust_remote_code\": dataset_config.trust_remote_code,\n    }\n\n    # First check if it's a local path\n    if Path(dataset_config.path).exists():\n        return _load_from_local_path(dataset_config, load_dataset_kwargs)\n\n    # Check if it's a HuggingFace dataset\n    is_hub_dataset = _check_if_hub_dataset(dataset_config, use_auth_token)\n\n    # Check if it's a cloud storage path and get appropriate filesystem\n    remote_fs, storage_options = _get_remote_filesystem(dataset_config.path)\n    is_cloud_dataset = False\n    if remote_fs:\n        try:\n            is_cloud_dataset = remote_fs.exists(dataset_config.path)\n        except (FileNotFoundError, ConnectionError):\n            pass\n\n    # Load from appropriate source\n    if is_hub_dataset:\n        return _load_from_hub(dataset_config, use_auth_token, load_dataset_kwargs)\n    if is_cloud_dataset:\n        return _load_from_cloud(\n            dataset_config, remote_fs, storage_options, load_dataset_kwargs\n        )\n    if dataset_config.path.startswith(\"https://\"):\n        return _load_from_url(dataset_config, load_dataset_kwargs)\n    if dataset_config.data_files:\n        return _load_from_data_files(dataset_config, load_dataset_kwargs)\n\n    raise ValueError(\n        f\"The dataset could not be loaded. This could be due to a misconfigured dataset path \"\n        f\"({dataset_config.path}). Try double-check your path / name / data_files. \"\n        f\"This is not caused by the dataset type.\"\n    )\n\n\ndef _check_if_hub_dataset(dataset_config: DictDefault, use_auth_token: bool) -> bool:\n    \"\"\"Check if a dataset exists on the HuggingFace Hub.\"\"\"\n    try:\n        snapshot_download(\n            repo_id=dataset_config.path,\n            repo_type=\"dataset\",\n            token=use_auth_token,\n            revision=dataset_config.revision,\n            ignore_patterns=[\"*\"],\n        )\n        return True\n    except (\n        RepositoryNotFoundError,\n        RevisionNotFoundError,\n        FileNotFoundError,\n        ConnectionError,\n        HFValidationError,\n        ValueError,\n    ):\n        return False\n\n\ndef _get_remote_filesystem(\n    path: str,\n) -> tuple[\n    S3FileSystem | GCSFileSystem | AzureBlobFileSystem | OCIFileSystem | None, dict\n]:\n    \"\"\"Get the appropriate filesystem for a remote path.\"\"\"\n    if path.startswith(\"s3://\"):\n        try:\n            import s3fs\n\n            storage_options = {\"anon\": False}\n            return s3fs.S3FileSystem(**storage_options), storage_options\n        except ImportError as exc:\n            raise ImportError(\"s3:// paths require s3fs to be installed\") from exc\n\n    elif path.startswith((\"gs://\", \"gcs://\")):\n        try:\n            import gcsfs\n\n            storage_options = {\"token\": None}  # type: ignore  # nosec B105\n            return gcsfs.GCSFileSystem(**storage_options), storage_options\n        except ImportError as exc:\n            raise ImportError(\n                \"gs:// or gcs:// paths require gcsfs to be installed\"\n            ) from exc\n\n    elif path.startswith((\"adl://\", \"abfs://\", \"az://\")):\n        try:\n            import adlfs\n\n            storage_options = {\"anon\": False}\n            return adlfs.AzureBlobFileSystem(**storage_options), storage_options\n        except ImportError as exc:\n            raise ImportError(\n                \"adl:// or abfs:// paths require adlfs to be installed\"\n            ) from exc\n\n    elif path.startswith(\"oci://\"):\n        try:\n            import ocifs\n\n            storage_options = {}\n            return ocifs.OCIFileSystem(**storage_options), storage_options\n        except ImportError as exc:\n            raise ImportError(\"oci:// paths require ocifs to be installed\") from exc\n\n    return None, {}\n\n\ndef _load_from_local_path(\n    dataset_config: DictDefault, load_dataset_kwargs: dict\n) -> Dataset | IterableDataset | DatasetDict | IterableDatasetDict:\n    \"\"\"Load a dataset from a local path.\"\"\"\n    local_path = Path(dataset_config.path)\n\n    if local_path.is_dir():\n        if dataset_config.data_files:\n            dataset_type = get_dataset_type(dataset_config)\n            return load_dataset(\n                dataset_type,\n                data_files=dataset_config.data_files,\n                **load_dataset_kwargs,\n            )\n        try:\n            return load_from_disk(dataset_config.path)\n        except FileNotFoundError:\n            return load_dataset(dataset_config.path, **load_dataset_kwargs)\n    elif local_path.is_file():\n        dataset_type = get_dataset_type(dataset_config)\n\n        # For single file datasets, HF always creates only a \"train\" split\n        if dataset_type in (\"json\", \"csv\", \"text\"):\n            load_dataset_kwargs[\"split\"] = \"train\"\n\n        return load_dataset(\n            dataset_type,\n            data_files=dataset_config.path,\n            **load_dataset_kwargs,\n        )\n    else:\n        raise ValueError(\n            \"Unhandled dataset load: local path exists, but is neither a directory or a file\"\n        )\n\n\ndef _load_from_hub(\n    dataset_config: DictDefault, use_auth_token: bool, load_dataset_kwargs: dict\n) -> Dataset | IterableDataset | DatasetDict | IterableDatasetDict:\n    \"\"\"Load a dataset from the HuggingFace Hub.\"\"\"\n    return load_dataset(\n        dataset_config.path,\n        data_files=dataset_config.data_files,\n        token=use_auth_token,\n        revision=dataset_config.revision,\n        **load_dataset_kwargs,\n    )\n\n\ndef _load_from_cloud(\n    dataset_config: DictDefault,\n    remote_fs: S3FileSystem | GCSFileSystem | AzureBlobFileSystem | OCIFileSystem,\n    storage_options: dict,\n    load_dataset_kwargs: dict,\n) -> Dataset | IterableDataset | DatasetDict | IterableDatasetDict:\n    \"\"\"Load a dataset from cloud storage.\"\"\"\n    if remote_fs.isdir(dataset_config.path):\n        return load_from_disk(\n            dataset_config.path,\n            storage_options=storage_options,\n        )\n\n    if remote_fs.isfile(dataset_config.path):\n        dataset_type = get_dataset_type(dataset_config)\n        return load_dataset(\n            dataset_type,\n            data_files=dataset_config.path,\n            storage_options=storage_options,\n            **load_dataset_kwargs,\n        )\n\n    raise ValueError(\n        f\"Cloud path {dataset_config.path} is neither a directory nor a file\"\n    )\n\n\ndef _load_from_url(\n    dataset_config: DictDefault, load_dataset_kwargs: dict\n) -> Dataset | IterableDataset | DatasetDict | IterableDatasetDict:\n    \"\"\"Load a dataset from a URL.\"\"\"\n    dataset_type = get_dataset_type(dataset_config)\n    return load_dataset(\n        dataset_type,\n        data_files=dataset_config.path,\n        **load_dataset_kwargs,\n    )\n\n\ndef _load_from_data_files(\n    dataset_config: DictDefault, load_dataset_kwargs: dict\n) -> Dataset | IterableDataset | DatasetDict | IterableDatasetDict:\n    \"\"\"Load a dataset from data files.\"\"\"\n    file_path = None\n\n    if isinstance(dataset_config.data_files, str):\n        file_path = hf_hub_download(\n            repo_id=dataset_config.path,\n            repo_type=\"dataset\",\n            filename=dataset_config.data_files,\n            revision=dataset_config.revision,\n        )\n    elif isinstance(dataset_config.data_files, list):\n        file_path = [\n            hf_hub_download(\n                repo_id=dataset_config.path,\n                repo_type=\"dataset\",\n                filename=file,\n                revision=dataset_config.revision,\n            )\n            for file in dataset_config.data_files\n        ]\n    else:\n        raise ValueError(\"data_files must be either a string or list of strings\")\n\n    return load_dataset(\"json\", data_files=file_path, **load_dataset_kwargs)\n\n\ndef generate_split_fingerprints(\n    dataset: Dataset, val_set_size: int | float, seed: int\n) -> tuple[str, str]:\n    \"\"\"Generate consistent fingerprints for train/test splits.\"\"\"\n    fingerprint = dataset._fingerprint\n\n    train_hash_input = f\"{fingerprint}|{val_set_size}|train|{seed}\"\n    test_hash_input = f\"{fingerprint}|{val_set_size}|test|{seed}\"\n\n    train_fingerprint = md5(train_hash_input)\n    test_fingerprint = md5(test_hash_input)\n\n    return train_fingerprint, test_fingerprint\n\n\ndef get_prepared_dataset_path(cfg: DictDefault, dataset_hash: str) -> Path:\n    \"\"\"Get standardized path for prepared datasets.\n\n    Args:\n        cfg: Configuration object.\n        dataset_hash: Hash identifying the specific dataset configuration.\n\n    Returns:\n        Path where the prepared dataset should be stored.\n    \"\"\"\n    base_path = cfg.dataset_prepared_path or DEFAULT_DATASET_PREPARED_PATH\n    return Path(base_path) / dataset_hash\n\n\ndef create_train_validation_split(\n    dataset: Dataset, cfg: DictDefault, val_set_size: int | float\n) -> tuple[Dataset, Dataset]:\n    \"\"\"Create train/validation split with consistent fingerprinting.\n\n    Args:\n        dataset: Dataset to split.\n        cfg: Configuration object containing seed and other settings.\n        val_set_size: Size of validation set (absolute number or fraction).\n\n    Returns:\n        Tuple of (train_dataset, eval_dataset).\n    \"\"\"\n    train_fingerprint, test_fingerprint = generate_split_fingerprints(\n        dataset, val_set_size, cfg.seed\n    )\n\n    # Apply deduplication before splitting if configured\n    if cfg.dataset_exact_deduplication:\n        dataset, _ = deduplicate_and_log_datasets(dataset=dataset)\n\n    split_dataset = dataset.train_test_split(\n        test_size=val_set_size,\n        shuffle=False,\n        seed=cfg.seed,\n        train_new_fingerprint=train_fingerprint,\n        test_new_fingerprint=test_fingerprint,\n    )\n\n    return split_dataset[\"train\"], split_dataset[\"test\"]\n\n\ndef _generate_from_iterable_dataset(\n    dataset: IterableDataset, worker_id: list[int], num_workers: list[int]\n) -> Generator[Any, None, None]:\n    \"\"\"Generator function to correctly split the dataset for each worker\"\"\"\n    for i, item in enumerate(dataset):\n        if i % num_workers[0] == worker_id[0]:\n            yield item\n\n\ndef save_preprocessed_dataset(\n    cfg: DictDefault,\n    dataset: Dataset,\n    dataset_hash: str,\n    split: str,\n) -> None:\n    \"\"\"Save preprocessed dataset to disk and optionally push to the HF Hub.\"\"\"\n    prepared_ds_path = get_prepared_dataset_path(cfg, dataset_hash)\n    num_workers = cfg.dataset_num_proc or get_default_process_count()\n    if isinstance(dataset, IterableDataset):\n        ds_from_iter = Dataset.from_generator(\n            functools.partial(_generate_from_iterable_dataset, dataset),\n            features=dataset.features,\n            num_proc=num_workers,\n            split=split,\n            gen_kwargs={\n                \"worker_id\": list(range(num_workers)),\n                \"num_workers\": [num_workers] * num_workers,\n            },\n        )\n        ds_from_iter.save_to_disk(\n            str(prepared_ds_path),\n            num_proc=num_workers,\n            max_shard_size=None,\n            num_shards=cfg.num_dataset_shards_to_save,\n        )\n    else:\n        min_rows_per_proc = 256\n        os.makedirs(prepared_ds_path, exist_ok=True)\n        dataset.save_to_disk(\n            str(prepared_ds_path),\n            num_proc=min(max(1, len(dataset) // min_rows_per_proc), num_workers),\n            max_shard_size=None,\n            num_shards=cfg.num_dataset_shards_to_save,\n        )\n    if cfg.push_dataset_to_hub:\n        LOG.info(\n            \"Pushing merged prepared dataset to Huggingface hub at \"\n            f\"{cfg.push_dataset_to_hub} (version {dataset_hash})...\",\n            main_process_only=False,\n        )\n        dataset.push_to_hub(\n            cfg.push_dataset_to_hub,\n            dataset_hash,\n            private=True,\n        )\n\n\ndef load_preprocessed_dataset(cfg: DictDefault, dataset_hash: str) -> Dataset | None:\n    \"\"\"Load preprocessed dataset from disk if available.\n\n    Args:\n        cfg: Configuration object.\n        dataset_hash: Hash identifying the dataset configuration.\n\n    Returns:\n        Loaded dataset if found and conditions are met, None otherwise.\n    \"\"\"\n    prepared_ds_path = get_prepared_dataset_path(cfg, dataset_hash)\n\n    if (\n        cfg.dataset_prepared_path\n        and any(prepared_ds_path.glob(\"*\"))\n        and not cfg.skip_prepare_dataset\n        and not cfg.is_preprocess\n    ):\n        LOG.info(\n            f\"Loading prepared dataset from disk at {prepared_ds_path}...\",\n        )\n        return load_from_disk(str(prepared_ds_path))\n\n    LOG.info(\n        f\"Unable to find prepared dataset in {prepared_ds_path}\",\n    )\n    return None\n\n\ndef try_load_from_hub(\n    cfg: DictDefault, dataset_hash: str, split: str\n) -> Dataset | None:\n    \"\"\"Try to load the prepared dataset from HuggingFace Hub.\"\"\"\n    try:\n        LOG.info(\n            \"Attempting to load prepared dataset from HuggingFace Hub at \"\n            f\"{cfg.push_dataset_to_hub} (version {dataset_hash})...\"\n        )\n        dataset = load_dataset(\n            cfg.push_dataset_to_hub,\n            dataset_hash,\n            token=cfg.hf_use_auth_token,\n        )\n        return dataset[split]\n    except Exception:\n        LOG.info(\"Unable to find prepared dataset in HuggingFace Hub\")\n        return None\n\n\ndef generate_dataset_hash_from_config(\n    cfg: DictDefault, cfg_datasets: list, tokenizer_name: str\n) -> str:\n    \"\"\"Generate a hash to uniquely identify a dataset configuration for SFT.\n\n    Args:\n        cfg: Main configuration object.\n        cfg_datasets: List of dataset configurations.\n        tokenizer_name: Name of the tokenizer being used.\n\n    Returns:\n        MD5 hash string representing the configuration.\n    \"\"\"\n    config_str = (\n        f\"{cfg.sequence_len}@{cfg.sample_packing}@{cfg.eval_sample_packing}@\"\n        f\"{cfg.group_by_length}@{cfg.kd_temperature or 1.0}@\"\n        f\"{cfg.dataset_exact_deduplication or False}|\"\n        f\"{'|'.join(sorted([f'{d.path}:{d.type}:{d.shards}:{d.conversation}:{d.split}:{d.temperature or 1.0}' for d in cfg_datasets]))}\"\n        f\"|{tokenizer_name}\"\n    )\n    return str(md5(config_str))\n\n\ndef merge_datasets(datasets: list[Dataset], cfg: DictDefault) -> Dataset:\n    \"\"\"Merge multiple datasets into one with optional shuffling.\n\n    Args:\n        datasets: List of datasets to merge.\n        cfg: Configuration object containing shuffle settings.\n\n    Returns:\n        Merged dataset.\n    \"\"\"\n    if len(datasets) == 1:\n        ds = datasets[0]\n\n        # Do not shuffle if curriculum sampling is enabled or\n        # shuffle_merged_datasets is disabled\n        if cfg.curriculum_sampling or not cfg.shuffle_merged_datasets:\n            return ds\n\n        return ds.shuffle(seed=cfg.seed)\n\n    # If enabled, shuffle each dataset independently before merging.\n    # This allows curriculum learning strategies to be applied at the dataset level.\n    if cfg.shuffle_before_merging_datasets:\n        LOG.info(\"Shuffling each dataset individually before merging...\")\n        datasets = [ds.shuffle(seed=cfg.seed) for ds in datasets]\n\n    LOG.info(\"Merging datasets...\")\n    merged_dataset = concatenate_datasets(datasets)\n\n    if cfg.shuffle_merged_datasets:\n        LOG.debug(\"Shuffling merged datasets...\")\n        if cfg.curriculum_sampling:\n            LOG.warning(\n                \"Shuffling merged datasets with curriculum sampling is not recommended. \"\n                \"This will randomize the order of samples.\"\n            )\n        merged_dataset = merged_dataset.shuffle(seed=cfg.seed)\n    else:\n        LOG.debug(\"Not shuffling merged datasets.\")\n\n    return merged_dataset\n"
  },
  {
    "path": "src/axolotl/utils/data/streaming.py",
    "content": "\"\"\"Data handling specific to streaming datasets.\"\"\"\n\nimport functools\nfrom collections import defaultdict\nfrom typing import Callable, Dict, List, Optional\n\nimport torch\nfrom datasets import Dataset\nfrom torch.utils.data import RandomSampler\nfrom transformers import PreTrainedTokenizerBase\n\nfrom axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq\nfrom axolotl.utils.logging import get_logger\nfrom axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths\nfrom axolotl.utils.trainer import process_pretraining_datasets_for_packing\n\nLOG = get_logger(__name__)\n\n\ndef encode_streaming(\n    examples: Dict[str, List],\n    tokenizer: PreTrainedTokenizerBase,\n    max_tokens: int,\n    text_column: str = \"text\",\n    concatenate: bool = True,\n) -> Dict[str, List]:\n    res = tokenizer(\n        examples[text_column],\n        truncation=True,\n        max_length=max_tokens - 2,\n        add_special_tokens=True,\n    )\n    # Convert to PyTorch tensors\n    input_ids = [torch.tensor(seq) for seq in res[\"input_ids\"]]\n    targets = [torch.tensor(seq) for seq in res[\"input_ids\"]]\n    attention_mask = [torch.tensor(seq) for seq in res[\"attention_mask\"]]\n    if not concatenate:\n        return {\n            \"input_ids\": [seq.tolist() for seq in input_ids],\n            \"labels\": [seq.tolist() for seq in targets],\n            \"attention_mask\": [seq.tolist() for seq in attention_mask],\n        }\n\n    new_input_ids = []\n    new_labels = []\n    new_attention_mask = []\n    # Append EOS and PAD tokens to input_ids, and correct attention_mask\n    for i, _ in enumerate(input_ids):\n        input_ids[i] = torch.cat(\n            (\n                input_ids[i],\n                torch.tensor([tokenizer.eos_token_id, tokenizer.pad_token_id]),\n            ),\n            dim=0,\n        )\n        targets[i] = torch.cat(\n            (\n                targets[i],\n                torch.tensor([tokenizer.eos_token_id, -100]),\n            ),\n            dim=0,\n        )\n        attention_mask[i] = torch.cat((attention_mask[i], torch.tensor([1, 0])), dim=0)\n\n    # Concatenate tokens so that their lengths are less than max_tokens\n    buffer_input_ids = torch.tensor([], dtype=torch.long)\n    buffer_labels = torch.tensor([], dtype=torch.long)\n    buffer_attention_mask = torch.tensor([], dtype=torch.long)\n\n    for ids, labels, mask in zip(input_ids, targets, attention_mask, strict=False):\n        if buffer_input_ids.numel() == max_tokens:\n            new_input_ids.append(buffer_input_ids)\n            new_labels.append(buffer_labels)\n            new_attention_mask.append(buffer_attention_mask)\n            buffer_input_ids = torch.tensor([], dtype=torch.long)\n            buffer_labels = torch.tensor([], dtype=torch.long)\n            buffer_attention_mask = torch.tensor([], dtype=torch.long)\n            buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)\n            buffer_labels = torch.cat((buffer_labels, labels), dim=0)\n            buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)\n        elif buffer_input_ids.numel() + ids.numel() <= max_tokens:\n            buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)\n            buffer_labels = torch.cat((buffer_labels, labels), dim=0)\n            buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)\n        else:\n            buffer_input_ids = torch.cat(\n                (\n                    buffer_input_ids,\n                    torch.full(\n                        (max_tokens - buffer_input_ids.numel(),),\n                        tokenizer.pad_token_id,\n                        dtype=torch.long,\n                    ),\n                ),\n                dim=0,\n            )\n            buffer_labels = torch.cat(\n                (\n                    buffer_labels,\n                    torch.full(\n                        (max_tokens - buffer_labels.numel(),),\n                        -100,\n                        dtype=torch.long,\n                    ),\n                ),\n                dim=0,\n            )\n            buffer_attention_mask = torch.cat(\n                (\n                    buffer_attention_mask,\n                    torch.full(\n                        (max_tokens - buffer_attention_mask.numel(),),\n                        0,\n                        dtype=torch.long,\n                    ),\n                ),\n                dim=0,\n            )\n            new_input_ids.append(buffer_input_ids)\n            new_labels.append(buffer_labels)\n            new_attention_mask.append(buffer_attention_mask)\n            buffer_input_ids = torch.tensor([], dtype=torch.long)\n            buffer_labels = torch.tensor([], dtype=torch.long)\n            buffer_attention_mask = torch.tensor([], dtype=torch.long)\n\n            buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)\n            buffer_labels = torch.cat((buffer_labels, labels), dim=0)\n            buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)\n\n    if buffer_input_ids.numel() > 0:  # for any leftover tokens\n        while buffer_input_ids.numel() < max_tokens:  # make all sequences equal in size\n            buffer_input_ids = torch.cat(\n                (\n                    buffer_input_ids,\n                    torch.full(\n                        (max_tokens - buffer_input_ids.numel(),),\n                        tokenizer.pad_token_id,\n                        dtype=torch.long,\n                    ),\n                ),\n                dim=0,\n            )\n            buffer_labels = torch.cat(\n                (\n                    buffer_labels,\n                    torch.full(\n                        (max_tokens - buffer_labels.numel(),),\n                        -100,\n                        dtype=torch.long,\n                    ),\n                ),\n                dim=0,\n            )\n            buffer_attention_mask = torch.cat(\n                (\n                    buffer_attention_mask,\n                    torch.full(\n                        (max_tokens - buffer_attention_mask.numel(),),\n                        0,\n                        dtype=torch.long,\n                    ),\n                ),\n                dim=0,\n            )\n        new_input_ids.append(buffer_input_ids)\n        new_labels.append(buffer_labels)\n        new_attention_mask.append(buffer_attention_mask)\n\n    ret = {\n        \"input_ids\": [seq.tolist() for seq in new_input_ids],\n        \"labels\": [seq.tolist() for seq in new_labels],\n        \"attention_mask\": [seq.tolist() for seq in new_attention_mask],\n    }\n\n    LOG.debug(len(ret[\"input_ids\"]))\n    return ret\n\n\ndef wrap_streaming_dataset(\n    dataset,\n    tokenizer,\n    cfg,\n    ds_wrapper_fn,\n):\n    if cfg.sample_packing:\n        # For SFT (non-pretraining) datasets, always use multipack_attn=True to ensure\n        # attention isolation between packed sequences\n        multipack_attn = (\n            True if not cfg.pretraining_dataset else cfg.pretrain_multipack_attn\n        )\n\n        collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq(\n            tokenizer,\n            return_tensors=\"pt\",\n            padding=True,\n            pad_to_multiple_of=cfg.sequence_len,\n            multipack_attn=multipack_attn,\n        )\n        encode = functools.partial(\n            encode_packed_streaming,\n            collate_fn,\n            ds_wrapper_fn,\n            max_seq_length=cfg.sequence_len,\n            batch_size=cfg.micro_batch_size,\n            multipack_attn=multipack_attn,\n            bin_size=cfg.sample_packing_bin_size,\n        )\n\n        # Set this to 1 so downstream data_loader doesn't try to increase the batch size\n        # again\n        cfg.micro_batch_size = 1\n    else:\n        # NOTE: This is not reachable for SFT datasets since we use the pre-existing\n        # loading function for non-packed streaming datasets. Refer to\n        # _prepare_streaming_datasets in sft.py for that code path.\n        text_column = (\n            getattr(cfg.pretraining_dataset[0], \"text_column\", \"text\") or \"text\"\n        )\n        encode = functools.partial(\n            encode_streaming,\n            tokenizer=tokenizer,\n            max_tokens=cfg.sequence_len,\n            text_column=text_column,\n            concatenate=cfg.pretraining_sample_concatenation is True,\n        )\n\n    if cfg.shuffle_merged_datasets:\n        dataset = dataset.shuffle(\n            seed=cfg.seed, buffer_size=cfg.streaming_multipack_buffer_size\n        )\n    else:\n        LOG.debug(\"NOT shuffling merged pretraining datasets\")\n\n    # remove all the existing columns after mapping since they end up having\n    # a different length than the encoded/tokenized column\n    # this is empty during streaming/pretraining\n    remove_columns = []\n    if dataset.features is None:\n        for first_row in dataset:\n            remove_columns = list(first_row.keys())\n            break\n    else:\n        remove_columns = list(dataset.features.keys())\n\n    dataset = dataset.map(\n        encode,\n        batched=True,\n        batch_size=cfg.streaming_multipack_buffer_size,\n        remove_columns=remove_columns,\n    )\n    return dataset\n\n\ndef encode_packed_streaming(\n    collate_fn,\n    ds_wrapper: Callable,\n    examples: Dict[str, List],\n    bin_size: int,\n    max_seq_length: int = 2048,\n    batch_size: int = 4,\n    multipack_attn: Optional[bool] = True,\n) -> Dict[str, List]:\n    # tokenize all the examples\n    # rows get split with stride (overlap)\n    train_dataset = ds_wrapper(dataset=Dataset.from_dict(examples))[0]\n\n    train_dataset = process_pretraining_datasets_for_packing(\n        train_dataset,\n        max_seq_length,\n        skip_position_ids=not multipack_attn,\n        # FIXME using attention mask unpad/pad with trainer and packed pretraining is broken atm\n        # workaround by using the position id logic for now in trainer\n        drop_attention_mask=multipack_attn,\n    )\n\n    sampler = MultipackBatchSampler(\n        sampler=RandomSampler(train_dataset),\n        lengths=get_dataset_lengths(train_dataset),\n        batch_size=1,\n        batch_max_len=batch_size * max_seq_length,\n        drop_last=True,\n        num_processes=1,\n        bin_size=bin_size,\n    )\n\n    chunked_data = defaultdict(list)\n\n    for batch in sampler:\n        for data in batch:\n            features = train_dataset[data]\n            if \"num_truncated_tokens\" in features:\n                del features[\"num_truncated_tokens\"]\n            if \"overflow_to_sample_mapping\" in features:\n                del features[\"overflow_to_sample_mapping\"]\n            if \"labels\" not in features:\n                features[\"labels\"] = features[\"input_ids\"].copy()\n            collated_features = collate_fn(features)\n\n            for feature in features.keys():\n                if feature == \"length\":\n                    continue\n                chunked_data[feature].append(collated_features[feature].squeeze(0))\n\n    return chunked_data\n"
  },
  {
    "path": "src/axolotl/utils/data/utils.py",
    "content": "\"\"\"Data handling helpers\"\"\"\n\nimport contextlib\nimport functools\nimport hashlib\nimport time\nfrom enum import Enum\nfrom typing import Callable\n\nimport huggingface_hub\nimport numpy as np\nimport requests\nfrom datasets import Dataset, IterableDataset\n\nfrom axolotl.utils.dict import DictDefault\nfrom axolotl.utils.logging import get_logger\nfrom axolotl.utils.samplers.utils import get_dataset_lengths\nfrom axolotl.utils.trainer import filter_sequences_by_length\n\nLOG = get_logger(__name__)\n\n\nclass RetryStrategy(Enum):\n    \"\"\"Enum for retry strategies.\"\"\"\n\n    CONSTANT = 1\n    LINEAR = 2\n    EXPONENTIAL = 3\n\n\ndef retry_on_request_exceptions(\n    max_retries=3, delay=1, retry_strategy: RetryStrategy = RetryStrategy.LINEAR\n) -> Callable:\n    \"\"\"Decorator that retries function calls on specific request exceptions.\n\n    Args:\n        max_retries: Maximum number of retry attempts.\n        delay: Base delay between retries in seconds.\n        retry_strategy: Strategy for calculating retry delays.\n\n    Returns:\n        Decorated function with retry logic.\n    \"\"\"\n\n    def decorator(func):\n        @functools.wraps(func)\n        def wrapper(*args, **kwargs):\n            for attempt in range(max_retries):\n                try:\n                    return func(*args, **kwargs)\n                except (\n                    requests.exceptions.ReadTimeout,\n                    requests.exceptions.ConnectionError,\n                    requests.exceptions.HTTPError,\n                    huggingface_hub.errors.HfHubHTTPError,\n                ) as exc:\n                    if attempt < max_retries - 1:\n                        if retry_strategy == RetryStrategy.EXPONENTIAL:\n                            step_delay = delay * 2**attempt\n                        elif retry_strategy == RetryStrategy.LINEAR:\n                            step_delay = delay * (attempt + 1)\n                        else:\n                            step_delay = delay  # Use constant delay.\n                        time.sleep(step_delay)\n                    else:\n                        raise exc\n\n        return wrapper\n\n    return decorator\n\n\ndef md5(to_hash: str, encoding: str = \"utf-8\") -> str:\n    \"\"\"Generate MD5 hash of a string.\"\"\"\n    try:\n        return hashlib.md5(to_hash.encode(encoding), usedforsecurity=False).hexdigest()\n    except TypeError:\n        return hashlib.md5(to_hash.encode(encoding)).hexdigest()  # nosec\n\n\ndef sha256(to_hash: str, encoding: str = \"utf-8\") -> str:\n    \"\"\"Generate SHA256 hash of a string.\"\"\"\n    return hashlib.sha256(to_hash.encode(encoding)).hexdigest()\n\n\ndef _deduplicate_dataset(\n    dataset: Dataset,\n    seen_hashes: set[str] | None = None,\n) -> tuple[Dataset, set[str]]:\n    \"\"\"Remove duplicate rows from a dataset using SHA256 hashes.\n\n    Args:\n        dataset: Dataset to deduplicate.\n        seen_hashes: Set of previously seen row hashes (for cross-deduplication).\n\n    Returns:\n        Tuple of deduplicated dataset and the set of seen hashes.\n    \"\"\"\n    if seen_hashes is None:\n        seen_hashes = set()\n\n    unique_indices = []\n    for idx, row in enumerate(dataset):\n        row_hash = sha256(str(row))  # Using SHA256 for collision resistance\n        if row_hash not in seen_hashes:\n            seen_hashes.add(row_hash)\n            unique_indices.append(idx)\n\n    return dataset.select(unique_indices), seen_hashes\n\n\ndef deduplicate_and_log_datasets(\n    dataset: Dataset,\n    other_dataset: Dataset | None = None,\n    dataset_name: str | None = \"train\",\n    other_name: str | None = \"eval\",\n) -> tuple[Dataset, Dataset | None]:\n    \"\"\"Deduplicate datasets, with optional cross-dataset deduplication.\n\n    Args:\n        dataset: Primary dataset to deduplicate.\n        other_dataset: Optional second dataset to deduplicate against the first.\n        dataset_name: Name for the primary dataset (for logging).\n        other_name: Name for the second dataset (for logging).\n\n    Returns:\n        Tuple of (deduplicated_dataset, deduplicated_other_dataset).\n    \"\"\"\n    # Deduplicate primary dataset\n    LOG.info(\n        f\"Starting deduplication for {dataset_name} dataset. Original size: {len(dataset)}\"\n    )\n    dataset, seen_rows = _deduplicate_dataset(dataset)\n    LOG.info(\n        f\"Deduplication complete for {dataset_name} dataset. New size: {len(dataset)}\"\n    )\n\n    # Deduplicate second dataset if provided\n    if other_dataset is not None:\n        LOG.info(\n            f\"Starting deduplication for {other_name} dataset. Original size: {len(other_dataset)}\"\n        )\n        other_dataset, _ = _deduplicate_dataset(other_dataset, seen_rows)\n        LOG.info(\n            f\"Deduplication complete for {other_name} dataset. New size: {len(other_dataset)}\"\n        )\n\n    return dataset, other_dataset\n\n\ndef keep_min_len(sample, min_sequence_len=2):\n    \"\"\"\n    Batched filter function that keeps only samples with sequence length >= min_sequence_len.\n    Returns a list of booleans indicating which samples to keep.\n    \"\"\"\n    min_sequence_len = min_sequence_len or 2\n\n    input_ids = sample[\"input_ids\"]\n\n    # Batched (input_ids is a list of lists)\n    results = []\n    for seq in input_ids:\n        results.append(len(seq) >= min_sequence_len)\n    return results\n\n\ndef truncate_long_seq(sample, sequence_len=2048):\n    \"\"\"\n    Truncate samples whose sequence length is too long (> sequence_len).\n    Modifies the sample in-place and returns the modified sample.\n    \"\"\"\n    input_ids = sample[\"input_ids\"]\n\n    # Batched (input_ids is a list of lists)\n    for i, seq in enumerate(input_ids):\n        length = len(seq)\n        if length > sequence_len:\n            sample[\"input_ids\"][i] = seq[:sequence_len]\n            if \"attention_mask\" in sample:\n                sample[\"attention_mask\"][i] = sample[\"attention_mask\"][i][:sequence_len]\n            if \"labels\" in sample:\n                sample[\"labels\"][i] = sample[\"labels\"][i][:sequence_len]\n            if \"position_ids\" in sample:\n                sample[\"position_ids\"][i] = sample[\"position_ids\"][i][:sequence_len]\n    return sample\n\n\ndef _should_skip_processing(dataset: Dataset) -> bool:\n    \"\"\"Check if dataset should skip long sequence handling.\"\"\"\n    if (\n        hasattr(dataset, \"column_names\")\n        and dataset.column_names\n        and \"input_ids\" not in dataset.column_names\n    ):\n        LOG.warning(\n            \"Dataset does not contain 'input_ids' column. Skip drop long seq. This is \"\n            \"expected for reward modeling.\"\n        )\n        return True\n    elif not hasattr(dataset, \"column_names\") or dataset.column_names is None:\n        LOG.info(\n            \"Dataset is streaming (IterableDataset), skipping long sequence handling\"\n        )\n        return True\n    return False\n\n\ndef _log_dataset_stats(dataset: Dataset) -> None:\n    \"\"\"Log min/max sequence lengths for debugging.\"\"\"\n    with contextlib.suppress(AttributeError, ValueError):\n        ds_lengths = get_dataset_lengths(dataset, from_arrow=True)\n        LOG.info(f\"min_input_len: {np.min(ds_lengths)}\")\n        LOG.info(f\"max_input_len: {np.max(ds_lengths)}\")\n\n\ndef _build_filter_kwargs(dataset: Dataset, cfg: DictDefault) -> dict:\n    \"\"\"Build kwargs for dataset filter/map operations.\"\"\"\n    kwargs = {}\n    if not isinstance(dataset, IterableDataset):\n        kwargs[\"num_proc\"] = cfg.dataset_num_proc\n        kwargs[\"load_from_cache_file\"] = not cfg.is_preprocess\n    return kwargs\n\n\ndef _filter_short_sequences(\n    dataset: Dataset, min_len: int, filter_kwargs: dict\n) -> tuple[Dataset, int]:\n    \"\"\"Filter out sequences shorter than min_len. Returns (dataset, num_dropped).\"\"\"\n    prior_len = len(dataset) if hasattr(dataset, \"__len__\") else None\n\n    desc_kwargs = {}\n    if filter_kwargs:\n        desc_kwargs[\"desc\"] = f\"Filtering Short Sequences (<{min_len})\"\n\n    dataset = dataset.filter(\n        functools.partial(keep_min_len, min_sequence_len=min_len),\n        batched=True,\n        **filter_kwargs,\n        **desc_kwargs,\n    )\n\n    dropped = 0\n    if prior_len:\n        dropped = prior_len - len(dataset)\n        if dropped > 0:\n            LOG.info(f\"Dropped {dropped} short sequences (<{min_len} tokens)\")\n\n    return dataset, dropped\n\n\ndef _truncate_long_sequences(\n    dataset: Dataset, max_len: int, map_kwargs: dict\n) -> Dataset:\n    \"\"\"Truncate sequences longer than max_len.\"\"\"\n    desc_kwargs = {}\n    if map_kwargs:\n        desc_kwargs[\"desc\"] = f\"Truncating Sequences (target_len={max_len})\"\n\n    dataset = dataset.map(\n        functools.partial(truncate_long_seq, sequence_len=max_len),\n        batched=True,\n        **map_kwargs,\n        **desc_kwargs,\n    )\n    LOG.info(f\"Truncated long sequences to max length {max_len}\")\n    return dataset\n\n\ndef _drop_outside_range(\n    dataset: Dataset,\n    max_len: int,\n    min_len: int,\n    raise_on_long: bool,\n    filter_kwargs: dict,\n) -> tuple[Dataset, int]:\n    \"\"\"Drop sequences outside valid length range [min_len, max_len].\n\n    Returns (dataset, num_dropped).\"\"\"\n    prior_len = len(dataset) if hasattr(dataset, \"__len__\") else None\n\n    desc_kwargs = {}\n    if filter_kwargs:\n        action = (\n            \"Checking Sequence Lengths\"\n            if raise_on_long\n            else \"Dropping Invalid Sequences\"\n        )\n        desc_kwargs[\"desc\"] = f\"{action} (<{min_len} or >{max_len})\"\n\n    dataset = dataset.filter(\n        functools.partial(\n            filter_sequences_by_length,\n            sequence_len=max_len,\n            min_sequence_len=min_len,\n            raise_on_drop=raise_on_long,\n        ),\n        batched=True,\n        **filter_kwargs,\n        **desc_kwargs,\n    )\n\n    dropped = 0\n    if not raise_on_long and prior_len:\n        dropped = prior_len - len(dataset)\n        if dropped > 0:\n            LOG.info(\n                f\"Dropped {dropped} sequences outside valid range \"\n                f\"([{min_len}, {max_len}])\"\n            )\n\n    return dataset, dropped\n\n\ndef handle_long_seq_in_dataset(\n    dataset: Dataset, sequence_len: int, cfg: DictDefault\n) -> Dataset:\n    \"\"\"Remove sequences longer than configured maximum from dataset.\n\n    Args:\n        dataset: Dataset to filter.\n        sequence_len: Maximum length for sequences to keep\n        cfg: Dictionary mapping `axolotl` config keys to values.\n\n    Returns:\n        Filtered dataset with long sequences handled according to the excess_length_strategy value:\n            'drop' (default)    excludes any sequence longer than sequence_len\n            'truncate'          truncates them down to sequence_len\n            'raise'             raises a ValueError if any sequence was found that was longer than sequence_len\n    \"\"\"\n    # Early returns for special cases\n    if _should_skip_processing(dataset):\n        return dataset\n\n    excess_length_strategy = (cfg.excess_length_strategy or \"drop\").lower()\n\n    _log_dataset_stats(dataset)\n\n    # Setup kwargs\n    filter_kwargs = _build_filter_kwargs(dataset, cfg)\n\n    # Handle sequences based on strategy\n    if excess_length_strategy == \"truncate\":\n        dataset, _ = _filter_short_sequences(dataset, cfg.min_sample_len, filter_kwargs)\n        dataset = _truncate_long_sequences(dataset, sequence_len, filter_kwargs)\n    else:\n        raise_on_long = excess_length_strategy == \"raise\"\n        dataset, _ = _drop_outside_range(\n            dataset, sequence_len, cfg.min_sample_len, raise_on_long, filter_kwargs\n        )\n\n    return dataset\n"
  },
  {
    "path": "src/axolotl/utils/data/wrappers.py",
    "content": "\"\"\"Data handling specific to SFT.\"\"\"\n\nimport logging\nfrom typing import Any, NoReturn, cast\n\nfrom datasets import (\n    Dataset,\n    IterableDataset,\n    Sequence,\n    Value,\n)\nfrom transformers import PreTrainedTokenizer\nfrom transformers.processing_utils import ProcessorMixin\n\nfrom axolotl.datasets import TokenizedPromptDataset, wrap_dataset_for_tokenized_prompt\nfrom axolotl.prompt_strategies import load\nfrom axolotl.prompt_strategies.bradley_terry import load as bradley_terry_load\nfrom axolotl.prompt_tokenizers import (\n    AlpacaMultipleChoicePromptTokenizingStrategy,\n    AlpacaPromptTokenizingStrategy,\n    AlpacaReflectionPTStrategy,\n    DatasetWrappingStrategy,\n    GPTeacherPromptTokenizingStrategy,\n    JeopardyPromptTokenizingStrategy,\n    OpenAssistantPromptTokenizingStrategy,\n    PromptTokenizingStrategy,\n    SummarizeTLDRPromptTokenizingStrategy,\n)\nfrom axolotl.prompters import (\n    AlpacaPrompter,\n    GPTeacherPrompter,\n    JeopardyPrompter,\n    MultipleChoiceConcisePrompter,\n    MultipleChoiceExplainPrompter,\n    Prompter,\n    ReflectAlpacaPrompter,\n    SummarizeTLDRPrompter,\n    UnsupportedPrompter,\n)\nfrom axolotl.utils.dict import DictDefault\n\nLOG = logging.getLogger(__name__)\n\n\ndef handle_unknown_dataset_strategy(dataset_config: DictDefault) -> NoReturn:\n    \"\"\"Raise error for unknown dataset strategy.\"\"\"\n    ds_type = dataset_config.type\n    suffix = \"\"\n    if \":load_\" in ds_type:\n        suffix = f\"Did you mean {ds_type.replace(':load_', '.load_')}?\"\n\n    error_message = f\"unhandled prompt tokenization strategy: {ds_type}. {suffix}\"\n    LOG.error(error_message)\n    raise ValueError(error_message)\n\n\ndef get_dataset_wrapper(\n    dataset_config: DictDefault,\n    tokenizer: PreTrainedTokenizer,\n    cfg: DictDefault,\n    dataset_base_type: str | None,\n    dataset: Dataset | IterableDataset,\n    dataset_prompt_style: str | None = None,\n    processor: ProcessorMixin | None = None,\n) -> tuple[Dataset | IterableDataset, Prompter | None]:\n    \"\"\"Create an appropriate dataset wrapper and prompter based on dataset\n    configuration.\n\n    Args:\n        dataset_config: Configuration for the dataset.\n        tokenizer: Tokenizer to use for processing text.\n        cfg: Global configuration object.\n        dataset_base_type: The base type of the dataset.\n        dataset: The actual dataset object.\n        dataset_prompt_style: Optional prompt style specification.\n        processor: Optional processor for multimodal datasets.\n\n    Returns:\n        tuple of (dataset_wrapper, dataset_prompter).\n    \"\"\"\n    # Common parameters for dataset wrapping\n    dataset_kwargs: dict[str, Any] = {\n        \"process_count\": cfg.dataset_num_proc,\n        \"keep_in_memory\": cfg.dataset_keep_in_memory is True,\n    }\n\n    LOG.info(\n        f\"Loading dataset: {dataset_config['path']} with base_type: \"\n        f\"{dataset_base_type} and prompt_style: {dataset_prompt_style}\"\n    )\n\n    # Dataset is already tokenized\n    if _is_dataset_already_tokenized(dataset):\n        return dataset, UnsupportedPrompter()\n\n    # Custom dataset type definition\n    if isinstance(dataset_config.type, DictDefault):\n        return _handle_custom_dataset_type(\n            dataset_config, tokenizer, cfg, dataset, dataset_kwargs\n        )\n\n    # Skip preparation if configured\n    if cfg.skip_prepare_dataset:\n        return dataset, None\n\n    # Bradley-Terry dataset\n    if dataset_config.type.startswith(\"bradley_terry\"):\n        return _handle_bradley_terry_dataset(\n            dataset_config, tokenizer, cfg, dataset, dataset_kwargs\n        )\n\n    # Stepwise supervised dataset\n    if dataset_config.type.startswith(\"stepwise_supervised\"):\n        return _handle_stepwise_supervised_dataset(\n            dataset_config, tokenizer, cfg, dataset, dataset_kwargs\n        )\n\n    # Try to load prompt tokenizer / dataset wrapper strategy from registry\n    dataset_strategy = load(\n        dataset_config.type, tokenizer, cfg, dataset_config, processor=processor\n    )\n    if dataset_strategy:\n        return _handle_loaded_strategy(dataset_strategy, dataset, dataset_kwargs)\n\n    # Known dataset types with specific handling\n    if dataset_base_type in DATASET_HANDLERS:\n        handler = DATASET_HANDLERS[dataset_base_type]\n        return handler(dataset_prompt_style, tokenizer, cfg, dataset, dataset_kwargs)\n\n    # Unhandled dataset type\n    handle_unknown_dataset_strategy(dataset_config)\n\n\ndef _is_dataset_already_tokenized(dataset: Dataset | IterableDataset) -> bool:\n    \"\"\"Check if the dataset is already tokenized.\"\"\"\n    return (\n        isinstance(dataset, Dataset)\n        and \"input_ids\" in dataset.features\n        and \"attention_mask\" in dataset.features\n        and \"labels\" in dataset.features\n    )\n\n\ndef _handle_custom_dataset_type(\n    dataset_config: DictDefault,\n    tokenizer: PreTrainedTokenizer,\n    cfg: DictDefault,\n    dataset: Dataset | IterableDataset,\n    dataset_kwargs: dict[str, Any],\n) -> tuple[Dataset | IterableDataset, Prompter]:\n    \"\"\"Handle a custom dataset type defined in the configuration.\"\"\"\n    dataset_strategy = cast(\n        PromptTokenizingStrategy,\n        load(\"user_defined\", tokenizer, cfg, dataset_config.type.to_dict()),\n    )\n    dataset_prompter = UnsupportedPrompter()\n    dataset_wrapper = wrap_dataset_for_tokenized_prompt(\n        dataset_strategy,\n        dataset,\n        **dataset_kwargs,\n    )\n    return dataset_wrapper, dataset_prompter\n\n\ndef _handle_bradley_terry_dataset(\n    dataset_config: DictDefault,\n    tokenizer: PreTrainedTokenizer,\n    cfg: DictDefault,\n    dataset: Dataset | IterableDataset,\n    dataset_kwargs: dict[str, Any],\n) -> tuple[Dataset | IterableDataset, Prompter | None]:\n    \"\"\"Handle a Bradley-Terry dataset.\"\"\"\n    bt_type = dataset_config.type.split(\".\", 1)[1]\n    dataset_strategy = bradley_terry_load(bt_type, tokenizer, cfg, dataset_config)\n\n    if not dataset_strategy:\n        handle_unknown_dataset_strategy(dataset_config)\n\n    dataset_prompter = UnsupportedPrompter()\n    dataset_wrapper = wrap_dataset_for_tokenized_prompt(\n        dataset_strategy,\n        dataset,\n        **dataset_kwargs,\n    )\n\n    return dataset_wrapper, dataset_prompter\n\n\ndef _handle_stepwise_supervised_dataset(\n    dataset_config: DictDefault,\n    tokenizer: PreTrainedTokenizer,\n    cfg: DictDefault,\n    dataset: Dataset | IterableDataset,\n    dataset_kwargs: dict[str, Any],\n) -> tuple[Dataset | IterableDataset, Prompter]:\n    \"\"\"Handle a stepwise supervised dataset.\"\"\"\n    dataset_prompter = UnsupportedPrompter()\n    dataset_strategy = load(dataset_config.type, tokenizer, cfg, dataset_config)\n\n    # We need to explicitly cast boolean labels to int\n    # for compatibility with how trl's PRMTrainer works\n    if isinstance(dataset, Dataset):\n        dataset = dataset.cast_column(\"labels\", Sequence(Value(\"int64\")))\n\n    dataset_wrapper = TokenizedPromptDataset(\n        dataset_strategy,\n        dataset,\n        **dataset_kwargs,\n    )\n    return dataset_wrapper, dataset_prompter\n\n\ndef _handle_loaded_strategy(\n    dataset_strategy: PromptTokenizingStrategy | DatasetWrappingStrategy,\n    dataset: Dataset | IterableDataset,\n    dataset_kwargs: dict[str, Any],\n) -> tuple[Dataset | IterableDataset, Prompter | None]:\n    \"\"\"Handle a dataset with a strategy loaded from the registry.\"\"\"\n    if isinstance(dataset_strategy, DatasetWrappingStrategy):\n        return dataset_strategy.wrap_dataset(dataset, **dataset_kwargs), None\n\n    dataset_prompter = UnsupportedPrompter()\n    dataset_wrapper = wrap_dataset_for_tokenized_prompt(\n        dataset_strategy,\n        dataset,\n        **dataset_kwargs,\n    )\n    return dataset_wrapper, dataset_prompter\n\n\ndef _handle_alpaca_dataset(\n    dataset_prompt_style: str | None,\n    tokenizer: PreTrainedTokenizer,\n    cfg: DictDefault,\n    dataset: Dataset | IterableDataset,\n    dataset_kwargs: dict[str, Any],\n) -> tuple[Dataset | IterableDataset, Prompter]:\n    \"\"\"Handle an Alpaca dataset.\"\"\"\n    dataset_prompter = AlpacaPrompter(dataset_prompt_style)\n    dataset_strategy = AlpacaPromptTokenizingStrategy(\n        dataset_prompter,\n        tokenizer,\n        cfg.train_on_inputs,\n        cfg.sequence_len,\n    )\n    dataset_wrapper = wrap_dataset_for_tokenized_prompt(\n        dataset_strategy,\n        dataset,\n        **dataset_kwargs,\n    )\n    return dataset_wrapper, dataset_prompter\n\n\ndef _handle_explainchoice_dataset(\n    dataset_prompt_style: str | None,\n    tokenizer: PreTrainedTokenizer,\n    cfg: DictDefault,\n    dataset: Dataset | IterableDataset,\n    dataset_kwargs: dict[str, Any],\n) -> tuple[Dataset | IterableDataset, Prompter]:\n    \"\"\"Handle an ExplainChoice dataset.\"\"\"\n    dataset_prompter = MultipleChoiceExplainPrompter(dataset_prompt_style)\n    dataset_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(\n        dataset_prompter,\n        tokenizer,\n        cfg.train_on_inputs,\n        cfg.sequence_len,\n    )\n    dataset_wrapper = wrap_dataset_for_tokenized_prompt(\n        dataset_strategy,\n        dataset,\n        **dataset_kwargs,\n    )\n    return dataset_wrapper, dataset_prompter\n\n\ndef _handle_concisechoice_dataset(\n    dataset_prompt_style: str | None,\n    tokenizer: PreTrainedTokenizer,\n    cfg: DictDefault,\n    dataset: Dataset | IterableDataset,\n    dataset_kwargs: dict[str, Any],\n) -> tuple[Dataset | IterableDataset, Prompter]:\n    \"\"\"Handle a ConciseChoice dataset.\"\"\"\n    dataset_prompter = MultipleChoiceConcisePrompter(dataset_prompt_style)\n    dataset_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(\n        dataset_prompter,\n        tokenizer,\n        cfg.train_on_inputs,\n        cfg.sequence_len,\n    )\n    dataset_wrapper = wrap_dataset_for_tokenized_prompt(\n        dataset_strategy,\n        dataset,\n        **dataset_kwargs,\n    )\n    return dataset_wrapper, dataset_prompter\n\n\ndef _handle_summarizetldr_dataset(\n    dataset_prompt_style: str | None,\n    tokenizer: PreTrainedTokenizer,\n    cfg: DictDefault,\n    dataset: Dataset | IterableDataset,\n    dataset_kwargs: dict[str, Any],\n) -> tuple[Dataset | IterableDataset, Prompter]:\n    \"\"\"Handle a SummarizeTLDR dataset.\"\"\"\n    dataset_prompter = SummarizeTLDRPrompter(dataset_prompt_style)\n    dataset_strategy = SummarizeTLDRPromptTokenizingStrategy(\n        dataset_prompter,\n        tokenizer,\n        cfg.train_on_inputs,\n        cfg.sequence_len,\n    )\n    dataset_wrapper = wrap_dataset_for_tokenized_prompt(\n        dataset_strategy,\n        dataset,\n        **dataset_kwargs,\n    )\n    return dataset_wrapper, dataset_prompter\n\n\ndef _handle_jeopardy_dataset(\n    dataset_prompt_style: str | None,\n    tokenizer: PreTrainedTokenizer,\n    cfg: DictDefault,\n    dataset: Dataset | IterableDataset,\n    dataset_kwargs: dict[str, Any],\n) -> tuple[Dataset | IterableDataset, Prompter]:\n    \"\"\"Handle a Jeopardy dataset.\"\"\"\n    dataset_prompter = JeopardyPrompter(dataset_prompt_style)\n    dataset_strategy = JeopardyPromptTokenizingStrategy(\n        dataset_prompter,\n        tokenizer,\n        cfg.train_on_inputs,\n        cfg.sequence_len,\n    )\n    dataset_wrapper = wrap_dataset_for_tokenized_prompt(\n        dataset_strategy,\n        dataset,\n        **dataset_kwargs,\n    )\n    return dataset_wrapper, dataset_prompter\n\n\ndef _handle_oasst_dataset(\n    dataset_prompt_style: str | None,\n    tokenizer: PreTrainedTokenizer,\n    cfg: DictDefault,\n    dataset: Dataset | IterableDataset,\n    dataset_kwargs: dict[str, Any],\n) -> tuple[Dataset | IterableDataset, Prompter]:\n    \"\"\"Handle an OpenAssistant dataset.\"\"\"\n    dataset_prompter = AlpacaPrompter(dataset_prompt_style)\n    dataset_strategy = OpenAssistantPromptTokenizingStrategy(\n        dataset_prompter,\n        tokenizer,\n        cfg.train_on_inputs,\n        cfg.sequence_len,\n    )\n    dataset_wrapper = wrap_dataset_for_tokenized_prompt(\n        dataset_strategy,\n        dataset,\n        **dataset_kwargs,\n    )\n    return dataset_wrapper, dataset_prompter\n\n\ndef _handle_gpteacher_dataset(\n    dataset_prompt_style: str | None,\n    tokenizer: PreTrainedTokenizer,\n    cfg: DictDefault,\n    dataset: Dataset | IterableDataset,\n    dataset_kwargs: dict[str, Any],\n) -> tuple[Dataset | IterableDataset, Prompter]:\n    \"\"\"Handle a GPTeacher dataset.\"\"\"\n    dataset_prompter = GPTeacherPrompter(dataset_prompt_style)\n    dataset_strategy = GPTeacherPromptTokenizingStrategy(\n        dataset_prompter,\n        tokenizer,\n        cfg.train_on_inputs,\n        cfg.sequence_len,\n    )\n    dataset_wrapper = wrap_dataset_for_tokenized_prompt(\n        dataset_strategy,\n        dataset,\n        **dataset_kwargs,\n    )\n    return dataset_wrapper, dataset_prompter\n\n\ndef _handle_reflection_dataset(\n    dataset_prompt_style: str | None,\n    tokenizer: PreTrainedTokenizer,\n    cfg: DictDefault,\n    dataset: Dataset | IterableDataset,\n    dataset_kwargs: dict[str, Any],\n) -> tuple[Dataset | IterableDataset, Prompter]:\n    \"\"\"Handle a Reflection dataset.\"\"\"\n    dataset_prompter = ReflectAlpacaPrompter(dataset_prompt_style)\n    dataset_strategy = AlpacaReflectionPTStrategy(\n        dataset_prompter,\n        tokenizer,\n        cfg.train_on_inputs,\n        cfg.sequence_len,\n    )\n    dataset_wrapper = wrap_dataset_for_tokenized_prompt(\n        dataset_strategy,\n        dataset,\n        **dataset_kwargs,\n    )\n    return dataset_wrapper, dataset_prompter\n\n\nDATASET_HANDLERS = {\n    \"alpaca\": _handle_alpaca_dataset,\n    \"explainchoice\": _handle_explainchoice_dataset,\n    \"concisechoice\": _handle_concisechoice_dataset,\n    \"summarizetldr\": _handle_summarizetldr_dataset,\n    \"jeopardy\": _handle_jeopardy_dataset,\n    \"oasst\": _handle_oasst_dataset,\n    \"gpteacher\": _handle_gpteacher_dataset,\n    \"reflection\": _handle_reflection_dataset,\n}\n"
  },
  {
    "path": "src/axolotl/utils/datasets.py",
    "content": "\"\"\"helper functions for datasets\"\"\"\n\nimport os\n\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\n\ndef get_default_process_count():\n    if axolotl_dataset_num_proc := os.environ.get(\"AXOLOTL_DATASET_NUM_PROC\"):\n        return int(axolotl_dataset_num_proc)\n    if axolotl_dataset_processes := os.environ.get(\"AXOLOTL_DATASET_PROCESSES\"):\n        LOG.warning(\n            \"AXOLOTL_DATASET_PROCESSES and `dataset_processes` are deprecated and will be \"\n            \"removed in a future version. Please use `dataset_num_proc` instead.\"\n        )\n        return int(axolotl_dataset_processes)\n    if runpod_cpu_count := os.environ.get(\"RUNPOD_CPU_COUNT\"):\n        return int(runpod_cpu_count)\n    return os.cpu_count()\n"
  },
  {
    "path": "src/axolotl/utils/dict.py",
    "content": "\"\"\"Module containing the DictDefault class\"\"\"\n\nfrom addict import Dict\n\n\nclass DictDefault(Dict):\n    \"\"\"\n    A Dict that returns None instead of returning empty Dict for missing keys.\n    \"\"\"\n\n    def __missing__(self, key):\n        return None\n\n    def __or__(self, other):\n        return DictDefault(super().__ror__(other))\n\n    def __setitem__(self, name, value):\n        # workaround for pickle/unpickle issues and __frozen not being available\n        try:\n            isFrozen = hasattr(self, \"__frozen\") and object.__getattribute__(\n                self, \"__frozen\"\n            )\n        except AttributeError:\n            isFrozen = False\n\n        if isFrozen and name not in super().keys():\n            raise KeyError(name)\n        super(Dict, self).__setitem__(name, value)\n        try:\n            p = object.__getattribute__(self, \"__parent\")\n            key = object.__getattribute__(self, \"__key\")\n        except AttributeError:\n            p = None\n            key = None\n        if p is not None:\n            p[key] = self\n            object.__delattr__(self, \"__parent\")\n            object.__delattr__(self, \"__key\")\n\n\ndef remove_none_values(obj):\n    \"\"\"\n    Remove null from a dictionary-like obj or list.\n    These can appear due to Dataset loading causing schema merge.\n    See https://github.com/axolotl-ai-cloud/axolotl/pull/2909\n    \"\"\"\n    if hasattr(obj, \"items\"):\n        return {k: remove_none_values(v) for k, v in obj.items() if v is not None}\n    if isinstance(obj, list):\n        return [remove_none_values(elem) for elem in obj]\n    return obj\n"
  },
  {
    "path": "src/axolotl/utils/distributed.py",
    "content": "\"\"\"Utilities for distributed functionality.\"\"\"\n\nimport os\nimport pickle  # nosec\nfrom contextlib import contextmanager\nfrom datetime import timedelta\n\nimport torch\nimport torch.distributed as dist\nfrom accelerate import PartialState\nfrom accelerate.utils import ParallelismConfig\nfrom transformers.utils.import_utils import (\n    is_torch_cuda_available,\n    is_torch_mps_available,\n    is_torch_npu_available,\n)\n\ndistributed_state = None\n\n\ndef get_device_type() -> torch.device:\n    device = torch.device(\"cpu\")\n    if is_torch_cuda_available():\n        device = torch.device(\"cuda\")\n    elif is_torch_mps_available():\n        device = torch.device(\"mps\")\n    elif is_torch_npu_available():\n        device = torch.device(\"npu\")\n    return device\n\n\ndef get_device_count() -> int:\n    cur_device = get_device_type()\n    if \"cuda\" in str(cur_device):\n        return torch.cuda.device_count()\n    if \"npu\" in str(cur_device):\n        return torch.npu.device_count()\n    return 1\n\n\ndef get_current_device() -> int:\n    cur_device = get_device_type()\n    if \"cuda\" in str(cur_device):\n        return torch.cuda.current_device()\n    if \"npu\" in str(cur_device):\n        return torch.npu.current_device()\n    return 0\n\n\ndef init_distributed_state():\n    global distributed_state\n    if distributed_state is None:\n        timeout = int(os.environ.get(\"AXOLOTL_NCCL_TIMEOUT\", 1800))\n        try:\n            distributed_state = PartialState(timeout=timedelta(seconds=timeout))\n        except ValueError:\n            pass\n\n\ndef get_distributed_state() -> PartialState | None:\n    return distributed_state\n\n\ndef is_distributed() -> bool:\n    \"\"\"Check if distributed training is initialized.\"\"\"\n    init_distributed_state()\n\n    if distributed_state is None:\n        return False\n\n    return distributed_state.use_distributed and distributed_state.initialized\n\n\ndef barrier():\n    \"\"\"\n    Acts as a barrier to wait for all processes. This ensures that all processes\n    reach the barrier before proceeding further.\n    \"\"\"\n    if is_distributed():\n        dist.barrier()\n\n\ndef is_main_process() -> bool:\n    \"\"\"\n    Check if the current process is the main process. If not in distributed mode,\n    always return `True`.\n\n    We use a simpler logic when the distributed state is not initialized: we just log\n    on the 0-th local rank.\n\n    Returns:\n        `True` if the current process is the main process, `False` otherwise.\n    \"\"\"\n    if get_distributed_state() is None:\n        return os.environ.get(\"LOCAL_RANK\", \"0\") == \"0\"\n    if not is_distributed():\n        return True\n    return dist.get_rank() == 0\n\n\ndef is_local_main_process() -> bool:\n    if get_distributed_state() is None:\n        return os.environ.get(\"LOCAL_RANK\", \"0\") == \"0\"\n    return PartialState().is_local_main_process\n\n\ndef get_world_size() -> int:\n    return int(os.getenv(\"WORLD_SIZE\", \"1\"))\n\n\ndef cleanup_distributed():\n    \"\"\"\n    Destroy process group if torch distributed is initialized. Called in training early\n    termination or when training successfully completes.\n    \"\"\"\n    # Ensure that all operations are completed before destroying the process group\n    if torch.cuda.is_available():\n        torch.cuda.synchronize()\n\n    if torch.xpu.is_available():\n        torch.xpu.synchronize()\n\n    # Destroy the process group\n    if torch.distributed.is_initialized():\n        torch.distributed.destroy_process_group()\n\n\n@contextmanager\ndef zero_first(is_main: bool):\n    \"\"\"\n    runs the wrapped context so that rank 0 runs first before other ranks\n    \"\"\"\n    if not is_main:  # other ranks wait first\n        barrier()\n    yield\n    if is_main:  # then rank 0 waits after it has run the context\n        barrier()\n\n\ndef gather_scalar_from_all_ranks(fn, world_size=1):\n    \"\"\"\n    Run a callable 'fn' on all ranks and gather the results on the specified rank.\n\n    Args:\n    - fn (callable): A function that computes the value. This should not have any side effects.\n    - rank (int, optional): The rank that gathers the values. Default is 0.\n    - world_size (int, optional): Total number of processes in the current distributed setup.\n\n    Returns:\n    - A list of computed values from all ranks if on the gathering rank, otherwise None.\n    \"\"\"\n    value_scalar = fn()\n    if not is_distributed():\n        return [value_scalar]\n    value_tensor = torch.tensor(\n        value_scalar, device=f\"{get_device_type()}:{get_current_device()}\"\n    ).float()\n\n    if not is_main_process():\n        dist.gather(value_tensor, dst=0)\n    else:\n        gathered_tensors = [torch.zeros_like(value_tensor) for _ in range(world_size)]\n        dist.gather(value_tensor, gather_list=gathered_tensors, dst=0)\n\n        # Convert tensors back to their original type (int or float)\n        gathered_values = []\n        for tensor in gathered_tensors:\n            if tensor == tensor.int():\n                gathered_values.append(int(tensor.item()))\n            else:\n                gathered_values.append(float(tensor.item()))\n        return gathered_values\n    return None\n\n\ndef broadcast_dict(vals: dict):\n    if not is_distributed():\n        return vals\n\n    cur_device = get_device_type()\n    if is_main_process():\n        data_byte = pickle.dumps(vals)\n        data_tensor = torch.ByteTensor(list(data_byte)).to(cur_device)\n        data_size = torch.IntTensor([len(data_byte)]).to(cur_device)\n    else:\n        data_tensor = torch.empty([1024], dtype=torch.uint8, device=cur_device)\n        data_size = torch.IntTensor([0]).to(cur_device)\n\n    dist.broadcast(data_size, 0)\n    if not is_main_process():\n        # resize\n        data_tensor = data_tensor.new_empty([data_size.item()])\n\n    dist.broadcast(data_tensor, 0)\n\n    if not is_main_process():\n        data_list = data_tensor.cpu().tolist()\n        data_byte = bytes(data_list[: data_size.item()])\n        vals = pickle.loads(data_byte)  # nosec\n\n    return vals\n\n\ndef compute_and_broadcast(fn):\n    \"\"\"\n    Compute a value using the function 'fn' only on the specified rank (default is 0).\n    The value is then broadcasted to all other ranks.\n\n    Args:\n    - fn (callable): A function that computes the value. This should not have any side effects.\n    - rank (int, optional): The rank that computes the value. Default is 0.\n\n    Returns:\n    - The computed value (int or float).\n    \"\"\"\n    cur_device = f\"{get_device_type()}:{get_current_device()}\"\n    if is_main_process():\n        value_scalar = fn()\n        value_tensor = torch.tensor(\n            value_scalar, device=cur_device, dtype=torch.float32\n        )\n    else:\n        value_tensor = torch.tensor(\n            0.0, device=cur_device, dtype=torch.float32\n        )  # Placeholder tensor\n\n    # Broadcast the tensor to all processes.\n    barrier()\n    dist.broadcast(value_tensor, src=0)\n\n    # Convert the tensor back to its original type (int or float)\n    if value_tensor == value_tensor.int():\n        return int(value_tensor.item())\n    return float(value_tensor.item())\n\n\ndef gather_from_all_ranks(fn, world_size=1):\n    \"\"\"\n    Run a callable 'fn' on all ranks and gather the results on the specified rank.\n\n    Args:\n    - fn (callable): A function that computes the value. This should not have any side effects.\n    - rank (int, optional): The rank that gathers the values. Default is 0.\n    - world_size (int, optional): Total number of processes in the current distributed setup.\n\n    Returns:\n    - A list of computed values from all ranks if on the gathering rank, otherwise None.\n    \"\"\"\n    value_scalar = fn()\n    value_tensor = torch.tensor(\n        value_scalar, device=f\"{get_device_type()}:{get_current_device()}\"\n    ).float()\n\n    # Placeholder tensor for gathering results\n    if is_main_process():\n        gathered_tensors = [torch.zeros_like(value_tensor) for _ in range(world_size)]\n    else:\n        gathered_tensors = None\n\n    dist.gather(value_tensor, gather_list=gathered_tensors, dst=0)\n\n    if is_main_process():\n        # Convert tensors back to their original type (int or float)\n        gathered_values = []\n        for tensor in gathered_tensors:\n            if tensor == tensor.int():\n                gathered_values.append(int(tensor.item()))\n            else:\n                gathered_values.append(float(tensor.item()))\n        return gathered_values\n    return None\n\n\ndef reduce_and_broadcast(fn1, fn2):\n    \"\"\"\n    Run a callable 'fn1' on all ranks, gather the results, reduce them using 'fn2',\n    and then broadcast the reduced result to all ranks.\n\n    Args:\n    - fn1 (callable): A function that computes the value on each rank.\n    - fn2 (callable): A reduction function that takes a list of values and returns a single value.\n    - world_size (int, optional): Total number of processes in the current distributed setup.\n\n    Returns:\n    - The reduced and broadcasted value.\n    \"\"\"\n\n    # Gather values from all ranks using fn1\n    if not is_distributed():\n        return fn2([fn1()])\n\n    gathered_values = gather_from_all_ranks(fn1, world_size=dist.get_world_size())\n\n    # Use compute_and_broadcast to compute the reduced value on the main process\n    # and then broadcast it to all ranks\n    return compute_and_broadcast(lambda: fn2(gathered_values))\n\n\ndef build_parallelism_config(cfg):\n    pc_kwargs = _get_parallel_config_kwargs(\n        get_world_size(),\n        cfg.tensor_parallel_size,\n        cfg.context_parallel_size,\n        cfg.dp_shard_size,\n        cfg.dp_replicate_size,\n        bool(cfg.fsdp or cfg.fsdp_config),\n    )\n\n    if pc_kwargs:\n        parallelism_config = ParallelismConfig(\n            **pc_kwargs,\n        )\n        device_mesh = parallelism_config.build_device_mesh(\"cuda\")\n\n        return parallelism_config, device_mesh\n    return None, None\n\n\ndef _get_parallel_config_kwargs(\n    world_size: int,\n    tensor_parallel_size: int = 1,\n    context_parallel_size: int = 1,\n    dp_shard_size: int | None = None,\n    dp_replicate_size: int | None = None,\n    is_fsdp: bool = False,\n):\n    pc_kwargs = {}\n    remaining_world_size = world_size\n\n    if tensor_parallel_size and tensor_parallel_size > 1:\n        pc_kwargs[\"tp_size\"] = tensor_parallel_size\n        remaining_world_size = remaining_world_size // tensor_parallel_size\n\n    if context_parallel_size and context_parallel_size > 1:\n        pc_kwargs[\"cp_size\"] = context_parallel_size\n        remaining_world_size = remaining_world_size // context_parallel_size\n\n    if dp_shard_size is None and dp_replicate_size in (None, 1):\n        if remaining_world_size > 1:\n            pc_kwargs[\"dp_shard_size\"] = remaining_world_size\n            remaining_world_size = 1\n\n    if dp_replicate_size and dp_replicate_size > 1:\n        pc_kwargs[\"dp_replicate_size\"] = dp_replicate_size\n        remaining_world_size = remaining_world_size // dp_replicate_size\n\n    if remaining_world_size > 1 and dp_shard_size and dp_shard_size > 1:\n        if not is_fsdp:\n            raise ValueError(\n                \"dp_shard_size was configured without a corresponding fsdp_config! \"\n                \"Please ensure you have configured FSDP using fsdp_config.\"\n            )\n        pc_kwargs[\"dp_shard_size\"] = dp_shard_size\n        remaining_world_size = remaining_world_size // dp_shard_size\n        if remaining_world_size > 1 and \"dp_replicate_size\" not in pc_kwargs:\n            pc_kwargs[\"dp_replicate_size\"] = remaining_world_size\n            remaining_world_size = 1\n\n    if remaining_world_size > 1:\n        if \"dp_shard_size\" not in pc_kwargs and is_fsdp:\n            pc_kwargs[\"dp_shard_size\"] = remaining_world_size\n            remaining_world_size = 1\n\n    if remaining_world_size > 1:\n        raise ValueError(\n            f\"The configured parallelisms are incompatible with the current world size ({get_world_size()})!\\n\"\n            f\"{pc_kwargs}\"\n        )\n\n    return pc_kwargs\n"
  },
  {
    "path": "src/axolotl/utils/environment.py",
    "content": "\"\"\"\nutils to get GPU info for the current environment\n\"\"\"\n\nimport os\nfrom importlib.metadata import version\n\nimport torch\nfrom accelerate.utils.environment import (\n    check_cuda_p2p_ib_support as accelerate_check_cuda_p2p_ib_support,\n)\nfrom packaging.version import Version, parse\n\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\n\ndef check_cuda_p2p_ib_support():\n    if not accelerate_check_cuda_p2p_ib_support():\n        return False\n    if not check_cuda_p2p_support():\n        return False\n    return True\n\n\ndef check_cuda_p2p_support() -> bool:\n    try:\n        world_size = int(os.environ.get(\"WORLD_SIZE\", \"1\"))\n        local_rank = int(os.environ.get(\"LOCAL_RANK\", \"0\"))\n    except ValueError:\n        return True\n\n    if world_size > 1:\n        node_world_size = int(os.environ.get(\"NODE_WORLD_SIZE\", \"8\"))\n        local_other_rank = (local_rank // node_world_size) * node_world_size\n        local_other_rank += 1 if (local_rank % node_world_size) == 0 else 0\n        try:\n            can_p2p = torch.cuda.can_device_access_peer(local_rank, local_other_rank)\n        except AssertionError as exc:\n            # some sort of logic error in indexing processes, assume p2p is fine for now\n            LOG.warning(exc)\n            return True\n        return can_p2p\n\n    return True\n\n\ndef get_package_version(package: str) -> Version:\n    version_str = version(package)\n    return parse(version_str)\n\n\ndef is_package_version_ge(package: str, version_: str) -> bool:\n    package_version = get_package_version(package)\n    return package_version >= parse(version_)\n"
  },
  {
    "path": "src/axolotl/utils/freeze.py",
    "content": "\"\"\"\nmodule to freeze/unfreeze parameters by name\n\"\"\"\n\nimport re\nfrom typing import Callable, List, Tuple, Union\n\nfrom axolotl.utils.distributed import is_main_process\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\n\ndef freeze_layers_except(model, regex_patterns):\n    \"\"\"\n    Freezes all layers of the given model except for the layers that match given regex patterns.\n    Periods in the patterns are treated as literal periods, not as wildcard characters.\n\n    Parameters:\n    - model (nn.Module): The PyTorch model to be modified.\n    - regex_patterns (list of str): List of regex patterns to match layer names to keep unfrozen.\n      Note that you cannot use a dot as a wildcard character in the patterns since it is reserved for separating layer names.\n      Also, to match the entire layer name, the pattern should start with \"^\" and end with \"$\", otherwise it will match any part of the layer name.\n      The range pattern part is optional and it is not compiled as a regex pattern which means you must put \"$\" before the range pattern if you want to match the entire layer name.\n      E.g., [\"^model.embed_tokens.weight$[:32000]\", \"layers.2[0-9]+.block_sparse_moe.gate.[a-z]+$\"]\n\n    Returns:\n    None; the model is modified in place.\n    \"\"\"\n    if isinstance(regex_patterns, str):\n        regex_patterns = [regex_patterns]\n\n    patterns = [LayerNamePattern(pattern) for pattern in regex_patterns]\n\n    # Unfreeze layers that match the regex patterns\n    for name, param in model.named_parameters():\n        param.requires_grad = False\n        unfrozen_ranges = []\n        for pattern in patterns:\n            if not pattern.match(name):\n                continue\n\n            param.requires_grad = True\n\n            if pattern.range is not None:\n                unfrozen_ranges.append(pattern.range)\n\n        merged_unfrozen_ranges = _merge_ranges(unfrozen_ranges, len(param))\n\n        if param.requires_grad and is_main_process():\n            unfrozen_ranges = (\n                f\" with ranges {merged_unfrozen_ranges}\"\n                if merged_unfrozen_ranges\n                else \"\"\n            )\n            LOG.debug(f\"Unfrozen {name}{unfrozen_ranges}\")\n\n        if not merged_unfrozen_ranges:\n            continue\n\n        # The range list we need is actually the inverted of the merged ranges\n        ranges_to_freeze = _invert_ranges(merged_unfrozen_ranges, len(param))\n\n        param.register_hook(_create_freeze_parameters_hook(ranges_to_freeze))\n\n    if is_main_process() and all(\n        not param.requires_grad for param in model.parameters()\n    ):\n        LOG.warning(\"All parameters are frozen. Model will not be trained.\")\n\n\ndef _invert_ranges(\n    given_ranges: List[Tuple[int, int]], layer_size: int\n) -> List[Tuple[int, int]]:\n    \"\"\"\n    Inverts a list of ranges to obtain the ranges not covered by the given ranges.\n\n    Parameters:\n    - given_ranges (List[Tuple[int, int]]): List of ranges to invert. Each range is represented as a tuple of start (inclusive) and end (exclusive) indices.\n    - layer_size (int): The length of the layer. E.g., len(model.layer.weight)\n    Returns:\n    - List[Tuple[int, int]]: List of inverted ranges, where each range is represented as a tuple of start (inclusive) and end (exclusive) indices.\n    \"\"\"\n    if not given_ranges:\n        return [(0, layer_size)]\n\n    inverted_ranges = []\n    current_start = 0\n\n    for start, end in sorted(given_ranges):\n        if start > current_start:\n            inverted_ranges.append((current_start, start))\n        current_start = max(current_start, end)\n\n    # Handle the case where the last given range does not reach the end of the total_size\n    if current_start < layer_size:\n        inverted_ranges.append((current_start, layer_size))\n\n    return inverted_ranges\n\n\ndef _merge_ranges(\n    given_ranges: List[Tuple[int, Union[int, None]]], layer_size: int\n) -> List[Tuple[int, int]]:\n    \"\"\"\n    Merges overlapping ranges and sorts the given ranges.\n\n    This function takes a list of ranges and merges any overlapping ranges. The ranges are represented\n    as tuples, where the first element is the start index (inclusive) and the second element is the end\n    index (exclusive). The end index can be None, indicating that the range extends to the end of the\n    sequence.\n\n    Parameters:\n    - given_ranges (List[Tuple[int, int | None]]): List of ranges to merge.\n    - layer_size (int): The length of the layer. E.g., len(model.layer.weight)\n\n    Returns:\n    - List[Tuple[int, int]]: List of merged ranges, as start (inclusive) and end (exclusive) indices.\n    \"\"\"\n    # End of each range can be determined now since we have the total size\n    processed_ranges = [\n        (start, end if end is not None else layer_size) for start, end in given_ranges\n    ]\n    for start, end in processed_ranges:\n        if start < 0 or end > layer_size > 0 or start >= end:\n            raise ValueError(f\"invalid unfreeze range: start={start}, end={end}\")\n\n    # No need to merge if there's only one or no ranges\n    if len(processed_ranges) <= 1:\n        return processed_ranges\n\n    sorted_ranges = sorted(processed_ranges)\n\n    merged_ranges = [sorted_ranges[0]]\n    for start, end in sorted_ranges[1:]:\n        prev_start, prev_end = merged_ranges[-1]\n        if start <= prev_end:\n            merged_ranges[-1] = (prev_start, max(prev_end, end))\n        else:\n            merged_ranges.append((start, end))\n\n    return merged_ranges\n\n\ndef _create_freeze_parameters_hook(ranges_to_freeze: List[Tuple[int, int]]) -> Callable:\n    \"\"\"\n    Create a hook to freeze parameters in specified ranges by setting their gradients to zero.\n\n    This function takes a list of tuples representing the ranges of indices to freeze. Each tuple should contain\n    two integers representing the start and end indices of the range.\n\n    Parameters:\n    - ranges_to_freeze (List[Tuple[int, int]]): Ranges of indices to freeze.\n\n    Returns:\n    - Callable: A hook function to be used with `register_hook` on parameters.\n\n    Example usage:\n    ```\n    ranges_to_freeze = [(0, 10), (20, 30)]\n    hook = _create_freeze_parameters_hook(ranges_to_freeze)\n    model.register_hook(hook)\n    ```\n    \"\"\"\n\n    def freeze_parameters_hook(gradients):\n        for start, end in ranges_to_freeze:\n            gradients[start:end].zero_()\n\n    return freeze_parameters_hook\n\n\nclass LayerNamePattern:\n    \"\"\"\n    Represents a regex pattern for layer names, potentially including a parameter index range.\n    \"\"\"\n\n    def __init__(self, pattern: str):\n        \"\"\"\n        Initializes a new instance of the LayerNamePattern class.\n\n        Parameters:\n        - pattern (str): The regex pattern for layer names, potentially including a parameter index range.\n        \"\"\"\n        self.raw_pattern = pattern\n        name_pattern, self.range = self._parse_pattern(pattern)\n        self.name_regex = re.compile(re.sub(r\"\\.(?!\\+)\", \"\\\\.\", name_pattern))\n\n    def match(self, name: str) -> bool:\n        \"\"\"\n        Checks if the given layer name matches the regex pattern.\n\n        Parameters:\n        - name (str): The layer name to check.\n\n        Returns:\n        - bool: True if the layer name matches the pattern, False otherwise.\n        \"\"\"\n        return self.name_regex.match(name) is not None\n\n    def _parse_pattern(\n        self, pattern: str\n    ) -> Tuple[str, Union[Tuple[int, Union[int, None]], None]]:\n        \"\"\"\n        Extracts the range pattern from the given pattern.\n\n        Parameters:\n        - pattern (str): The pattern to extract the range from.\n\n        Returns:\n        - Tuple[str, Tuple[int, int | None] | None]: A tuple containing the regex pattern to match the layer name without the range pattern and the range of layer indices to match, if specified.\n        \"\"\"\n        match = re.match(r\"^(.+)\\[([0-9]*)(?::([0-9]*))?\\]$\", pattern)\n        if not match:\n            return pattern, None\n\n        base_pattern, start_part, end_part = match.groups()\n\n        if end_part is None and start_part.isdecimal():\n            index = int(start_part)\n            return base_pattern, (index, index + 1)\n\n        # [:end] or [start:] or [start:end]\n        start = int(start_part) if start_part else 0\n        end = int(end_part) if end_part else None\n\n        if end is not None and start >= end:\n            raise ValueError(\n                f\"Invalid range in layer name pattern: {pattern}.\"\n                \"End of range must be greater than start.\"\n            )\n        return base_pattern, (start, end)\n"
  },
  {
    "path": "src/axolotl/utils/generation/__init__.py",
    "content": "\"\"\"Generation utilities for monitoring during training.\"\"\"\n\nfrom .sft import format_generation_for_logging, generate_samples\n\n__all__ = [\"generate_samples\", \"format_generation_for_logging\"]\n"
  },
  {
    "path": "src/axolotl/utils/generation/sft.py",
    "content": "\"\"\"Sample generation utilities for SFT/Pretrain training.\"\"\"\n\nfrom typing import Any, List, Optional\n\nimport torch\nfrom accelerate.utils import extract_model_from_parallel\nfrom colorama import Fore, Style\n\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\n\ndef generate_samples(\n    model: torch.nn.Module,\n    tokenizer: Any,\n    dataloader: Any,\n    num_generation_samples: int = 3,\n    max_new_tokens: int = 50,\n    temperature: float = 0.7,\n    top_p: Optional[float] = None,\n    top_k: Optional[int] = None,\n    do_sample: bool = True,\n    prompt_ratio: float = 0.5,\n) -> List[dict]:\n    \"\"\"\n    Generate samples from the model during training for monitoring.\n\n    Args:\n        model: The model to generate from\n        tokenizer: The tokenizer to use for encoding/decoding\n        dataloader: Dataloader to sample prompts from\n        num_generation_samples: Number of samples to generate\n        max_new_tokens: Maximum new tokens to generate\n        temperature: Sampling temperature (0.0 = greedy)\n        top_p: Nucleus sampling parameter\n        top_k: Top-k sampling parameter\n        do_sample: Whether to use sampling vs greedy decoding\n        prompt_ratio: Ratio of sequence to use as prompt (0.0-1.0)\n\n    Returns:\n        List of dicts with 'prompt', 'generated', and 'full_text' keys\n    \"\"\"\n    unwrapped_model = extract_model_from_parallel(model)\n\n    training = unwrapped_model.training\n    unwrapped_model.eval()\n\n    device = next(unwrapped_model.parameters()).device\n\n    generations = []\n\n    try:\n        with torch.no_grad():\n            samples_collected = 0\n\n            for batch in dataloader:\n                if samples_collected >= num_generation_samples:\n                    break\n\n                input_ids = batch[\"input_ids\"].to(device)\n                attention_mask = batch.get(\"attention_mask\")\n                if attention_mask is not None:\n                    attention_mask = attention_mask.to(device)\n                batch_size = input_ids.shape[0]\n\n                indices = torch.randperm(batch_size)[\n                    : num_generation_samples - samples_collected\n                ]\n\n                for idx in indices:\n                    if samples_collected >= num_generation_samples:\n                        break\n\n                    sequence = input_ids[idx]\n\n                    if attention_mask is not None:\n                        seq_len = attention_mask[idx].sum().item()\n                    else:\n                        seq_len = sequence.shape[0]\n\n                    if seq_len < 5:\n                        continue\n\n                    prompt_len = max(1, int(seq_len * prompt_ratio))\n                    prompt_ids = sequence[:prompt_len].unsqueeze(0)\n\n                    try:\n                        generation_config = {\n                            \"max_new_tokens\": max_new_tokens,\n                            \"do_sample\": do_sample,\n                            \"pad_token_id\": tokenizer.pad_token_id\n                            if tokenizer.pad_token_id is not None\n                            else tokenizer.eos_token_id,\n                        }\n\n                        if do_sample:\n                            generation_config[\"temperature\"] = temperature\n                            if top_p is not None:\n                                generation_config[\"top_p\"] = top_p\n                            if top_k is not None:\n                                generation_config[\"top_k\"] = top_k\n\n                        generated_ids = unwrapped_model.generate(\n                            prompt_ids, **generation_config\n                        )\n\n                        prompt_text = tokenizer.decode(\n                            prompt_ids[0], skip_special_tokens=True\n                        )\n                        generated_text = tokenizer.decode(\n                            generated_ids[0][prompt_len:], skip_special_tokens=True\n                        )\n                        full_text = tokenizer.decode(\n                            generated_ids[0], skip_special_tokens=True\n                        )\n\n                        generations.append(\n                            {\n                                \"prompt\": prompt_text,\n                                \"generated\": generated_text,\n                                \"full_text\": full_text,\n                            }\n                        )\n\n                        samples_collected += 1\n\n                    except Exception as e:\n                        LOG.warning(f\"Failed to generate sample: {e}\", exc_info=True)\n                        continue\n\n    except Exception as e:\n        LOG.warning(f\"Error during sample generation: {e}\", exc_info=True)\n\n    if training:\n        unwrapped_model.train()\n    else:\n        unwrapped_model.eval()\n\n    return generations\n\n\ndef format_generation_for_logging(\n    sample: dict, sample_idx: int, step: int\n) -> tuple[str, str]:\n    \"\"\"\n    Format a generation sample for pretty logging.\n\n    Args:\n        sample: Dict with 'prompt', 'generated', and 'full_text' keys\n        sample_idx: Index of the sample\n        step: Current training step\n\n    Returns:\n        Tuple of (console_text, wandb_text)\n    \"\"\"\n    console_text = (\n        f\"\\n{Style.BRIGHT}{Fore.CYAN}{'=' * 80}{Style.RESET_ALL}\\n\"\n        f\"{Style.BRIGHT}{Fore.GREEN}Sample {sample_idx + 1} (Step {step}){Style.RESET_ALL}\\n\"\n        f\"{Style.BRIGHT}{Fore.CYAN}{'=' * 80}{Style.RESET_ALL}\\n\"\n        f\"{Style.BRIGHT}{Fore.YELLOW}[PROMPT]{Style.RESET_ALL}\\n{sample['prompt']}\\n\\n\"\n        f\"{Style.BRIGHT}{Fore.MAGENTA}[GENERATED]{Style.RESET_ALL}\\n{sample['generated']}\\n\"\n        f\"{Style.BRIGHT}{Fore.CYAN}{'=' * 80}{Style.RESET_ALL}\\n\"\n    )\n    wandb_text = (\n        f\"\\n{'=' * 80}\\n\"\n        f\"Sample {sample_idx + 1} (Step {step})\\n\"\n        f\"{'=' * 80}\\n\"\n        f\"[PROMPT]\\n{sample['prompt']}\\n\\n\"\n        f\"[GENERATED]\\n{sample['generated']}\\n\"\n        f\"{'=' * 80}\\n\"\n    )\n\n    return console_text, wandb_text\n"
  },
  {
    "path": "src/axolotl/utils/import_helper.py",
    "content": "\"\"\"\nHelper for importing modules from strings\n\"\"\"\n\nimport importlib\n\n\ndef get_cls_from_module_str(module_str: str):\n    # use importlib to dynamically load the reward function from the module\n    if not isinstance(module_str, str) or not module_str.strip():\n        raise ValueError(\"module_str must be a non-empty string\")\n\n    parts = module_str.split(\".\")\n    if len(parts) < 2:\n        raise ValueError(f\"Invalid module string format: {module_str}\")\n\n    try:\n        cls_name = parts[-1]\n        module_path = \".\".join(parts[:-1])\n        mod = importlib.import_module(module_path)\n        mod_cls = getattr(mod, cls_name)\n        return mod_cls\n    except ImportError as e:\n        raise ImportError(f\"Failed to import module '{module_path}': {e}\") from e\n    except AttributeError as e:\n        raise AttributeError(\n            f\"Class '{cls_name}' not found in module '{module_path}': {e}\"\n        ) from e\n"
  },
  {
    "path": "src/axolotl/utils/logging.py",
    "content": "\"\"\"Logging helpers to only log on main process.\"\"\"\n\nimport functools\nimport logging\nimport warnings\n\nfrom axolotl.utils.distributed import is_main_process\n\n# Suppress noisy bitsandbytes warnings about dtype casting during quantization\nwarnings.filterwarnings(\n    \"ignore\",\n    message=\".*MatMul8bitLt: inputs will be cast from.*\",\n    category=UserWarning,\n)\n\n# Adapted from Accelerate\n# https://github.com/huggingface/accelerate/blob/main/src/accelerate/logging.py\n\n\nclass MultiProcessAdapter(logging.LoggerAdapter):\n    \"\"\"\n    Logger adapter for distributed logging, specifically to only log on main process.\n    \"\"\"\n\n    @staticmethod\n    def _should_log(main_process_only: bool):\n        return not main_process_only or is_main_process()\n\n    def log(self, level, msg, *args, **kwargs):\n        main_process_only = kwargs.pop(\"main_process_only\", True)\n        kwargs.setdefault(\"stacklevel\", 2)\n\n        if self.isEnabledFor(level) and self._should_log(main_process_only):\n            msg, kwargs = self.process(msg, kwargs)\n            self.logger.log(level, msg, *args, **kwargs)\n\n    @functools.lru_cache(maxsize=10)\n    def warning_once(self, *args, **kwargs):\n        \"\"\"\n        This method is identical to `logger.warning()`, but will emit the warning with the same message only once\n\n        Note: The cache is for the function arguments, so 2 different callers using the same arguments will hit the\n        cache. The assumption here is that all warning messages are unique across the code. If they aren't then need to\n        switch to another type of cache that includes the caller frame information in the hashing function.\n        \"\"\"\n        self.warning(*args, **kwargs)\n\n\ndef get_logger(name: str, log_level: str | None = None) -> MultiProcessAdapter:\n    logger = logging.getLogger(name)\n    logger.setLevel(logging.DEBUG)\n    return MultiProcessAdapter(logger, extra={})\n"
  },
  {
    "path": "src/axolotl/utils/lora.py",
    "content": "# Copyright 2025 Axolotl AI. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nmodule to get the state dict of a merged lora model\n\"\"\"\n\nimport torch\nfrom peft.tuners.tuners_utils import onload_layer\nfrom peft.utils import ModulesToSaveWrapper, _get_submodules\n\n\ndef get_lora_merged_state_dict(\n    model: torch.nn.Module,\n) -> dict:\n    r\"\"\"\n    Create and return a state_dict that has the LoRA deltas\n    merged into the base model’s weights, without modifying `model` in place.\n\n    Arguments:\n        model (torch.nn.Module): A model that has LoRA/PEFT adapters attached.\n\n    Returns:\n        dict: A state_dict of the merged parameters.\n    \"\"\"\n\n    base_model_prefix = \"base_model.model.\"\n    state_dict = {}\n    key_list = [key for key, _ in model.named_modules() if model.prefix not in key]\n    for key in key_list:\n        try:\n            _, target, _ = _get_submodules(model, key)\n        except AttributeError:\n            continue\n        with onload_layer(target):\n            weight_key = key.replace(base_model_prefix, \"\") + \".weight\"\n            bias_key = key.replace(base_model_prefix, \"\") + \".bias\"\n            if hasattr(target, \"base_layer\"):\n                target.merge(safe_merge=True, adapter_names=None)\n                # get the state_dict of target.base_layer\n                layer_state_dict = target.base_layer.state_dict()\n                state_dict[weight_key] = layer_state_dict[\"weight\"]\n            elif isinstance(target, ModulesToSaveWrapper):\n                # save any additional trainable modules part of `modules_to_save`\n                new_module = target.modules_to_save[target.active_adapter]\n                if hasattr(new_module, \"base_layer\"):\n                    # check if the module is itself a tuner layer\n                    new_module.merge(safe_merge=True, adapter_names=None)\n                layer_state_dict = new_module.state_dict()\n                state_dict[weight_key] = layer_state_dict[\"weight\"]\n            elif hasattr(target, \"weight\"):\n                if any(\n                    skip in key\n                    for skip in [\n                        \".original_module\",\n                        \".modules_to_save\",\n                        \".base_layer\",\n                    ]\n                ):\n                    continue\n                layer_state_dict = target.state_dict()\n                state_dict[weight_key] = layer_state_dict[\"weight\"]\n                if hasattr(target, \"bias\") and \"bias\" in layer_state_dict.keys():\n                    state_dict[bias_key] = layer_state_dict[\"bias\"]\n    return state_dict\n"
  },
  {
    "path": "src/axolotl/utils/mistral/__init__.py",
    "content": "\"\"\"Init for `axolotl.utils.mistral` module.\"\"\"\n\nfrom axolotl.utils.mistral.mistral3_processor import Mistral3Processor\nfrom axolotl.utils.mistral.mistral_tokenizer import HFMistralTokenizer\n\n__all__ = [\"HFMistralTokenizer\", \"Mistral3Processor\"]\n"
  },
  {
    "path": "src/axolotl/utils/mistral/mistral3_processor.py",
    "content": "\"\"\"Processor for Mistral3 multimodal models with image support\"\"\"\n\nfrom typing import Any, Dict, Optional, Union\n\nimport torch\nfrom transformers import ProcessorMixin\nfrom transformers.feature_extraction_utils import BatchFeature\nfrom transformers.processing_utils import ProcessingKwargs\nfrom transformers.tokenization_utils_base import PreTokenizedInput, TextInput\n\nfrom axolotl.utils.mistral.mistral_tokenizer import HFMistralTokenizer\n\n\nclass Mistral3ProcessorKwargs(ProcessingKwargs):\n    _defaults: Dict[str, Dict[str, Any]] = {\n        \"text_kwargs\": {\n            \"padding\": True,\n        },\n        \"common_kwargs\": {\n            \"return_tensors\": \"pt\",\n            \"return_dict\": True,\n            \"tokenize\": True,\n        },\n    }\n\n\nclass Mistral3Processor(ProcessorMixin):\n    \"\"\"\n    Processor for Mistral3 multimodal models that handles text and images.\n    Wraps HFMistralTokenizer and adds image processing capabilities.\n    \"\"\"\n\n    def __init__(self, tokenizer: HFMistralTokenizer):\n        super().__init__(tokenizer)\n\n    @property\n    def audio_tokenizer(self) -> None:\n        \"\"\"Audio tokenizer is not supported. Dummy method to satisfy HuggingFace API.\"\"\"\n        return None\n\n    def _merge_kwargs(\n        self, processor_kwargs_class: Any, **kwargs: Any\n    ) -> Dict[str, Dict[str, Any]]:\n        \"\"\"Merge kwargs with defaults similar to ProcessorMixin\"\"\"\n        defaults = processor_kwargs_class._defaults\n        output_kwargs: Dict[str, Dict[str, Any]] = {}\n\n        for kwarg_type, default_values in defaults.items():\n            output_kwargs[kwarg_type] = {**default_values}\n\n        # Update with provided kwargs\n        for key, value in kwargs.items():\n            # Try to match key to appropriate kwarg type\n            if key in [\"padding\", \"truncation\", \"max_length\"]:\n                output_kwargs.setdefault(\"text_kwargs\", {}).update({key: value})\n            elif key in [\"return_tensors\", \"return_dict\", \"tokenize\"]:\n                output_kwargs.setdefault(\"common_kwargs\", {}).update({key: value})\n            else:\n                # Add to text_kwargs by default\n                output_kwargs.setdefault(\"text_kwargs\", {}).update({key: value})\n\n        return output_kwargs\n\n    def apply_chat_template(\n        self,\n        conversation: Union[list[dict[str, str]], list[list[dict[str, str]]]],\n        **kwargs: Any,\n    ) -> Union[BatchFeature, str, list[str]]:\n        \"\"\"\n        Apply chat template with image support for Mistral3.\n\n        Similar to VoxtralProcessor, this method extracts images from the conversation,\n        calls the tokenizer's apply_chat_template, then adds pixel_values and image_sizes\n        to the result.\n        \"\"\"\n        output_kwargs = self._merge_kwargs(Mistral3ProcessorKwargs, **kwargs)\n        text_kwargs = output_kwargs[\"text_kwargs\"]\n        common_kwargs = output_kwargs[\"common_kwargs\"]\n\n        return_tensors = common_kwargs.pop(\"return_tensors\", \"pt\")\n        if return_tensors != \"pt\":\n            raise ValueError(\n                f\"{self.__class__.__name__} only supports `return_tensors='pt'`.\"\n            )\n\n        return_dict = common_kwargs.pop(\"return_dict\", False)\n        tokenize = common_kwargs.pop(\"tokenize\", False)\n\n        # Determine if batched\n        if isinstance(conversation, (list, tuple)) and (\n            isinstance(conversation[0], (list, tuple))\n            or hasattr(conversation[0], \"content\")\n        ):\n            is_batched = True\n            conversations = conversation\n        else:\n            is_batched = False\n            conversations = [conversation]  # type: ignore\n\n        # Call tokenizer's apply_chat_template\n        tokenizer_kwargs = {**text_kwargs, **common_kwargs}\n        tokenizer_kwargs[\"return_tensors\"] = return_tensors\n        tokenizer_kwargs[\"tokenize\"] = tokenize\n        tokenizer_kwargs[\"return_dict\"] = return_dict\n\n        encoded_instruct_inputs = self.tokenizer.apply_chat_template(\n            conversations,\n            **tokenizer_kwargs,\n        )\n\n        if tokenize:\n            if return_dict:\n                # The tokenizer already handles pixel_values, we just need to add image_sizes\n                if hasattr(encoded_instruct_inputs, \"items\"):\n                    data: Dict[str, Any] = dict(encoded_instruct_inputs)  # type: ignore\n                elif hasattr(encoded_instruct_inputs, \"data\"):\n                    data = encoded_instruct_inputs.data  # type: ignore\n                else:\n                    raise ValueError(\"Unknown data type\")\n\n                if \"pixel_values\" in data:\n                    pixel_values = data[\"pixel_values\"]\n\n                    # MistralTokenizer returns a Double, so we convert to fp32\n                    data[\"pixel_values\"] = pixel_values.to(dtype=torch.float32)\n\n                    # Always batched: [B, C, H, W] -> image_sizes: [B, 2]\n                    # Since tensor is homogeneous, all images have same H, W\n                    batch_size = pixel_values.shape[0]\n                    image_sizes = torch.tensor([pixel_values.shape[-2:]] * batch_size)\n                    data[\"image_sizes\"] = image_sizes\n\n                return BatchFeature(data=data, tensor_type=return_tensors)\n\n        if not is_batched:\n            return encoded_instruct_inputs[0]\n\n        return encoded_instruct_inputs\n\n    def __call__(\n        self,\n        text: Optional[\n            Union[\n                TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]\n            ]\n        ],\n        **kwargs: Any,\n    ) -> BatchFeature:\n        \"\"\"\n        Forward text processing to the tokenizer.\n        This method does not support images - use apply_chat_template instead.\n        \"\"\"\n        output_kwargs = self._merge_kwargs(Mistral3ProcessorKwargs, **kwargs)\n        text_kwargs = output_kwargs[\"text_kwargs\"]\n        common_kwargs = output_kwargs[\"common_kwargs\"]\n\n        out = self.tokenizer(text, **text_kwargs)\n        return BatchFeature(\n            data=out, tensor_type=common_kwargs.pop(\"return_tensors\", None)\n        )\n"
  },
  {
    "path": "src/axolotl/utils/mistral/mistral_tokenizer.py",
    "content": "\"\"\"Wrapper for MistralTokenizer from mistral-common\"\"\"\n\nimport os\nfrom typing import Optional\n\nimport numpy as np\nfrom mistral_common.protocol.instruct.validator import ValidationMode\nfrom mistral_common.tokens.tokenizers.utils import download_tokenizer_from_hf_hub\nfrom torch import Tensor\nfrom transformers.tokenization_mistral_common import MistralCommonBackend\nfrom transformers.tokenization_utils_base import VERY_LARGE_INTEGER\n\n\nclass HFMistralTokenizer(MistralCommonBackend):\n    \"\"\"\n    Wraps mistral_common.tokens.tokenizers.mistral.MistralTokenizer\n    and exposes HuggingFace API for special tokens.\n    \"\"\"\n\n    def __init__(self, name_or_path: str, **kwargs):\n        \"\"\"\n        Args:\n            name_or_path: The name or path to the tokenizer files or the repo id.\n            **kwargs: Additional keyword arguments passed to the parent class.\n        \"\"\"\n        kwargs.pop(\"mode\", None)\n\n        mode = ValidationMode.finetuning\n        super().__init__(**kwargs, mode=mode)\n\n        self._name_or_path = name_or_path\n\n        # set mode as is not set upstream\n        self._set_mode(mode)\n\n    @property\n    def name_or_path(self) -> str:\n        return self._name_or_path\n\n    @name_or_path.setter\n    def name_or_path(self, name_or_path: str) -> None:\n        self._name_or_path = name_or_path\n\n    @property\n    def chat_template(self) -> str | None:\n        \"\"\"Chat template is not supported. Dummy method to satisfy HuggingFace API.\"\"\"\n        return \"[This is a dummy chat template]\"\n\n    @chat_template.setter\n    def chat_template(self, chat_template: str | None) -> None:\n        pass\n\n    def _set_mode(self, mode: ValidationMode):\n        \"\"\"Set the mode of the MistralRequestValidator.\n\n        Args:\n            mode: The mode to set.\n\n        Raises:\n            RuntimeError: If the MistralRequestValidator does not have a _mode attribute.\n        \"\"\"\n        # Check if MistralRequestValidator has a _mode attribute.\n        # This is a private API and may change in the future.\n\n        from mistral_common.protocol.instruct.validator import MistralRequestValidator\n\n        if not (\n            hasattr(self.tokenizer, \"_chat_completion_request_validator\")\n            and isinstance(\n                self.tokenizer._chat_completion_request_validator,\n                MistralRequestValidator,\n            )\n            and hasattr(self.tokenizer._chat_completion_request_validator, \"_mode\")\n        ):\n            raise RuntimeError(\n                f\"Unable to switch mistral tokenizer to {mode.value} mode - \"\n                \"private API `_chat_completion_request_validator._mode` missing.\"\n            )\n\n        self.tokenizer._chat_completion_request_validator._mode = mode\n\n    def apply_chat_template(  # type: ignore\n        self,\n        conversation: list[dict] | list[list[dict]],\n        chat_template: str | None = None,\n        add_generation_prompt: bool = False,\n        **kwargs,\n    ) -> str | list[int]:\n        \"\"\"Patched fn to handle setting test mode, remove chat_template and add_generation_prompt kwarg\"\"\"\n\n        # pop unnecessary kwarg for mistral\n        kwargs.pop(\"real_last_index\", None)\n        kwargs.pop(\"add_special_tokens\", None)\n\n        try:\n            if add_generation_prompt:\n                self._set_mode(ValidationMode.test)\n\n            out = super().apply_chat_template(conversation, **kwargs)\n\n            return out  # type: ignore\n\n        finally:\n            if add_generation_prompt:\n                self._set_mode(ValidationMode.finetuning)\n\n    def decode(  # type: ignore\n        self,\n        token_ids: int | list[int] | np.ndarray | Tensor,\n        **kwargs,\n    ) -> str:\n        \"\"\"\n        Decode token_ids into str.\n\n        This overrides upstream.decode to convert int to list[int]\n        \"\"\"\n\n        if isinstance(token_ids, int):\n            token_ids = [token_ids]\n\n        return super().decode(token_ids, **kwargs)\n\n    @classmethod\n    def from_pretrained(\n        cls,\n        pretrained_model_name_or_path: str | os.PathLike,\n        *init_inputs,\n        mode: ValidationMode = ValidationMode.test,\n        cache_dir: Optional[str | os.PathLike] = None,\n        force_download: bool = False,\n        local_files_only: bool = False,\n        token: Optional[str | bool] = None,\n        revision: str = \"main\",\n        model_max_length: int = VERY_LARGE_INTEGER,\n        padding_side: str = \"left\",\n        truncation_side: str = \"right\",\n        model_input_names: Optional[list[str]] = None,\n        clean_up_tokenization_spaces: bool = False,\n        **kwargs,\n    ):\n        r\"\"\"\n        Patched fn to pass `name_or_path` and remove extra kwargs.\n\n        Instantiate a `MistralCommonBackend` from a predefined\n        tokenizer.\n\n        Args:\n            pretrained_model_name_or_path (`str` or `os.PathLike`):\n                Can be either:\n\n                - A string, the *model id* of a predefined tokenizer hosted inside a model repo on huggingface.co.\n                - A path to a *directory* containing the tokenizer config, for instance saved\n                  using the [`MistralCommonBackend.tokenization_mistral_common.save_pretrained`] method, e.g.,\n                  `./my_model_directory/`.\n            mode (`ValidationMode`, *optional*, defaults to `ValidationMode.test`):\n                Validation mode for the `MistralTokenizer` tokenizer.\n            cache_dir (`str` or `os.PathLike`, *optional*):\n                Path to a directory in which a downloaded predefined tokenizer vocabulary files should be cached if the\n                standard cache should not be used.\n            force_download (`bool`, *optional*, defaults to `False`):\n                Whether or not to force the (re-)download the vocabulary files and override the cached versions if they\n                exist.\n            token (`str` or *bool*, *optional*):\n                The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated\n                when running `hf auth login` (stored in `~/.huggingface`).\n            local_files_only (`bool`, *optional*, defaults to `False`):\n                Whether or not to only rely on local files and not to attempt to download any files.\n            revision (`str`, *optional*, defaults to `\"main\"`):\n                The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a\n                git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any\n                identifier allowed by git.\n            max_length (`int`, *optional*):\n                Controls the maximum length to use by one of the truncation/padding parameters.\n\n                If left unset or set to `None`, this will use the predefined model maximum length if a maximum length\n                is required by one of the truncation/padding parameters. If the model has no specific maximum input\n                length (like XLNet) truncation/padding to a maximum length will be deactivated.\n            padding_side (`str`, *optional*, defaults to `\"left\"`):\n                The side on which the model should have padding applied. Should be selected between ['right', 'left'].\n                Default value is picked from the class attribute of the same name.\n            truncation_side (`str`, *optional*, defaults to `\"right\"`):\n                The side on which the model should have truncation applied. Should be selected between ['right', 'left'].\n            model_input_names (`List[string]`, *optional*):\n                The list of inputs accepted by the forward pass of the model (like `\"token_type_ids\"` or\n                `\"attention_mask\"`). Default value is picked from the class attribute of the same name.\n            clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):\n                Whether or not the model should cleanup the spaces that were added when splitting the input text during the\n                tokenization process.\n            kwargs (additional keyword arguments, *optional*):\n                Not supported by `MistralCommonBackend.from_pretrained`.\n                Will raise an error if used.\n        \"\"\"\n        if init_inputs:\n            raise ValueError(\n                \"`init_inputs` are not supported by `MistralCommonBackend.from_pretrained`.\"\n            )\n\n        # Delete trust_remote_code as it does nothing\n        kwargs.pop(\"trust_remote_code\", None)\n\n        # Delete tokenizer as it does nothing\n        kwargs.pop(\"tokenizer\", None)\n\n        # Handle kwargs and AutoTokenizer case\n        if kwargs and not kwargs.keys() == {\"_from_auto\"}:\n            raise ValueError(\n                f\"Kwargs {list(kwargs.keys())} are not supported by `MistralCommonBackend.from_pretrained`.\"\n            )\n\n        if not os.path.isfile(pretrained_model_name_or_path):\n            tokenizer_path = download_tokenizer_from_hf_hub(\n                repo_id=str(pretrained_model_name_or_path),\n                cache_dir=str(cache_dir),\n                token=token,\n                revision=revision,\n                force_download=force_download,\n                local_files_only=local_files_only,\n            )\n        else:\n            tokenizer_path = str(pretrained_model_name_or_path)\n\n        return cls(\n            name_or_path=str(pretrained_model_name_or_path),\n            tokenizer_path=tokenizer_path,\n            mode=mode,\n            model_max_length=model_max_length,\n            padding_side=padding_side,\n            truncation_side=truncation_side,\n            model_input_names=model_input_names,\n            clean_up_tokenization_spaces=clean_up_tokenization_spaces,\n        )\n\n    def save_pretrained(self, *args, **kwargs) -> tuple[str, ...]:\n        \"\"\"\n        Patches to remove save_jinja_files from being passed onwards.\n        \"\"\"\n        kwargs.pop(\"save_jinja_files\", None)\n        return super().save_pretrained(*args, **kwargs)\n"
  },
  {
    "path": "src/axolotl/utils/mlflow_.py",
    "content": "\"\"\"Module for mlflow utilities\"\"\"\n\nimport os\n\nfrom axolotl.utils.dict import DictDefault\n\n\ndef setup_mlflow_env_vars(cfg: DictDefault):\n    for key in cfg.keys():\n        if key.startswith(\"mlflow_\") or key.startswith(\"hf_mlflow_\"):\n            value = cfg.get(key, \"\")\n\n            if value and isinstance(value, str) and len(value) > 0:\n                os.environ[key.upper()] = value\n\n    # Enable mlflow if experiment name is present\n    if cfg.mlflow_experiment_name and len(cfg.mlflow_experiment_name) > 0:\n        cfg.use_mlflow = True\n\n    # Enable logging hf artifacts in mlflow if value is truthy\n    if cfg.hf_mlflow_log_artifacts is True:\n        os.environ[\"HF_MLFLOW_LOG_ARTIFACTS\"] = \"true\"\n"
  },
  {
    "path": "src/axolotl/utils/model_shard_quant.py",
    "content": "\"\"\"\nmodule to handle loading model on cpu/meta device for FSDP\n\"\"\"\n\nimport os\nimport time\nfrom typing import List, Optional, Type, Union\n\nimport safetensors\nimport torch\nfrom accelerate import init_empty_weights\nfrom bitsandbytes.nn import Linear4bit, Params4bit\nfrom fastcore.parallel import parallel\nfrom torch import Tensor, nn\nfrom tqdm import tqdm\nfrom transformers import AutoModelForCausalLM\nfrom transformers.quantizers import AutoHfQuantizer\nfrom transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, hub\n\n\ndef _replace_linear(\n    model: nn.Module,\n    linear_replacement: Type[nn.Module],\n    quant_config: Union[dict, None] = None,\n    skip_modules=None,\n    **kwargs,\n):\n    \"\"\"\n    Replace linear modules with a new Linear module.\n    Parameters:\n        model (`torch.nn.Module`):\n            Input model or `torch.nn.Module` as the function is run recursively.\n        linear_replacement (`torch.nn.Module`):\n            The linear module that replaces the old one. Only expects standard arguments.\n            If other arguments need to be passed, use a lambda.\n        skip_modules (`List[str]`, *optional*, defaults to `lm_head`):\n            List of modules names not to convert. Defaults to `lm_head`.\n    \"\"\"\n    if skip_modules is None:\n        skip_modules = [\"lm_head\"]\n    for name, module in model.named_children():\n        if len(list(module.children())) > 0:\n            _replace_linear(\n                module, linear_replacement, quant_config, skip_modules, **kwargs\n            )\n\n        if isinstance(module, torch.nn.Linear) and name not in skip_modules:\n            if issubclass(linear_replacement, Linear4bit):\n                model._modules[name] = linear_replacement(\n                    module.in_features,\n                    module.out_features,\n                    module.bias is not None,\n                    **kwargs,\n                )\n            else:\n                raise ValueError(\n                    f\"Unsupported linear replacement: {type(linear_replacement)}\"\n                )\n    return model\n\n\ndef load_and_quantize(\n    module: nn.Module,\n    name: str,\n    value: Tensor,\n    device: torch.device = None,\n    dtype: torch.dtype = None,\n    skip_names: Optional[List[str]] = None,\n    to_cpu: bool = False,\n    to_meta: bool = False,\n    verbose: bool = False,\n    quant_method: str = \"bnb\",\n):\n    \"\"\"\n    Loads `value` tensor into submodule of `module`, optionally skipping `skip_names` and converting to `dtype`.\n\n    Quantizes `Params4bit` on `device` then places on \"cpu\" if to_cpu=True or \"meta\" if to_meta=True.\n    \"\"\"\n\n    if not skip_names:\n        skip_names = []\n\n    def place_on_device(value):\n        if to_meta:\n            device = \"meta\"\n        elif to_cpu:\n            device = \"cpu\"\n        return value.to(device=device, dtype=dtype)\n\n    if any(skip_name in name for skip_name in skip_names):\n        if verbose:\n            print(f\"Skipping {name} because it is in skip_names\")\n        return\n\n    module_key, _, value_key = name.rpartition(\".\")\n    try:\n        submodule = module.get_submodule(module_key)\n    except AttributeError as exc:\n        print(f\"Module {module_key} not found:\\n{exc}\")\n        return\n\n    try:\n        if quant_method == \"bnb\":\n            param = submodule.get_parameter(value_key)\n            if isinstance(param, Params4bit):\n                # With `sync_module_states=True`, a meta device Params4bit needs to be the same\n                # shape as the quantized Params4bit with an initialized quant_state. However,\n                # FSDP only syncs parameters and buffers, so the quant_state isn't copied. This\n                # workaround quantizes Params4bit to initialize quant_state on all ranks, then\n                # replaces Params4bit's data with a meta tensor to free memory on non-rank 0.\n                value = type(param)(\n                    value.to(device=device, dtype=dtype).data, **param.__dict__\n                ).cuda(device)\n                if to_meta:\n                    value = type(param)(value.data.to(\"meta\"), **value.__dict__)\n                elif to_cpu:\n                    value = type(param)(value.data.to(\"cpu\"), **value.__dict__)\n            else:\n                value = type(param)(place_on_device(value).data)\n\n    except AttributeError:\n        # it's a buffer\n        value = place_on_device(value)\n\n    setattr(submodule, value_key, value)\n\n\ndef n_loading_workers(quant_method: str, param_count: float):\n    devprops = torch.cuda.get_device_properties(torch.cuda.current_device())\n    left = int(os.cpu_count() / torch.cuda.device_count())\n    model_params_b = 70\n    right = int(\n        (4 if quant_method == \"hqq\" else 8)\n        * (devprops.total_memory / 1e9 / 40)\n        * (model_params_b / (param_count / 1e9))\n    )\n    return min(left, right)\n\n\ndef load_sharded_model(\n    model_name,\n    model_config,\n    cfg,\n    torch_dtype=torch.bfloat16,\n    low_memory=True,\n):\n    if (low_memory and cfg.local_rank == 0) or not low_memory:\n        model = AutoModelForCausalLM.from_pretrained(\n            model_name,\n            use_cache=False,\n            dtype=torch.float32,\n            _attn_implementation=model_config._attn_implementation,\n            trust_remote_code=cfg.trust_remote_code,\n        )\n        dtype = torch_dtype if not cfg.float32 else None\n        model.to(dtype=dtype, device=\"cpu\" if low_memory else cfg.local_rank)\n    else:\n        with init_empty_weights():\n            model = AutoModelForCausalLM.from_config(\n                model_config,\n                dtype=torch_dtype,\n                trust_remote_code=cfg.trust_remote_code,\n            )\n    return model\n\n\ndef load_sharded_model_quant(\n    model_name,\n    model_config,\n    cfg,\n    compute_dtype=torch.bfloat16,\n    quant_storage=torch.float32,\n    low_memory=True,\n    verbose=False,\n    loading_workers=2,\n    quantization_config=None,\n):\n    with init_empty_weights():\n        model = AutoModelForCausalLM.from_config(\n            model_config,\n            trust_remote_code=cfg.trust_remote_code,\n        )\n        if hasattr(model, \"transformer\"):\n            model.transformer = _replace_linear(\n                model.transformer,\n                Linear4bit,\n                compute_dtype=compute_dtype,\n                quant_type=\"nf4\",\n                quant_storage=quant_storage,\n                compress_statistics=True,  # bnb_4bit_use_double_quant\n                skip_modules=[\n                    \"lm_head\",\n                    \"embed_out\",\n                ],\n            )\n        else:\n            # this is the more common case with HF transformers\n            # TODO can we detect the model arch and dynamically set skip_modules\n            model.model = _replace_linear(\n                model.model,\n                Linear4bit,\n                compute_dtype=compute_dtype,\n                quant_type=\"nf4\",\n                quant_storage=quant_storage,\n                compress_statistics=True,  # bnb_4bit_use_double_quant\n                skip_modules=[\n                    \"lm_head\",\n                    \"embed_out\",\n                ],\n            )\n    model.is_loaded_in_4bit = True\n\n    # Grab the safetensors files that hold the weights\n    try:\n        idx = hub.cached_file(model_name, SAFE_WEIGHTS_INDEX_NAME)\n        files, _ = hub.get_checkpoint_shard_files(model_name, idx)\n    except OSError:\n        try:\n            # This means the model doesn't have a model.safetensors.index.json because it is not sharded\n            files = []\n            files.append(hub.cached_file(model_name, SAFE_WEIGHTS_NAME))\n        except OSError as exc:\n            # This means the model probably doesn't have a safetensors file\n            raise exc\n\n    # Load in the weights, using our custom load_and_quantize method which quantizes Params4bit on the fly\n    # and then places each layer on CPU or meta if using low_memory to minimize GPU memory usage\n    def load_and_quantize_parallel(name_param, model, **kwargs):\n        name, param = name_param\n        load_and_quantize(model, name, param, **kwargs)\n\n    quant_method = \"bnb\"\n    param_count = sum((p.numel() for n, p in model.named_parameters()))\n\n    n_workers = (\n        n_loading_workers(quant_method, param_count)\n        if loading_workers == -1\n        else loading_workers\n    )\n    if cfg.local_rank == 0 and verbose:\n        print(f\"Using n_workers: {n_workers} for loading\")\n\n    start = time.time()\n    for filename in tqdm(\n        files,\n        desc=\"Loading & Quantizing Model Shards\",\n        disable=cfg.local_rank != 0,\n        position=0,\n    ):\n        weights = safetensors.torch.load_file(filename)\n        parallel(\n            load_and_quantize_parallel,\n            iter(weights.items()),\n            n_workers=n_workers,\n            threadpool=True,\n            model=model,\n            dtype=quant_storage,\n            device=cfg.local_rank,\n            skip_names=[],\n            to_cpu=(low_memory and cfg.local_rank == 0),\n            to_meta=(low_memory and cfg.local_rank != 0),\n            verbose=verbose,\n            quant_method=quant_method,\n        )\n\n    # these attributes are needed to inform transformers/peft of the quantization\n    model.is_quantized = True\n    model.quantization_method = \"bitsandbytes\"\n    model.hf_quantizer = AutoHfQuantizer.from_config(quantization_config)\n\n    if cfg.local_rank == 0 and verbose:\n        print(f\"Loaded model weights in {time.time() - start:.3f} seconds\")\n    # cleanup any extra memory usage from parallel loading\n    torch.cuda.empty_cache()\n\n    return model\n"
  },
  {
    "path": "src/axolotl/utils/optimizers/__init__.py",
    "content": ""
  },
  {
    "path": "src/axolotl/utils/optimizers/adopt.py",
    "content": "\"\"\"\nCopied from https://github.com/iShohei220/adopt\n\nADOPT: Modified Adam Can Converge with Any β2 with the Optimal Rate (2024)\nTaniguchi, Shohei and Harada, Keno and Minegishi, Gouki and Oshima, Yuta and Jeong, Seong Cheol and Nagahara, Go and Iiyama, Tomoshi and Suzuki, Masahiro and Iwasawa, Yusuke and Matsuo, Yutaka\n\"\"\"\n\n# mypy: ignore-errors\n# flake8: noqa\n# mypy: allow-untyped-decorators\n# mypy: allow-untyped-defs\nfrom typing import Callable, List, Optional, Tuple, Union, cast\n\nimport torch\nfrom torch import Tensor\nfrom torch.optim.optimizer import (  # DeviceDict,; _capturable_doc,; _differentiable_doc,; _foreach_doc,; _fused_doc,; _maximize_doc,; _stack_if_compiling,\n    DeviceDict,\n    Optimizer,\n    ParamsT,\n    _capturable_doc,\n    _default_to_fused_or_foreach,\n    _device_dtype_check_for_fused,\n    _differentiable_doc,\n    _disable_dynamo_if_unsupported,\n    _foreach_doc,\n    _fused_doc,\n    _get_capturable_supported_devices,\n    _get_scalar_dtype,\n    _get_value,\n    _maximize_doc,\n    _stack_if_compiling,\n    _use_grad_for_differentiable,\n    _view_as_real,\n)\n\n__all__ = [\"ADOPT\", \"adopt\"]\n\n\nclass ADOPT(Optimizer):\n    def __init__(\n        self,\n        params: ParamsT,\n        lr: Union[float, Tensor] = 1e-3,\n        betas: Tuple[float, float] = (0.9, 0.9999),\n        eps: float = 1e-6,\n        clip_lambda: Optional[Callable[[int], float]] = lambda step: step**0.25,\n        weight_decay: float = 0.0,\n        decouple: bool = False,\n        *,\n        foreach: Optional[bool] = None,\n        maximize: bool = False,\n        capturable: bool = False,\n        differentiable: bool = False,\n        fused: Optional[bool] = None,\n    ):\n        if isinstance(lr, Tensor):\n            if foreach and not capturable:\n                raise ValueError(\n                    \"lr as a Tensor is not supported for capturable=False and foreach=True\"\n                )\n            if lr.numel() != 1:\n                raise ValueError(\"Tensor lr must be 1-element\")\n        if not 0.0 <= lr:\n            raise ValueError(f\"Invalid learning rate: {lr}\")\n        if not 0.0 <= eps:\n            raise ValueError(f\"Invalid epsilon value: {eps}\")\n        if not 0.0 <= betas[0] < 1.0:\n            raise ValueError(f\"Invalid beta parameter at index 0: {betas[0]}\")\n        if not 0.0 <= betas[1] < 1.0:\n            raise ValueError(f\"Invalid beta parameter at index 1: {betas[1]}\")\n        if not 0.0 <= weight_decay:\n            raise ValueError(f\"Invalid weight_decay value: {weight_decay}\")\n\n        self.clip_lambda = clip_lambda\n\n        defaults = dict(\n            lr=lr,\n            betas=betas,\n            eps=eps,\n            weight_decay=weight_decay,\n            decouple=decouple,\n            maximize=maximize,\n            foreach=foreach,\n            capturable=capturable,\n            differentiable=differentiable,\n            fused=fused,\n        )\n        super().__init__(params, defaults)\n\n        if fused:\n            # TODO: support fused\n            raise RuntimeError(\"`fused` is not currently supported\")\n\n            if differentiable:\n                raise RuntimeError(\"`fused` does not support `differentiable`\")\n            self._step_supports_amp_scaling = True\n            # TODO(crcrpar): [low prec params & their higher prec copy]\n            # Support AMP with FP16/BF16 model params which would need\n            # higher prec copy of params to do update math in higher prec to\n            # alleviate the loss of information.\n            if foreach:\n                raise RuntimeError(\"`fused` and `foreach` cannot be `True` together.\")\n\n    def __setstate__(self, state):\n        super().__setstate__(state)\n        for group in self.param_groups:\n            group.setdefault(\"maximize\", False)\n            group.setdefault(\"foreach\", None)\n            group.setdefault(\"capturable\", False)\n            group.setdefault(\"differentiable\", False)\n            fused = group.setdefault(\"fused\", None)\n            for p in group[\"params\"]:\n                p_state = self.state.get(p, [])\n                if len(p_state) != 0 and not torch.is_tensor(p_state[\"step\"]):\n                    step_val = float(p_state[\"step\"])\n                    p_state[\"step\"] = (\n                        torch.tensor(\n                            step_val,\n                            dtype=_get_scalar_dtype(is_fused=fused),\n                            device=p.device,\n                        )\n                        if group[\"capturable\"] or group[\"fused\"]\n                        else torch.tensor(step_val, dtype=_get_scalar_dtype())\n                    )\n\n    def _init_group(\n        self,\n        group,\n        params_with_grad,\n        grads,\n        exp_avgs,\n        exp_avg_sqs,\n        state_steps,\n    ):\n        has_complex = False\n        for p in group[\"params\"]:\n            if p.grad is not None:\n                has_complex |= torch.is_complex(p)\n                params_with_grad.append(p)\n                if p.grad.is_sparse:\n                    raise RuntimeError(\"ADOPT does not support sparse gradients\")\n                grads.append(p.grad)\n\n                state = self.state[p]\n                # Lazy state initialization\n                if len(state) == 0:\n                    if group[\"fused\"]:\n                        _device_dtype_check_for_fused(p)\n                    # note(crcrpar): [special device hosting for step]\n                    # Deliberately host `step` on CPU if both capturable and fused are off.\n                    # This is because kernel launches are costly on CUDA and XLA.\n                    state[\"step\"] = (\n                        torch.zeros(\n                            (),\n                            dtype=_get_scalar_dtype(is_fused=group[\"fused\"]),\n                            device=p.device,\n                        )\n                        if group[\"capturable\"] or group[\"fused\"]\n                        else torch.tensor(0.0, dtype=_get_scalar_dtype())\n                    )\n                    # Exponential moving average of gradient values\n                    state[\"exp_avg\"] = torch.zeros_like(\n                        p, memory_format=torch.preserve_format\n                    )\n                    # Exponential moving average of squared gradient values\n                    state[\"exp_avg_sq\"] = torch.zeros_like(\n                        p, memory_format=torch.preserve_format\n                    )\n\n                exp_avgs.append(state[\"exp_avg\"])\n                exp_avg_sqs.append(state[\"exp_avg_sq\"])\n\n                if group[\"differentiable\"] and state[\"step\"].requires_grad:\n                    raise RuntimeError(\n                        \"`requires_grad` is not supported for `step` in differentiable mode\"\n                    )\n\n                # Foreach without capturable does not support a tensor lr\n                if (\n                    group[\"foreach\"]\n                    and torch.is_tensor(group[\"lr\"])\n                    and not group[\"capturable\"]\n                ):\n                    raise RuntimeError(\n                        \"lr as a Tensor is not supported for capturable=False and foreach=True\"\n                    )\n\n                state_steps.append(state[\"step\"])\n        return has_complex\n\n    @_use_grad_for_differentiable\n    def step(self, closure=None):\n        \"\"\"Perform a single optimization step.\n\n        Args:\n            closure (Callable, optional): A closure that reevaluates the model\n                and returns the loss.\n        \"\"\"\n        self._cuda_graph_capture_health_check()\n\n        loss = None\n        if closure is not None:\n            with torch.enable_grad():\n                loss = closure()\n\n        for group in self.param_groups:\n            params_with_grad: List[Tensor] = []\n            grads: List[Tensor] = []\n            exp_avgs: List[Tensor] = []\n            exp_avg_sqs: List[Tensor] = []\n            state_steps: List[Tensor] = []\n            beta1, beta2 = group[\"betas\"]\n\n            has_complex = self._init_group(\n                group,\n                params_with_grad,\n                grads,\n                exp_avgs,\n                exp_avg_sqs,\n                state_steps,\n            )\n\n            adopt(\n                params_with_grad,\n                grads,\n                exp_avgs,\n                exp_avg_sqs,\n                state_steps,\n                has_complex=has_complex,\n                beta1=beta1,\n                beta2=beta2,\n                lr=group[\"lr\"],\n                clip_lambda=self.clip_lambda,\n                weight_decay=group[\"weight_decay\"],\n                decouple=group[\"decouple\"],\n                eps=group[\"eps\"],\n                maximize=group[\"maximize\"],\n                foreach=group[\"foreach\"],\n                capturable=group[\"capturable\"],\n                differentiable=group[\"differentiable\"],\n                fused=group[\"fused\"],\n                grad_scale=getattr(self, \"grad_scale\", None),\n                found_inf=getattr(self, \"found_inf\", None),\n            )\n\n        return loss\n\n\ndef _single_tensor_adopt(\n    params: List[Tensor],\n    grads: List[Tensor],\n    exp_avgs: List[Tensor],\n    exp_avg_sqs: List[Tensor],\n    state_steps: List[Tensor],\n    grad_scale: Optional[Tensor],\n    found_inf: Optional[Tensor],\n    *,\n    has_complex: bool,\n    beta1: float,\n    beta2: float,\n    lr: Union[float, Tensor],\n    clip_lambda: Optional[Callable[[int], float]],\n    weight_decay: float,\n    decouple: bool,\n    eps: float,\n    maximize: bool,\n    capturable: bool,\n    differentiable: bool,\n):\n    assert grad_scale is None and found_inf is None\n\n    if torch.jit.is_scripting():\n        # this assert is due to JIT being dumb and not realizing that the ops below\n        # have overloads to handle both float and Tensor lrs, so we just assert it's\n        # a float since most people using JIT are using floats\n        assert isinstance(lr, float)\n\n    for i, param in enumerate(params):\n        grad = grads[i] if not maximize else -grads[i]\n        exp_avg = exp_avgs[i]\n        exp_avg_sq = exp_avg_sqs[i]\n        step_t = state_steps[i]\n\n        # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]\n        if not torch._utils.is_compiling() and capturable:\n            capturable_supported_devices = _get_capturable_supported_devices()\n            assert (\n                param.device.type == step_t.device.type\n                and param.device.type in capturable_supported_devices\n            ), (\n                f\"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}.\"\n            )\n\n        step = step_t if capturable or differentiable else _get_value(step_t)\n\n        if weight_decay != 0 and not decouple:\n            grad = grad.add(param, alpha=weight_decay)\n\n        if torch.is_complex(param):\n            grad = torch.view_as_real(grad)\n            if exp_avg is not None:\n                exp_avg = torch.view_as_real(exp_avg)\n            if exp_avg_sq is not None:\n                exp_avg_sq = torch.view_as_real(exp_avg_sq)\n            param = torch.view_as_real(param)\n\n        if step == 0:\n            exp_avg_sq.addcmul_(grad, grad.conj())\n            # update step\n            step_t += 1\n            continue\n\n        if weight_decay != 0 and decouple:\n            param.add_(param, alpha=-lr * weight_decay)\n\n        denom = torch.clamp(exp_avg_sq.sqrt(), eps)\n        normed_grad = grad.div(denom)\n        if clip_lambda is not None:\n            clip = clip_lambda(step)\n            normed_grad.clamp_(-clip, clip)\n\n        exp_avg.lerp_(normed_grad, 1 - beta1)\n\n        param.add_(exp_avg, alpha=-lr)\n        exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)\n\n        # update step\n        step_t += 1\n\n\ndef _multi_tensor_adopt(\n    params: List[Tensor],\n    grads: List[Tensor],\n    exp_avgs: List[Tensor],\n    exp_avg_sqs: List[Tensor],\n    state_steps: List[Tensor],\n    grad_scale: Optional[Tensor],\n    found_inf: Optional[Tensor],\n    *,\n    has_complex: bool,\n    beta1: float,\n    beta2: float,\n    lr: Union[float, Tensor],\n    clip_lambda: Optional[Callable[[int], float]],\n    weight_decay: float,\n    decouple: bool,\n    eps: float,\n    maximize: bool,\n    capturable: bool,\n    differentiable: bool,\n):\n    if len(params) == 0:\n        return\n\n    if isinstance(lr, Tensor) and not capturable:\n        raise RuntimeError(\n            \"lr as a Tensor is not supported for capturable=False and foreach=True\"\n        )\n\n    # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]\n    if not torch._utils.is_compiling() and capturable:\n        capturable_supported_devices = _get_capturable_supported_devices(\n            supports_xla=False\n        )\n        assert all(\n            p.device.type == step.device.type\n            and p.device.type in capturable_supported_devices\n            for p, step in zip(params, state_steps)\n        ), (\n            f\"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}.\"\n        )\n\n    assert grad_scale is None and found_inf is None\n\n    assert not differentiable, \"_foreach ops don't support autograd\"\n\n    grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(\n        [params, grads, exp_avgs, exp_avg_sqs, state_steps]  # type: ignore[list-item]\n    )\n    for (\n        device_params_,\n        device_grads_,\n        device_exp_avgs_,\n        device_exp_avg_sqs_,\n        device_state_steps_,\n    ), _ in grouped_tensors.values():\n        device_params = cast(List[Tensor], device_params_)\n        device_grads = cast(List[Tensor], device_grads_)\n        device_exp_avgs = cast(List[Tensor], device_exp_avgs_)\n        device_exp_avg_sqs = cast(List[Tensor], device_exp_avg_sqs_)\n        device_state_steps = cast(List[Tensor], device_state_steps_)\n\n        # Handle complex parameters\n        if has_complex:\n            _view_as_real(\n                device_params, device_grads, device_exp_avgs, device_exp_avg_sqs\n            )\n\n        if maximize:\n            device_grads = torch._foreach_neg(device_grads)  # type: ignore[assignment]\n\n        if weight_decay != 0 and not decouple:\n            # Re-use the intermediate memory (device_grads) already allocated for maximize\n            if maximize:\n                torch._foreach_add_(device_grads, device_params, alpha=weight_decay)\n            else:\n                device_grads = torch._foreach_add(  # type: ignore[assignment]\n                    device_grads, device_params, alpha=weight_decay\n                )\n\n        if device_state_steps[0] == 0:\n            torch._foreach_addcmul_(device_exp_avg_sqs, device_grads, device_grads)\n\n            # Update steps\n            # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over\n            # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just\n            # wrapped it once now. The alpha is required to assure we go to the right overload.\n            if not torch._utils.is_compiling() and device_state_steps[0].is_cpu:\n                torch._foreach_add_(\n                    device_state_steps, torch.tensor(1.0, device=\"cpu\"), alpha=1.0\n                )\n            else:\n                torch._foreach_add_(device_state_steps, 1)\n\n            continue\n\n        if weight_decay != 0 and decouple:\n            torch._foreach_add_(device_params, device_params, alpha=-lr * weight_decay)\n\n        exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)\n        torch._foreach_maximum_(exp_avg_sq_sqrt, eps)\n\n        normed_grad = torch._foreach_div(device_grads, exp_avg_sq_sqrt)\n        if clip_lambda is not None:\n            clip = clip_lambda(device_state_steps[0])\n            torch._foreach_maximum_(normed_grad, -clip)\n            torch._foreach_minimum_(normed_grad, clip)\n\n        torch._foreach_lerp_(device_exp_avgs, normed_grad, 1 - beta1)\n\n        torch._foreach_add_(device_params, device_exp_avgs, alpha=-lr)\n        torch._foreach_mul_(device_exp_avg_sqs, beta2)\n        torch._foreach_addcmul_(\n            device_exp_avg_sqs, device_grads, device_grads, value=1 - beta2\n        )\n\n        # Update steps\n        # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over\n        # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just\n        # wrapped it once now. The alpha is required to assure we go to the right overload.\n        if not torch._utils.is_compiling() and device_state_steps[0].is_cpu:\n            torch._foreach_add_(\n                device_state_steps, torch.tensor(1.0, device=\"cpu\"), alpha=1.0\n            )\n        else:\n            torch._foreach_add_(device_state_steps, 1)\n\n\n@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adopt)\ndef adopt(\n    params: List[Tensor],\n    grads: List[Tensor],\n    exp_avgs: List[Tensor],\n    exp_avg_sqs: List[Tensor],\n    state_steps: List[Tensor],\n    # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627\n    # setting this as kwarg for now as functional API is compiled by torch/distributed/optim\n    foreach: Optional[bool] = None,\n    capturable: bool = False,\n    differentiable: bool = False,\n    fused: Optional[bool] = None,\n    grad_scale: Optional[Tensor] = None,\n    found_inf: Optional[Tensor] = None,\n    has_complex: bool = False,\n    *,\n    beta1: float,\n    beta2: float,\n    lr: Union[float, Tensor],\n    clip_lambda: Optional[Callable[[int], float]],\n    weight_decay: float,\n    decouple: bool,\n    eps: float,\n    maximize: bool,\n):\n    r\"\"\"Functional API that performs ADOPT algorithm computation.\"\"\"\n    # Respect when the user inputs False/True for foreach or fused. We only want to change\n    # the default when neither have been user-specified. Note that we default to foreach\n    # and pass False to use_fused. This is not a mistake--we want to give the fused impl\n    # bake-in time before making it the default, even if it is typically faster.\n    if fused is None and foreach is None:\n        _, foreach = _default_to_fused_or_foreach(\n            params, differentiable, use_fused=False\n        )\n        # Do not flip on foreach for the unsupported case where lr is a Tensor and capturable=False.\n        if foreach and isinstance(lr, Tensor) and not capturable:\n            foreach = False\n    if fused is None:\n        fused = False\n    if foreach is None:\n        foreach = False\n\n    # this check is slow during compilation, so we skip it\n    # if it's strictly needed we can add this check back in dynamo\n    if not torch._utils.is_compiling() and not all(\n        isinstance(t, torch.Tensor) for t in state_steps\n    ):\n        raise RuntimeError(\n            \"API has changed, `state_steps` argument must contain a list of singleton tensors\"\n        )\n\n    if foreach and torch.jit.is_scripting():\n        raise RuntimeError(\"torch.jit.script not supported with foreach optimizers\")\n    if fused and torch.jit.is_scripting():\n        raise RuntimeError(\"torch.jit.script not supported with fused optimizers\")\n\n    # if fused and not torch.jit.is_scripting():\n    #     func = _fused_adopt\n    # elif foreach and not torch.jit.is_scripting():\n    if foreach and not torch.jit.is_scripting():\n        func = _multi_tensor_adopt\n    else:\n        func = _single_tensor_adopt\n\n    func(\n        params,\n        grads,\n        exp_avgs,\n        exp_avg_sqs,\n        state_steps,\n        has_complex=has_complex,\n        beta1=beta1,\n        beta2=beta2,\n        lr=lr,\n        clip_lambda=clip_lambda,\n        weight_decay=weight_decay,\n        decouple=decouple,\n        eps=eps,\n        maximize=maximize,\n        capturable=capturable,\n        differentiable=differentiable,\n        grad_scale=grad_scale,\n        found_inf=found_inf,\n    )\n"
  },
  {
    "path": "src/axolotl/utils/quantization.py",
    "content": "\"\"\"\nUtilities for quantization including QAT and PTQ using torchao.\n\"\"\"\n\nimport torch\nfrom packaging import version\nfrom torchao.core.config import AOBaseConfig\nfrom torchao.prototype.qat import MXFakeQuantizeConfig\nfrom torchao.quantization import quantize_\nfrom torchao.quantization.qat import (\n    QATConfig,\n)\nfrom torchao.quantization.quant_api import (\n    Float8DynamicActivationFloat8WeightConfig,\n    Float8DynamicActivationInt4WeightConfig,\n    Int8DynamicActivationInt4WeightConfig,\n)\n\nfrom axolotl.utils.schemas.enums import TorchAOQuantDType\n\nquantization_config_to_str = {\n    Int8DynamicActivationInt4WeightConfig: \"int8int4\",\n    Float8DynamicActivationFloat8WeightConfig: \"fp8fp8\",\n    Float8DynamicActivationInt4WeightConfig: \"fp8int4\",\n}\n\nif version.parse(torch.__version__) >= version.parse(\"2.8.0\"):\n    try:\n        from torchao.prototype.mx_formats import NVFP4InferenceConfig\n\n        quantization_config_to_str[NVFP4InferenceConfig] = \"nvfp4\"\n    except (ImportError, RuntimeError):\n        pass\n\n    # int4 weight config imports will fail on machines with fbgemm-gpu installed\n    # without a CUDA runtime available so we do this safely\n    try:\n        from torchao.quantization.quant_api import Int4WeightOnlyConfig\n\n        quantization_config_to_str[Int4WeightOnlyConfig] = \"int4\"\n    except (ImportError, RuntimeError):\n        pass\n\n    try:\n        from torchao.prototype.qat import MXFakeQuantizeConfig\n\n        quantization_config_to_str[MXFakeQuantizeConfig] = \"mxfp4\"\n    except ImportError:\n        pass\n\n\ndef get_quantization_config(\n    weight_dtype: TorchAOQuantDType,\n    activation_dtype: TorchAOQuantDType | None = None,\n    group_size: int | None = None,\n) -> AOBaseConfig:\n    \"\"\"\n    This function is used to build a post-training quantization config.\n\n    Args:\n        weight_dtype: The dtype to use for weight quantization.\n        activation_dtype: The dtype to use for activation quantization.\n        group_size: The group size to use for weight quantization.\n\n    Returns:\n        The post-training quantization config.\n\n    Raises:\n        ValueError: If the activation dtype is not specified and the weight dtype is not int8 or int4,\n            or if the group size is not specified for int8 or int4 weight only quantization.\n    \"\"\"\n    if activation_dtype is None:\n        if weight_dtype == TorchAOQuantDType.int8:\n            raise ValueError(\"Int8WeightOnlyConfig is not supported by torchao QAT.\")\n        if weight_dtype == TorchAOQuantDType.int4:\n            from torchao.quantization.quant_api import Int4WeightOnlyConfig\n\n            if group_size is not None:\n                return Int4WeightOnlyConfig(group_size=group_size, version=2)\n            else:\n                return Int4WeightOnlyConfig(version=2)\n    if (\n        activation_dtype == TorchAOQuantDType.int4\n        and weight_dtype == TorchAOQuantDType.int4\n    ):\n        raise ValueError(\n            \"Int4DynamicActivationInt4WeightConfig is not supported by torchao QAT.\"\n        )\n    if (\n        activation_dtype == TorchAOQuantDType.int8\n        and weight_dtype == TorchAOQuantDType.int8\n    ):\n        raise ValueError(\n            \"Int8DynamicActivationInt8WeightConfig is not supported by torchao QAT.\"\n        )\n    if (\n        activation_dtype == TorchAOQuantDType.int8\n        and weight_dtype == TorchAOQuantDType.int4\n    ):\n        if group_size is not None:\n            return Int8DynamicActivationInt4WeightConfig(group_size=group_size)\n        else:\n            return Int8DynamicActivationInt4WeightConfig()\n    if (\n        activation_dtype == TorchAOQuantDType.float8_e4m3fn\n        and weight_dtype == TorchAOQuantDType.float8_e4m3fn\n    ):\n        return Float8DynamicActivationFloat8WeightConfig()\n    if (\n        activation_dtype == TorchAOQuantDType.float8_e4m3fn\n        and weight_dtype == TorchAOQuantDType.int4\n    ):\n        return Float8DynamicActivationInt4WeightConfig()\n    if weight_dtype == TorchAOQuantDType.nvfp4:\n        from torchao.prototype.mx_formats import NVFP4InferenceConfig\n\n        if group_size is not None and group_size != 16:\n            raise ValueError(\"NVFP4 quantization must use a group_size of 16\")\n        return NVFP4InferenceConfig()\n\n    if weight_dtype == TorchAOQuantDType.mxfp4:\n        from torchao.prototype.qat import MXFakeQuantizeConfig\n\n        # MXFP4 uses block_size=32 by default (vs NVFP4's 16)\n        block_size = group_size if group_size is not None else 32\n        if block_size != 32:\n            raise ValueError(\n                \"MXFP4 quantization must use a block_size (group_size) of 32\"\n            )\n\n        return MXFakeQuantizeConfig(dtype=torch.float4_e2m1fn_x2, block_size=block_size)\n\n    raise ValueError(\n        f\"Invalid activation/weight dtype combination: {activation_dtype}/{weight_dtype}\"\n    )\n\n\ndef quantize_model(\n    model,\n    weight_dtype: TorchAOQuantDType,\n    group_size: int | None = None,\n    activation_dtype: TorchAOQuantDType | None = None,\n    quantize_embedding: bool | None = None,\n):\n    \"\"\"\n    This function is used to quantize a model.\n\n    Args:\n        model: The model to quantize.\n        weight_dtype: The dtype to use for weight quantization.\n        group_size: The group size to use for weight quantization.\n        activation_dtype: The dtype to use for activation quantization.\n        quantize_embedding: Whether to quantize the model's embedding weights.\n\n    \"\"\"\n    linear_ptq_config = get_quantization_config(\n        weight_dtype=weight_dtype,\n        activation_dtype=activation_dtype,\n        group_size=group_size,\n    )\n    quantize_(model, linear_ptq_config)\n    if quantize_embedding:\n        # activation fake quantization is not supported for embedding layers\n        embedding_quantize_config = get_quantization_config(\n            weight_dtype=weight_dtype,\n            activation_dtype=None,\n            group_size=group_size,\n        )\n        quantize_(\n            model,\n            embedding_quantize_config,\n            filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding),\n        )\n\n\ndef prepare_model_for_qat(\n    model,\n    weight_dtype: TorchAOQuantDType,\n    group_size: int | None = None,\n    activation_dtype: TorchAOQuantDType | None = None,\n    quantize_embedding: bool = False,\n):\n    \"\"\"\n    This function is used to prepare a model for QAT by swapping the model's linear\n    layers with fake quantized linear layers, and optionally the embedding weights with\n    fake quantized embedding weights.\n\n    Args:\n        model: The model to quantize.\n        weight_dtype: The dtype to use for weight quantization.\n        group_size: The group size to use for weight quantization.\n        activation_dtype: The dtype to use for activation quantization.\n        quantize_embedding: Whether to quantize the model's embedding weights.\n\n    Raises:\n        ValueError: If the activation/weight dtype combination is invalid.\n    \"\"\"\n    base_config = get_quantization_config(\n        weight_dtype=weight_dtype,\n        activation_dtype=activation_dtype,\n        group_size=group_size,\n    )\n    if isinstance(base_config, MXFakeQuantizeConfig):\n        qat_config = QATConfig(\n            activation_config=base_config,\n            weight_config=base_config,\n        )\n    else:\n        qat_config = QATConfig(base_config)\n    quantize_(model, qat_config)\n    if quantize_embedding:\n        # activation fake quantization is not supported for embedding layers\n        embedding_base_config = get_quantization_config(\n            weight_dtype=weight_dtype,\n            activation_dtype=None,\n            group_size=group_size,\n        )\n        if isinstance(embedding_base_config, MXFakeQuantizeConfig):\n            embedding_qat_config = QATConfig(\n                weight_config=embedding_base_config,\n            )\n        else:\n            embedding_qat_config = QATConfig(embedding_base_config)\n        quantize_(\n            model,\n            embedding_qat_config,\n            filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding),\n        )\n\n\ndef convert_qat_model(\n    model,\n    quantize_embedding: bool = False,\n):\n    \"\"\"\n    This function converts a QAT model which has fake quantized layers back to the original model.\n    \"\"\"\n    config = QATConfig(step=\"convert\")\n    quantize_(model, config)\n    if quantize_embedding:\n        quantize_(\n            model,\n            config,\n            filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding),\n        )\n"
  },
  {
    "path": "src/axolotl/utils/samplers/__init__.py",
    "content": "\"\"\"\naxolotl samplers module\n\"\"\"\n\nfrom .multipack import MultipackBatchSampler  # noqa: F401\nfrom .utils import get_dataset_lengths  # noqa: F401\n"
  },
  {
    "path": "src/axolotl/utils/samplers/multipack.py",
    "content": "\"\"\"\nMultipack Batch Sampler - An efficient batch sampler for packing variable-length sequences\ninto fixed-capacity batches to optimize memory usage and training throughput.\n\"\"\"\n\nimport gc\nimport math\nimport os\nimport time\nfrom concurrent.futures import ProcessPoolExecutor\nfrom multiprocessing import cpu_count, get_context\nfrom typing import Iterable, Iterator, Union\n\nimport numba\nimport numpy as np\nfrom torch.utils.data import BatchSampler, Sampler, SequentialSampler\n\nfrom axolotl.utils.distributed import reduce_and_broadcast\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\n\n@numba.njit\ndef ffd_check(sequence_lengths: np.ndarray, bin_capacity: int, num_bins: int) -> bool:\n    \"\"\"First-fit-decreasing bin packing algorithm check.\n\n    Checks if sequences with the given lengths could fit in the specified number of\n    bins.\n\n    Args:\n        sequence_lengths: Array of sequence lengths.\n        bin_capacity: Maximum capacity of each bin.\n        num_bins: Number of bins available.\n\n    Returns:\n        `True` if all sequences can be packed, `False` otherwise.\n    \"\"\"\n    # Sort sequence lengths in descending order for optimal packing\n    sequence_lengths = np.sort(sequence_lengths)[::-1]\n    # Initialize all bins with full capacity\n    bins = np.full((num_bins,), bin_capacity, dtype=sequence_lengths.dtype)\n\n    # Try to place each sequence in the first bin it fits\n    for size in sequence_lengths:\n        not_found = True\n        for idx in range(num_bins):\n            if bins[idx] >= size:\n                bins[idx] -= size\n                not_found = False\n                break\n\n        # If no bin could fit this sequence, packing failed\n        if not_found:\n            return False\n\n    return True\n\n\n@numba.njit\ndef pack_group(\n    sequence_lengths: np.ndarray,\n    group_offset: int,\n    bin_capacity: int,\n    max_bins: int,\n    bin_size: int,\n    safe_mode: bool = True,\n) -> list[list[int]]:\n    \"\"\"Pack a group of sequences into bins using First-Fit Decreasing algorithm.\n\n    Args:\n        sequence_lengths: Array of sequence lengths.\n        group_offset: Offset to apply to indices when returning results.\n        bin_capacity: Maximum capacity of each bin.\n        max_bins: Maximum number of bins to use.\n        bin_size: Maximum number of sequences per bin.\n        safe_mode: If True, use a more conservative packing approach.\n\n    Returns:\n        List of bins, where each bin contains indices of sequences assigned to it.\n    \"\"\"\n    bins_remaining_space: list = []  # Tracks remaining capacity in each bin\n    bins_assigned_sequences: list = []  # Tracks sequence indices assigned to each bin\n\n    for seq_id, size in enumerate(sequence_lengths):\n        global_idx = seq_id + group_offset\n\n        # Try to place sequence in existing bins\n        add_new_bin = True\n        for bin_idx, _ in enumerate(bins_remaining_space):\n            if (\n                bins_remaining_space[bin_idx] >= size\n                and len(bins_assigned_sequences[bin_idx]) < bin_size\n            ):\n                bins_remaining_space[bin_idx] -= size\n                bins_assigned_sequences[bin_idx].append(global_idx)\n                add_new_bin = False\n                break\n\n        # Create a new bin if needed and if we haven't reached the limit\n        if add_new_bin:\n            if len(bins_remaining_space) >= max_bins and safe_mode:\n                # In safe mode, skip items that would exceed max_bins\n                continue\n            bins_remaining_space.append(bin_capacity - size)\n            bins_assigned_sequences.append([global_idx])\n\n            # Safety check to avoid infinite bins\n            if len(bins_remaining_space) > len(sequence_lengths):\n                break\n\n    return bins_assigned_sequences\n\n\ndef _process_group(\n    args: tuple[np.ndarray, int, int, int, int, bool],\n) -> list[list[int]]:\n    \"\"\"Standalone function for multiprocessing.\"\"\"\n    group_lengths, start_idx, bin_capacity, max_bins, bin_size, safe_mode = args\n    return pack_group(\n        group_lengths, start_idx, bin_capacity, max_bins, bin_size, safe_mode\n    )\n\n\ndef pack_parallel(\n    sequence_lengths: np.ndarray,\n    bin_capacity: int,\n    group_size: int,\n    bin_size: int,\n    num_processes: int | None = None,\n    safe_mode: bool = True,\n    mp_start_method: str | None = \"fork\",\n) -> list[list[int]]:\n    \"\"\"Pack sequences into bins using parallel processing.\n\n    Args:\n        sequence_lengths: Array of sequence lengths.\n        bin_capacity: Maximum capacity of each bin as total number of tokens.\n        group_size: Number of sequences to process in each group.\n        bin_size: Maximum number of bins to use.\n        num_processes: Number of parallel processes to use.\n        safe_mode: If True, use a more conservative packing approach.\n        mp_start_method: Multiprocessing start method ('fork', 'spawn', 'forkserver').\n                         'spawn' is often safer with Numba/PyTorch.\n                         Set to None to use system default.\n    Returns:\n        List of bins, where each bin contains indices of sequences assigned to it.\n    \"\"\"\n    num_items = len(sequence_lengths)\n    if num_processes is None:\n        num_processes = max(1, min(num_items // group_size, cpu_count(), 16))\n\n    # Create tasks for parallel processing\n    tasks = []\n    for i in range(0, num_items, group_size):\n        group_lengths = sequence_lengths[i : i + group_size]\n        max_bins = len(group_lengths)  # Allow as many bins as items in the group\n        tasks.append((group_lengths, i, bin_capacity, max_bins, bin_size, safe_mode))\n\n    # Process groups in parallel\n    all_bins = []\n\n    mp_ctx = None\n    if mp_start_method:\n        try:\n            mp_ctx = get_context(mp_start_method)\n        except ValueError:\n            LOG.warning(\n                f\"Failed to get multiprocessing context '{mp_start_method}'. \"\n                f\"Falling back to default. Available: {get_context().get_all_start_methods()}\"\n            )\n            mp_ctx = (\n                None  # Fallback to default context if specified one is not available\n            )\n\n    if num_processes == 1:\n        LOG.debug(\"Using single process for pack_parallel, running sequentially.\")\n        for task_args in tasks:\n            group_bins = _process_group(task_args)\n            all_bins.extend(group_bins)\n    else:\n        # Use ProcessPoolExecutor only if num_processes > 1\n        # Pass mp_context if available\n        with ProcessPoolExecutor(\n            max_workers=num_processes, mp_context=mp_ctx\n        ) as executor:\n            for group_bins in executor.map(_process_group, tasks):\n                all_bins.extend(group_bins)\n\n    return all_bins\n\n\n@numba.njit\ndef allocate_sequentially(\n    sequence_lengths: np.ndarray, rank: int, bin_capacity: int, num_ranks: int\n) -> tuple[list[list[int]], int, int]:\n    \"\"\"Sequential allocator that preserves example order.\n\n    Args:\n        sequence_lengths: The lengths of all examples.\n        rank: The current rank (for distributed training).\n        bin_capacity: The capacity of each bin (maximum sequence length).\n        num_ranks: Number of ranks (processes / GPUs).\n\n    Returns:\n        rank_batches: List of batches for the current rank.\n        total_tokens_used: Number of actual example tokens.\n        total_token_slots: Maximum theoretical number of example tokens (number of bins\n            * bin capacity).\n    \"\"\"\n    result = []\n    total_used = 0\n\n    # First, do sequential packing into bins\n    all_bins = []\n    current_bin = [0 for i in range(0)]  # numba hint\n    remaining_capacity = bin_capacity\n\n    for idx, size in enumerate(sequence_lengths):\n        if size <= remaining_capacity:\n            # Example fits in current bin\n            current_bin.append(idx)\n            remaining_capacity -= size\n            total_used += size\n        else:\n            # Example doesn't fit, start a new bin\n            if current_bin:  # Add non-empty bin to all_bins\n                all_bins.append(current_bin)\n            current_bin = [idx]\n            remaining_capacity = bin_capacity - size\n            total_used += size\n\n    # Add the last bin if not empty\n    if current_bin:\n        all_bins.append(current_bin)\n\n    # Assign bins to ranks - each rank gets every n-th bin\n    for bin_idx in range(rank, len(all_bins), num_ranks):\n        result.append(all_bins[bin_idx])\n\n    return result, total_used, len(all_bins) * bin_capacity\n\n\nclass MultipackBatchSampler(BatchSampler):\n    \"\"\"Batch sampler class for efficient packing of variable-length sequences\n\n    This sampler packs sequences into fixed-capacity bins (batches) to maximize\n    GPU memory utilization and training throughput by reducing padding.\n\n    It supports both parallel packing (using FFD algorithm) and\n    sequential packing (preserving original sequence order).\n    \"\"\"\n\n    _batches: list[list[list[int]]] | None = None\n    _len_across_ranks: int | None = None\n\n    def __init__(\n        self,\n        sampler: Union[Sampler[int], Iterable[int]],\n        batch_size: int,  # Number of bins per batch\n        batch_max_len: int,  # Maximum sequence length (bin capacity)\n        lengths: np.ndarray,  # Sequence lengths\n        bin_size: int,  # The max number of samples that can be packed in a single bin\n        packing_efficiency_estimate: float = 1.0,  # Initial efficiency estimate\n        drop_last: bool = True,  # Whether to drop final batches (might be incomplete)\n        num_count_samples: int = 4,  # Number of times to estimate batch count\n        sequential: bool = False,  # Whether to use sequential packing\n        group_size: int = 100_000,  # Size of groups for parallel packing\n        num_processes: int | None = None,  # Number of processes for parallel packing\n        safe_mode: bool = True,  # Conservative packing to prevent training instability\n        mp_start_method: str = \"fork\",\n        **kwargs,\n    ):\n        super().__init__(sampler, batch_size, drop_last)\n        self.batch_size = batch_size\n        self.batch_max_len = batch_max_len\n        self.lengths = np.array(lengths, dtype=np.int32)\n        self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0\n        self.sequential = sequential\n        self.group_size = group_size\n        self.bin_size = bin_size\n        self.num_processes = num_processes\n        self.safe_mode = safe_mode\n        self.mp_start_method = mp_start_method\n\n        assert isinstance(self.lengths, np.ndarray)\n\n        self.epoch = 0\n\n        # Efficiency statistics tracking\n        self.total_tokens_used = 0\n        self.total_token_slots = 0\n\n        # The number of times to calculate batches to determine minimum packed dataset length\n        world_size = int(os.environ.get(\"WORLD_SIZE\", \"1\"))\n        self.num_count_samples = (\n            1 if world_size >= num_count_samples else num_count_samples\n        )\n\n        if self.sequential and not isinstance(sampler, SequentialSampler):\n            LOG.warning(\n                \"using sequential sample packing with non-sequential sampler, did you want to also enable curriculum_sampling?\"\n            )\n\n    def set_epoch(self, epoch: int):\n        \"\"\"Set the epoch number, used for reproducible shuffling across epochs\"\"\"\n        self.epoch = epoch\n        self._batches = None  # Invalidate batch cache\n\n    def generate_batches(self, set_stats: bool = False) -> list[list[list[int]]]:\n        \"\"\"Generate packed batches for training.\n\n        Args:\n            set_stats: Whether to update efficiency statistics.\n\n        Returns:\n            List of batches, where each batch contains multiple bins, and each bin\n                contains multiple sequence indices.\n        \"\"\"\n        if self._batches is not None:\n            return self._batches\n\n        # Get indices from the sampler\n        indices = [idx for idx in self.sampler]\n\n        # Get lengths of the selected sequences\n        lengths = self.lengths[indices]\n\n        # Pack sequences into bins using either sequential or parallel packing\n        if self.sequential:\n            bins, total_used, total_slots = allocate_sequentially(\n                lengths,\n                rank=0,\n                bin_capacity=self.batch_max_len,\n                num_ranks=1,\n            )\n            # Map bin indices back to original indices\n            bins = [[indices[b_idx] for b_idx in bin_indices] for bin_indices in bins]\n        else:\n            # Use parallel packing\n            num_processes = self.num_processes or 1\n            all_bins = pack_parallel(\n                lengths,\n                bin_capacity=self.batch_max_len,\n                group_size=self.group_size,\n                bin_size=self.bin_size or self.batch_max_len,\n                num_processes=min(4, num_processes) if num_processes else 4,\n                safe_mode=self.safe_mode,\n                mp_start_method=self.mp_start_method,\n            )\n\n            # Map bin indices back to original indices\n            bins = [\n                [indices[b_idx] for b_idx in bin_indices] for bin_indices in all_bins\n            ]\n\n            # Calculate efficiency statistics\n            total_used = lengths.sum()\n            total_slots = len(all_bins) * self.batch_max_len\n            del all_bins\n\n        # Group bins into batches (each batch contains batch_size bins)\n        batches = [\n            bins[i : i + self.batch_size] for i in range(0, len(bins), self.batch_size)\n        ]\n\n        # Drop last batch if requested and it's incomplete\n        if self.drop_last and len(batches[-1]) < self.batch_size:\n            batches = batches[:-1]\n            # Adjust total_slots if we dropped a batch\n            if not self.sequential:\n                total_slots -= (self.batch_size - len(batches[-1])) * self.batch_max_len\n\n        # Update statistics if requested\n        if set_stats:\n            self.total_tokens_used += total_used\n            self.total_token_slots += total_slots\n\n        self._batches = batches\n        gc.collect()\n        return batches\n\n    def __iter__(self) -> Iterator[list[list[int]]]:\n        \"\"\"Return an iterator over batches.\n\n        The batches are truncated to match the minimum number of batches across all\n        ranks to ensure distributed training balance.\n        \"\"\"\n        batches = self.generate_batches(set_stats=True)\n        if self._len_across_ranks:\n            # Truncate batches to ensure all ranks have the same number of batches\n            batches = batches[: self._len_across_ranks]\n        return iter(batches)\n\n    def efficiency(self) -> float:\n        \"\"\"Calculate the packing efficiency (ratio of tokens used to total token slots).\n        Higher is better - 1.0 would mean perfect packing with no wasted space.\n        \"\"\"\n        if self.total_token_slots == 0:\n            self.generate_batches(set_stats=True)\n        if self.total_token_slots == 0:\n            return 0.0\n        # Return a Python float instead of potentially a numpy float\n        return float(self.total_tokens_used / self.total_token_slots)\n\n    def gather_efficiency(self) -> float:\n        \"\"\"Gather and synchronize packing efficiency estimates across all distributed\n        ranks.\n\n        Returns:\n            A conservative efficiency estimate based on the measurements.\n        \"\"\"\n\n        def calc_sample_packing_eff_est(estimates: list[float]):\n            LOG.debug(f\"sample_packing_eff_est across ranks: {repr(estimates)}\")\n            # Use 99.7% of max observed efficiency as a safe estimate\n            max_eff = max(float(eff) for eff in estimates)\n            return math.floor(0.997 * max_eff)\n\n        # Gather efficiency from all ranks and apply the calculation function\n        sample_packing_actual_eff_all = reduce_and_broadcast(\n            lambda: float(self.efficiency()),\n            calc_sample_packing_eff_est,\n        )\n\n        # Quantize to 0.5% intervals for stability\n        sample_packing_eff_est = (\n            math.ceil(sample_packing_actual_eff_all * 200.0) / 200.0\n        )\n        return sample_packing_eff_est\n\n    def gather_len_batches(self, num: int) -> int:\n        \"\"\"Gather and synchronize batch counts across all distributed ranks. Returns\n        the minimum number of batches available on any rank.\n        \"\"\"\n\n        def calc_min_len(estimates: list[int]) -> int:\n            LOG.info(f\"gather_len_batches: {repr(estimates)}\")\n            return math.floor(min(estimates))\n\n        # Find minimum batch count across ranks to ensure balance\n        min_len_batches = reduce_and_broadcast(lambda: num, calc_min_len)\n        return min_len_batches\n\n    def __len__(self) -> int:\n        \"\"\"Return the total number of batches that will be yielded by this sampler.\n\n        This is calculated as the minimum number of batches available on any rank to\n        ensure balanced distributed training.\n        \"\"\"\n        if self._batches is None:\n            self._batches = self.generate_batches(set_stats=True)\n\n        if self._len_across_ranks is None:\n            # Sample multiple times to get stable estimate\n            _sampled_lens = []\n            for _ in range(self.num_count_samples):\n                self._batches = None  # Reset cached batches\n                # log timer for generating batches\n                start_time = time.time()\n                _sampled_lens.append(len(self.generate_batches(set_stats=False)))\n                LOG.debug(f\"generate_batches time: {time.time() - start_time}\")\n            len_batches = min(_sampled_lens)\n\n            # Gather minimum across all ranks\n            if self._len_across_ranks is None:\n                self._len_across_ranks = self.gather_len_batches(len_batches)\n            else:\n                self._len_across_ranks = min(\n                    self._len_across_ranks, self.gather_len_batches(len_batches)\n                )\n\n        return self._len_across_ranks\n"
  },
  {
    "path": "src/axolotl/utils/samplers/utils.py",
    "content": "\"\"\"\nhelper util to calculate dataset lengths\n\"\"\"\n\nimport numpy as np\n\n\ndef get_dataset_lengths(dataset, from_arrow=False):\n    if \"length\" in dataset.column_names:\n        lengths = np.array(dataset[\"length\"])\n    elif \"position_ids\" in dataset.column_names:\n        position_ids = dataset[\"position_ids\"]\n        lengths = np.array([x[-1] + 1 for x in position_ids])\n    else:\n        if from_arrow:\n            input_ids = dataset.data.column(\"input_ids\")\n            lengths = np.vectorize(len)(np.array(input_ids, dtype=object))\n        else:\n            input_ids = dataset[\"input_ids\"]\n            lengths = np.array([len(seq) for seq in input_ids])\n    return lengths\n"
  },
  {
    "path": "src/axolotl/utils/schedulers.py",
    "content": "\"\"\"Module for custom LRScheduler class\"\"\"\n\nimport math\nfrom functools import partial\nfrom typing import Sequence\n\nfrom torch import Tensor\nfrom torch.optim import Optimizer\nfrom torch.optim.lr_scheduler import LambdaLR, LRScheduler\n\n\nclass RexLR(LRScheduler):\n    \"\"\"\n    Reflected Exponential (REX) learning rate scheduler.\n\n    - Original implementation: https://github.com/IvanVassi/REX_LR\n    - Original license: Apache 2.0\n    - Based on: https://arxiv.org/abs/2107.04197\n\n    Args:\n        optimizer (torch.optim.Optimizer): The optimizer to schedule the learning rate for.\n        max_lr (float): The maximum learning rate.\n        min_lr (float): The minimum learning rate.\n        total_steps (int): The total number of training steps.\n        num_warmup_steps (int): The number of warmup steps.\n        last_step (int): The index of last step.\n    \"\"\"\n\n    def __init__(\n        self, optimizer, max_lr, min_lr, total_steps=0, num_warmup_steps=0, last_step=0\n    ):\n        if min_lr > max_lr:\n            raise ValueError(\n                f'Value of \"min_lr\" should be less than value of \"max_lr\". Got min_lr={min_lr} and max_lr={max_lr}'\n            )\n        if num_warmup_steps > total_steps:\n            raise ValueError(\n                f\"num_warmup_steps ({num_warmup_steps}) must be less than or equal to total_steps ({total_steps}).\"\n            )\n\n        self.min_lr = min_lr\n        self.max_lr = max_lr\n        self.total_steps = total_steps\n        self.num_warmup_steps = num_warmup_steps\n        self.last_step = max(last_step - 1, 0)\n\n        # Ensure each parameter group has an \"initial_lr\" key to avoid issues when resuming.\n        for group in optimizer.param_groups:\n            initial_lr = group[\"lr\"]\n            if isinstance(initial_lr, Tensor):\n                initial_lr = initial_lr.clone()\n            group.setdefault(\"initial_lr\", initial_lr)\n        # Pass self.last_step as last_epoch to the parent.\n        super().__init__(optimizer, last_epoch=self.last_step)\n\n    @property\n    def last_step(self):\n        return self.last_epoch\n\n    @last_step.setter\n    def last_step(self, value):\n        self.last_epoch = value\n\n    def get_lr(self):\n        # Warmup phase: if defined, increase lr linearly from 0 to max_lr.\n        if 1 <= self.last_step <= self.num_warmup_steps:\n            return [\n                base_lr * self.last_step / self.num_warmup_steps\n                for base_lr in self.base_lrs\n            ]\n\n        # Post-warmup phase: adjust step relative to the end of warmup.\n        step_after = self.last_step - self.num_warmup_steps\n        remaining_steps = self.total_steps - self.num_warmup_steps\n\n        # Avoid LR spiking\n        if step_after >= remaining_steps or step_after == -1 or remaining_steps <= 0:\n            return [self.min_lr for _ in self.base_lrs]\n\n        mod_iter = step_after % remaining_steps\n        z = (remaining_steps - mod_iter) / remaining_steps\n        rex_factor = self.min_lr / self.max_lr + (1.0 - self.min_lr / self.max_lr) * (\n            z / (0.1 + 0.9 * z)\n        )\n        return [base_lr * rex_factor for base_lr in self.base_lrs]\n\n\nclass InterpolatingLogScheduler(LRScheduler):\n    \"\"\"\n    A scheduler that interpolates learning rates in a logarithmic fashion\n    \"\"\"\n\n    def __init__(self, optimizer, num_steps, min_lr, max_lr, last_epoch=-1):\n        \"\"\"A scheduler that interpolates learning rates in a logarithmic fashion\n\n        Args:\n        - optimizer: pytorch optimizer\n        - num_steps: int, the number of steps over which to increase from the min_lr to the max_lr\n        - min_lr: float, the minimum learning rate\n        - max_lr: float, the maximum learning rate\n\n        Usage:\n            fc = nn.Linear(1,1)\n            optimizer = optim.Adam(fc.parameters())\n            lr_scheduler = InterpolatingLogScheduler(optimizer, num_steps=400, min_lr=1e-6, max_lr=1e-4)\n        \"\"\"\n        self.num_steps = num_steps\n        self.min_lr = min_lr\n        self.max_lr = max_lr\n        self.q = (max_lr / min_lr) ** (1 / (num_steps - 1))\n        super().__init__(optimizer, last_epoch)\n\n    def get_lr(self):\n        if self.last_epoch <= 0:\n            lrs = [self.min_lr for base_lr in self.base_lrs]\n        elif self.last_epoch < self.num_steps:\n            lrs = [\n                self.min_lr * (self.q ** (self.last_epoch - 1))\n                for base_lr in self.base_lrs\n            ]\n        else:\n            lrs = [self.max_lr for base_lr in self.base_lrs]\n\n        return lrs\n\n\ndef _get_cosine_schedule_with_quadratic_warmup_lr_lambda(\n    current_step: int,\n    *,\n    num_warmup_steps: int,\n    num_training_steps: int,\n    num_cycles: float,\n):\n    if current_step < num_warmup_steps:\n        return (float(current_step) / float(max(1, num_warmup_steps))) ** 2\n    progress = float(current_step - num_warmup_steps) / float(\n        max(1, num_training_steps - num_warmup_steps)\n    )\n    return max(\n        0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))\n    )\n\n\ndef get_cosine_schedule_with_quadratic_warmup(\n    optimizer: Optimizer,\n    num_warmup_steps: int,\n    num_training_steps: int,\n    num_cycles: float = 0.5,\n    last_epoch: int = -1,\n):\n    \"\"\"\n    Create a schedule with a learning rate that decreases following the values of the cosine function between the\n    initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the\n    initial lr set in the optimizer.\n\n    Args:\n        optimizer ([`~torch.optim.Optimizer`]):\n            The optimizer for which to schedule the learning rate.\n        num_warmup_steps (`int`):\n            The number of steps for the warmup phase.\n        num_training_steps (`int`):\n            The total number of training steps.\n        num_cycles (`float`, *optional*, defaults to 0.5):\n            The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0\n            following a half-cosine).\n        last_epoch (`int`, *optional*, defaults to -1):\n            The index of the last epoch when resuming training.\n\n    Return:\n        `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.\n    \"\"\"\n\n    lr_lambda = partial(\n        _get_cosine_schedule_with_quadratic_warmup_lr_lambda,\n        num_warmup_steps=num_warmup_steps,\n        num_training_steps=num_training_steps,\n        num_cycles=num_cycles,\n    )\n    return LambdaLR(optimizer, lr_lambda, last_epoch)\n\n\ndef _get_cosine_schedule_with_min_lr_lambda(\n    current_step: int,\n    *,\n    num_warmup_steps: int,\n    num_training_steps: int,\n    min_lr_ratio: float,\n):\n    # Warm up\n    if current_step < num_warmup_steps:\n        return float(current_step) / float(max(1, num_warmup_steps))\n\n    # Cosine learning rate decay\n    progress = float(current_step - num_warmup_steps) / float(\n        max(1, num_training_steps - num_warmup_steps)\n    )\n    scaling = 0.5 * (1.0 + math.cos(math.pi * progress))\n    return (1 - min_lr_ratio) * scaling + min_lr_ratio\n\n\ndef get_cosine_schedule_with_min_lr(\n    optimizer: Optimizer,\n    num_warmup_steps: int,\n    num_training_steps: int,\n    min_lr_ratio: float = 0.0,\n):\n    \"\"\"\n    Create a learning rate schedule which has:\n        - linear warmup from 0 -> `max_lr` over `num_warmup_steps`\n        - cosine learning rate annealing from `max_lr` -> `min_lr` over `num_training_steps`\n    \"\"\"\n\n    lr_lambda = partial(\n        _get_cosine_schedule_with_min_lr_lambda,\n        num_warmup_steps=num_warmup_steps,\n        num_training_steps=num_training_steps,\n        min_lr_ratio=min_lr_ratio,\n    )\n    return LambdaLR(optimizer, lr_lambda)\n\n\ndef _get_cosine_schedule_with_warmup_decay_constant_lr_lambda(\n    current_step: int,\n    *,\n    num_warmup_steps: int,\n    num_training_steps: int,\n    constant_lr_ratio: float,\n    min_lr_ratio: float,\n    num_cycles: float,\n):\n    if current_step < num_warmup_steps:\n        return float(current_step) / float(max(1, num_warmup_steps))\n\n    num_constant_steps = int(num_training_steps * constant_lr_ratio)\n    current_step = min(current_step, num_constant_steps)\n\n    progress = float(current_step - num_warmup_steps) / float(\n        max(1, num_constant_steps - num_warmup_steps)\n    )\n\n    return (\n        max(\n            0,\n            (1 - min_lr_ratio)\n            * 0.5\n            * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)),\n        )\n        + min_lr_ratio\n    )\n\n\ndef get_cosine_schedule_with_warmup_decay_constant(\n    optimizer: Optimizer,\n    num_warmup_steps: int,\n    num_training_steps: int,\n    constant_lr_ratio: float,\n    min_lr_ratio: float,\n    num_cycles: float = 0.5,\n    last_epoch: int = -1,\n):\n    \"\"\"\n    Implementation of Continual Pre-Training of Large Language Models: How to (re)warm your model? (https://arxiv.org/pdf/2308.04014.pdf)\n    Create a schedule with a learning rate that decreases following the values of the cosine function between the\n    initial lr set in the optimizer to min_lr_ratio until num_training_steps * constant_lr_ratio, after constant_rate returns constant value of min_rate\n    , after a warmup period during which it increases linearly between 0 and the initial lr set in the optimizer.\n\n    Args:\n        optimizer ([`~torch.optim.Optimizer`]):\n            The optimizer for which to schedule the learning rate.\n        num_warmup_steps (`int`):\n            The number of steps for the warmup phase.\n        num_training_steps (`int`):\n            The total number of training steps.\n        constant_lr_ratio: (`float`):\n            The ratio of num_training_steps to decrease by cosine function.\n        min_lr_ratio: (`float):\n            The ratio of maximum learning rate for cosine function to decay to minimum learning rate.\n        num_cycles (`float`, *optional*, defaults to 0.5):\n            The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0\n            following a half-cosine).\n        last_epoch (`int`, *optional*, defaults to -1):\n            The index of the last epoch when resuming training.\n\n    Return:\n        `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.\n    \"\"\"\n\n    lr_lambda = partial(\n        _get_cosine_schedule_with_warmup_decay_constant_lr_lambda,\n        num_warmup_steps=num_warmup_steps,\n        num_training_steps=num_training_steps,\n        constant_lr_ratio=constant_lr_ratio,\n        min_lr_ratio=min_lr_ratio,\n        num_cycles=num_cycles,\n    )\n    return LambdaLR(optimizer, lr_lambda, last_epoch)\n\n\nclass JaggedLRRestartScheduler(LRScheduler):\n    \"\"\"Wraps another scheduler to apply per-lora-restart learning rate warmups.\"\"\"\n\n    def __init__(\n        self,\n        optimizer: Optimizer,\n        inner_schedule: LRScheduler,\n        jagged_restart_steps: int,\n        jagged_restart_warmup_steps: int,\n        jagged_restart_anneal_steps: int = 1,\n        min_lr_scale: float = 0.001,\n    ) -> None:\n        self.inner_schedule = inner_schedule\n        self.restarts_steps = jagged_restart_steps\n        self.warmup_steps = jagged_restart_warmup_steps\n        self.anneal_steps = jagged_restart_anneal_steps\n        self.min_lr_scale = min_lr_scale\n        super().__init__(optimizer, inner_schedule.last_epoch)\n\n    def get_lr(self) -> float | Sequence[float]:\n        self.inner_schedule.last_epoch = self.last_epoch\n\n        original = self.inner_schedule.get_lr()\n        step = self.last_epoch\n\n        if step < self.restarts_steps - self.anneal_steps:\n            scale = 1\n        else:\n            per_restart_progress = step % self.restarts_steps\n            if per_restart_progress < self.warmup_steps:\n                cycle_t = min(1.0, (per_restart_progress) / self.warmup_steps)\n            elif per_restart_progress > (self.restarts_steps - self.anneal_steps):\n                cycle_t = min(\n                    1.0,\n                    (self.restarts_steps - per_restart_progress) / self.anneal_steps,\n                )\n            else:\n                cycle_t = 1\n            scale = cycle_t * (1 - self.min_lr_scale) + self.min_lr_scale\n\n        if isinstance(original, Sequence):\n            return [lr * scale for lr in original]\n\n        return original * scale\n"
  },
  {
    "path": "src/axolotl/utils/schemas/__init__.py",
    "content": ""
  },
  {
    "path": "src/axolotl/utils/schemas/config.py",
    "content": "\"\"\"Module with Pydantic models for configuration.\"\"\"\n\nfrom typing import Annotated, Any, Literal\n\nfrom accelerate.utils import is_fp8_available\nfrom annotated_types import MinLen\nfrom packaging import version\nfrom pydantic import (\n    BaseModel,\n    Field,\n    StringConstraints,\n    field_serializer,\n    model_validator,\n)\n\nfrom axolotl.utils.datasets import get_default_process_count\nfrom axolotl.utils.logging import get_logger\nfrom axolotl.utils.schemas.datasets import (\n    DatasetConfig,\n    DPODataset,\n    KTODataset,\n    PretrainingDataset,\n    SFTDataset,\n    StepwiseSupervisedDataset,\n)\nfrom axolotl.utils.schemas.deprecated import DeprecatedParameters, RemappedParameters\nfrom axolotl.utils.schemas.dynamic_checkpoint import DynamicCheckpointConfig\nfrom axolotl.utils.schemas.enums import ChatTemplate, RingAttnFunc, RLType\nfrom axolotl.utils.schemas.fsdp import FSDPConfig\nfrom axolotl.utils.schemas.integrations import (\n    CometConfig,\n    GradioConfig,\n    LISAConfig,\n    MLFlowConfig,\n    OpenTelemetryConfig,\n    RayConfig,\n    TrackioConfig,\n    WandbConfig,\n)\nfrom axolotl.utils.schemas.internal import EnvCapabilities, GPUCapabilities\nfrom axolotl.utils.schemas.model import (\n    ModelInputConfig,\n    ModelOutputConfig,\n    SpecialTokensConfig,\n)\nfrom axolotl.utils.schemas.multimodal import MultiModalConfig\nfrom axolotl.utils.schemas.peft import LoraConfig, ReLoRAConfig\nfrom axolotl.utils.schemas.quantization import PTQConfig, QATConfig\nfrom axolotl.utils.schemas.training import HyperparametersConfig, JaggedLRConfig\nfrom axolotl.utils.schemas.trl import TRLConfig\nfrom axolotl.utils.schemas.validation import ValidationMixin\nfrom axolotl.utils.schemas.vllm import VllmConfig\n\nLOG = get_logger(__name__)\n\n\nclass AxolotlInputConfig(\n    ModelInputConfig,\n    ModelOutputConfig,\n    LoraConfig,\n    ReLoRAConfig,\n    JaggedLRConfig,\n    HyperparametersConfig,\n    WandbConfig,\n    MLFlowConfig,\n    CometConfig,\n    TrackioConfig,\n    OpenTelemetryConfig,\n    LISAConfig,\n    GradioConfig,\n    RayConfig,\n    MultiModalConfig,\n    RemappedParameters,\n    DeprecatedParameters,\n    ValidationMixin,\n    BaseModel,\n):\n    \"\"\"Wrapper of all config options.\"\"\"\n\n    model_config = {\"populate_by_name\": True}\n\n    strict: bool | None = Field(\n        default=False,\n        json_schema_extra={\"description\": \"Allow overwrite yml config using from cli\"},\n    )\n    resume_from_checkpoint: str | None = Field(\n        default=None,\n        json_schema_extra={\"description\": \"Resume from a specific checkpoint dir\"},\n    )\n    auto_resume_from_checkpoints: bool | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"If resume_from_checkpoint isn't set and you simply want it to start where it left off. Be careful with this being turned on between different models.\"\n        },\n    )\n    resize_token_embeddings_to_32x: bool | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Resize the model embeddings when new tokens are added to multiples of 32. This is reported to improve training speed on some models\"\n        },\n    )\n    mean_resizing_embeddings: bool | None = False\n    # optionally shrink the embeddings when the tokenizer vocab size is smaller\n    shrink_embeddings: bool | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Whether to shrink the embeddings to len(tokenizer). By default, we won't shrink.\"\n        },\n    )\n    embeddings_skip_upcast: bool | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Don't upcast the embeddings to float32 when using PEFT. Useful for low-VRAM GPUs\"\n        },\n    )\n    reinit_weights: bool | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Reinitialize model weights randomly instead of loading pretrained weights\"\n        },\n    )\n\n    trainer_cls: str | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"module to custom trainer class to use for training\"\n        },\n    )\n\n    rl: RLType | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Use RL training: 'dpo', 'ipo', 'kto', 'simpo', 'orpo', 'grpo'\"\n        },\n    )\n    trl: TRLConfig | None = Field(\n        default_factory=lambda: TRLConfig(),\n    )\n    vllm: VllmConfig | None = Field(\n        default_factory=lambda: VllmConfig(),\n    )\n    qat: QATConfig | None = None\n    quantization: PTQConfig | None = None\n    reward_model: bool | None = Field(\n        default=None,\n        json_schema_extra={\"description\": \"Reward modelling: `True` or `False`\"},\n    )\n    dynamic_checkpoint: DynamicCheckpointConfig | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Configuration for dynamic checkpointing (trigger by file or signal). \"\n            \"Set 'enabled: true' to activate this feature.\"\n        },\n    )\n    process_reward_model: bool | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Process reward modelling: `True` or `False`\"\n        },\n    )\n    center_rewards_coefficient: float | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Coefficient to incentivize the reward model to output mean-zero rewards (proposed by https://huggingface.co/papers/2312.09244, Eq. 2). Recommended value: `0.01`.\"\n        },\n    )\n    num_labels: int | None = None\n    # Whether to use weighting in DPO trainer.\n    # If `None`, default is `False` in the trainer.\n    dpo_use_weighting: bool | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Whether to perform weighting in DPO trainer\"\n        },\n    )\n    dpo_label_smoothing: float | None = None\n    dpo_norm_loss: bool | None = None\n\n    dpo_use_liger_kernel: bool | None = Field(\n        default=None,\n        json_schema_extra={\"description\": \"Whether to use Liger kernel for DPO loss.\"},\n    )\n\n    dpo_padding_free: bool | None = None\n\n    datasets: (\n        Annotated[\n            list[SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset],\n            MinLen(1),\n        ]\n        | None\n    ) = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"A list of one or more datasets to finetune the model with\"\n        },\n    )\n\n    test_datasets: (\n        Annotated[\n            list[SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset],\n            MinLen(1),\n        ]\n        | None\n    ) = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"A list of one or more datasets to eval the model with. You can use either test_datasets, or val_set_size, but not both.\"\n        },\n    )\n    shuffle_merged_datasets: bool | None = Field(\n        default=True,\n        json_schema_extra={\n            \"description\": \"If false, the datasets will not be shuffled and will keep their original order in `datasets`. The same applies to the `test_datasets` option and the `pretraining_dataset` option. Default is true.\"\n        },\n    )\n    shuffle_before_merging_datasets: bool | None = Field(\n        default=False,\n        json_schema_extra={\n            \"description\": \"If true, each dataset in `datasets` will be shuffled before merging. This allows curriculum learning strategies to be applied at the dataset level. Default is false.\"\n        },\n    )\n    dataset_prepared_path: str | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Axolotl attempts to save the dataset as an arrow after packing the data together so subsequent training attempts load faster, relative path\"\n        },\n    )\n    dataset_shard_num: int | None = Field(\n        default=None, json_schema_extra={\"description\": \"Num shards for whole dataset\"}\n    )\n    dataset_shard_idx: int | None = Field(\n        default=None,\n        json_schema_extra={\"description\": \"Index of shard to use for whole dataset\"},\n    )\n    skip_prepare_dataset: bool | None = False\n    num_dataset_shards_to_save: int | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Number of shards to save the prepared dataset\"\n        },\n    )\n\n    pretraining_dataset: (\n        Annotated[list[PretrainingDataset | SFTDataset], MinLen(1)] | None\n    ) = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Set to HF dataset for type: 'completion' for streaming instead of pre-tokenize\"\n        },\n    )\n    dataset_processes: int | None = Field(\n        default=None,\n        deprecated=\"Use `dataset_num_proc` instead. This parameter will be removed in a future version.\",\n        json_schema_extra={\n            \"description\": (\n                \"The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()` if not set.\\n\"\n                \"For Runpod VMs, it will default to number of vCPUs via RUNPOD_CPU_COUNT.\"\n            )\n        },\n    )\n    dataset_num_proc: int | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": (\n                \"The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()` if not set.\\n\"\n                \"For Runpod VMs, it will default to number of vCPUs via RUNPOD_CPU_COUNT.\"\n            )\n        },\n    )\n\n    dataset_exact_deduplication: bool | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Deduplicates datasets and test_datasets with identical entries\"\n        },\n    )\n    dataset_keep_in_memory: bool | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Keep dataset in memory while preprocessing. Only needed if cached dataset is taking too much storage\"\n        },\n    )\n    dataloader_pin_memory: bool | None = None\n    dataloader_num_workers: int | None = None\n    dataloader_prefetch_factor: int | None = None\n    dataloader_drop_last: bool | None = None\n\n    accelerator_config: dict[str, Any] | None = None\n\n    remove_unused_columns: bool | None = None\n\n    push_dataset_to_hub: str | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Push prepared dataset to hub - repo_org/repo_name\"\n        },\n    )\n    hf_use_auth_token: bool | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets. Required to be true when used in combination with `push_dataset_to_hub`\"\n        },\n    )\n\n    device: Any | None = None\n    device_map: Any | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Passed through to transformers when loading the model when launched without accelerate. Use `sequential` when training w/ model parallelism to limit memory\"\n        },\n    )\n    world_size: int | None = None\n    local_rank: int | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Don't mess with this, it's here for accelerate and torchrun\"\n        },\n    )\n    ddp: bool | None = None\n\n    seed: int | None = Field(\n        default=None, json_schema_extra={\"description\": \"Seed for reproducibility\"}\n    )\n    ddp_timeout: int | None = Field(\n        default=None,\n        json_schema_extra={\"description\": \"Advanced DDP Arguments - timeout\"},\n    )\n    ddp_bucket_cap_mb: int | None = Field(\n        default=None,\n        json_schema_extra={\"description\": \"Advanced DDP Arguments - bucket cap in MB\"},\n    )\n    ddp_broadcast_buffers: bool | None = Field(\n        default=None,\n        json_schema_extra={\"description\": \"Advanced DDP Arguments - broadcast buffers\"},\n    )\n    ddp_find_unused_parameters: bool | None = None\n\n    do_causal_lm_eval: bool | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Whether to run causal language model evaluation for metrics in `eval_causal_lm_metrics`\"\n        },\n    )\n    eval_causal_lm_metrics: list[str] | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"HF evaluate metrics used during evaluation. Default is ['sacrebleu', 'comet', 'ter', 'chrf', 'perplexity']\"\n        },\n    )\n    do_bench_eval: bool | None = None\n    bench_dataset: str | None = None\n    bench_split: str | None = None\n    metric_for_best_model: str | None = None\n    greater_is_better: bool | None = None\n\n    loss_watchdog_threshold: float | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training)\"\n        },\n    )\n    loss_watchdog_patience: int | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Number of high-loss steps in a row before the trainer aborts (default: 3)\"\n        },\n    )\n\n    gc_steps: int | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Run garbage collection every `gc_steps` steps. -1 will run on epoch end and before evaluations. Default is 0 (disabled).\"\n        },\n    )\n\n    bf16: Literal[\"auto\"] | bool | None = Field(\n        default=\"auto\",\n        json_schema_extra={\n            \"description\": \"Use CUDA bf16. bool or 'full' for `bf16_full_eval`, or 'auto' for automatic detection. require >=ampere\"\n        },\n    )\n    fp16: bool | None = Field(\n        default=None, json_schema_extra={\"description\": \"Use CUDA fp16\"}\n    )\n    fp8: bool | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Enable FP8 mixed precision training using TorchAO. Best \"\n            \"used in combination with torch.compile.\"\n        },\n    )\n    fp8_enable_fsdp_float8_all_gather: bool | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Enable FSDP float8 all-gather optimization for FP8 training. Can \"\n            \"improve training speed by 10-15% when FSDP is enabled.\"\n        },\n    )\n    bfloat16: bool | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"No AMP (automatic mixed precision) - require >=ampere\"\n        },\n    )  # for non-AMP cases\n    float16: bool | None = Field(\n        default=None,\n        json_schema_extra={\"description\": \"No AMP (automatic mixed precision)\"},\n    )  # for non-AMP cases\n    tf32: Literal[\"auto\"] | bool | None = Field(\n        default=\"auto\",\n        json_schema_extra={\n            \"description\": \"bool to use CUDA tf32 or 'auto' for automatic detection - require >=ampere\"\n        },\n    )\n    float32: bool | None = None\n\n    gradient_checkpointing: Literal[\"offload\", \"offload_disk\"] | bool | None = Field(\n        default=False,\n        json_schema_extra={\n            \"description\": \"Whether to use gradient checkpointing. Available options are: true, false, 'offload', 'offload_disk'. https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing\"\n        },\n    )\n    gradient_checkpointing_kwargs: dict[str, Any] | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Additional kwargs to pass to the trainer for gradient checkpointing\"\n        },\n    )\n    activation_offloading: Literal[\"legacy\", \"disk\"] | bool | None = Field(\n        default=False,\n        json_schema_extra={\n            \"description\": \"Whether to offload activations. Available options are: true, false, 'legacy', 'disk'.\"\n        },\n    )\n\n    unfrozen_parameters: list[str] | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"List of regex patterns for parameter names to keep unfrozen. \"\n            \"All other parameters will be frozen via requires_grad=False. \"\n            \"Note: range-based patterns (e.g. embed_tokens.weight$[:32000]) use gradient \"\n            \"zeroing rather than a true freeze, so weight decay will still apply to the \"\n            \"frozen portion and optimizer states are allocated for the full parameter.\"\n        },\n    )\n\n    sequence_len: int = Field(\n        default=512,\n        json_schema_extra={\n            \"description\": \"The maximum length of an input to train with, this should typically be less than 2048 as most models have a token/context limit of 2048\"\n        },\n    )\n    excess_length_strategy: Literal[\"drop\", \"truncate\", \"raise\"] | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"What to do when a tokenized row exceeds sequence_len. 'drop' removes the row; 'truncate' slices tensors to sequence_len; 'raise' raises a ValueError. Defaults to 'drop' for backward compatibility.\"\n        },\n    )\n    eval_sequence_len: int | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"The maximum length of an input for evaluation. If not specified, defaults to sequence_len\"\n        },\n    )\n    min_sample_len: int | None = None\n    max_prompt_len: int | None = Field(\n        default=None,\n        json_schema_extra={\"description\": \"maximum prompt length for RL training\"},\n    )\n    sample_packing: bool | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Use efficient multi-packing with block diagonal attention and per sequence position_ids. Recommend set to 'true'\"\n        },\n    )\n    sample_packing_group_size: int | None = Field(\n        default=100_000,\n        json_schema_extra={\n            \"description\": \"The number of samples packed at a time. Increasing the following values helps with packing, but usually only slightly (<%1.)\"\n        },\n    )\n    sample_packing_bin_size: int | None = Field(\n        default=200,\n        json_schema_extra={\n            \"description\": \"The number of samples which can be packed into one sequence. Increase if using a large sequence_len with many short samples.\"\n        },\n    )\n    sample_packing_sequentially: bool | None = Field(\n        default=None,\n        json_schema_extra={\"description\": \"Whether to pack samples sequentially\"},\n    )\n    sample_packing_mp_start_method: str | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"The multiprocessing start method to use for packing. Should be 'fork', 'spawn' or 'forkserver'\"\n        },\n    )\n    eval_sample_packing: bool | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Set to 'false' if getting errors during eval with sample_packing on\"\n        },\n    )\n    pad_to_sequence_len: bool | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Pad inputs so each step uses constant sized buffers. This will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently. Defaults to True if `sample_packing` enabled\"\n        },\n    )\n    curriculum_sampling: bool | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Whether to use sequential sampling for curriculum learning\"\n        },\n    )\n    multipack_real_batches: bool | None = None\n\n    batch_flattening: Literal[\"auto\"] | bool | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Use batch flattening for speedups when not using sample_packing\"\n        },\n    )\n\n    # for PoSE context length extension\n    use_pose: bool | None = None\n    pose_split_on_token_ids: list[int] | None = None\n    pose_max_context_len: int | None = None\n    pose_num_chunks: int | None = None\n\n    # Deprecated: Use streaming_multipack_buffer_size instead\n    pretrain_multipack_buffer_size: int | None = Field(\n        default=None,\n        deprecated=\"Deprecated in v0.13.0, will be removed in v0.14.0. Use streaming_multipack_buffer_size instead\",\n    )\n    pretrain_multipack_attn: bool | None = Field(\n        default=True,\n        json_schema_extra={\n            \"description\": \"whether to prevent cross attention for packed sequences during pretraining\",\n        },\n    )\n    pretraining_sample_concatenation: bool | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"whether to concatenate samples during pretraining\",\n        },\n    )\n\n    streaming: bool | None = Field(\n        default=None,\n        json_schema_extra={\"description\": \"Use streaming mode for loading datasets\"},\n    )\n    streaming_multipack_buffer_size: int | None = Field(\n        default=10_000,\n        json_schema_extra={\n            \"description\": \"Buffer size for multipack streaming datasets\"\n        },\n    )\n\n    xformers_attention: bool | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Whether to use xformers attention patch https://github.com/facebookresearch/xformers\"\n        },\n    )\n    sdp_attention: bool | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Whether to use scaled-dot-product attention https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html\"\n        },\n    )\n    s2_attention: bool | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Shifted-sparse attention (only llama) - https://arxiv.org/pdf/2309.12307.pdf\"\n        },\n    )\n    flex_attention: bool | None = None\n    flex_attn_compile_kwargs: dict[str, Any] | None = None\n    flash_attention: bool | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Whether to use flash attention patch https://github.com/Dao-AILab/flash-attention\"\n        },\n    )\n    flash_attn_cross_entropy: bool | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Whether to use flash-attention cross entropy implementation - advanced use only\"\n        },\n    )\n    flash_attn_rms_norm: bool | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Whether to use flash-attention rms norm implementation - advanced use only\"\n        },\n    )\n    flash_attn_fuse_mlp: bool | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Whether to fuse part of the MLP into a single operation\"\n        },\n    )\n    flash_optimum: bool | None = Field(\n        default=None,\n        json_schema_extra={\"description\": \"Whether to use bettertransformers\"},\n    )\n    sage_attention: bool | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Whether to use SageAttention https://github.com/thu-ml/SageAttention\"\n        },\n    )\n\n    eager_attention: bool | None = None\n\n    attn_implementation: str | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Specify a custom attention implementation, used mostly for kernels.\"\n        },\n    )\n\n    experts_implementation: str | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Which experts implementation to use for MoE models,\"\n        },\n    )\n\n    quantize_moe_experts: bool = Field(\n        default=False,\n        json_schema_extra={\n            \"description\": \"Quantize MoE expert weights on load to reduce VRAM. \"\n            \"Requires adapter (lora/qlora) with load_in_4bit or load_in_8bit. \"\n            \"Requires CUDA (not compatible with ROCm or other backends). \"\n            \"Note: total parameter count may be reported incorrectly when enabled \"\n            \"(trainable param count is correct).\"\n        },\n    )\n\n    scaling_softmax: bool | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Whether to use Scaled Softmax (SSMax) attention. Ref: https://arxiv.org/abs/2501.19399\"\n        },\n    )\n    scaling_softmax_factor: float | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Scaling factor for SSMax attention. Default is 0.43\"\n        },\n    )\n    scaling_softmax_bias: float | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Bias for SSMax attention. Default is 0.0. Note: The paper recommends bias=0 for better length generalization.\"\n        },\n    )\n\n    unsloth_cross_entropy_loss: bool | None = None\n    unsloth_lora_mlp: bool | None = None\n    unsloth_lora_qkv: bool | None = None\n    unsloth_lora_o: bool | None = None\n    unsloth_rms_norm: bool | None = None\n    unsloth_rope: bool | None = None\n\n    lora_mlp_kernel: bool | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Apply custom LoRA autograd functions and activation function Triton kernels for speed and memory savings. See: https://docs.axolotl.ai/docs/lora_optims.html\"\n        },\n    )\n    lora_qkv_kernel: bool | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Apply custom LoRA autograd functions and activation function Triton kernels for speed and memory savings. See: https://docs.axolotl.ai/docs/lora_optims.html\"\n        },\n    )\n    lora_o_kernel: bool | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Apply custom LoRA autograd functions and activation function Triton kernels for speed and memory savings. See: https://docs.axolotl.ai/docs/lora_optims.html\"\n        },\n    )\n\n    chunked_cross_entropy: bool | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Whether to use chunked cross entropy loss for memory efficiency\"\n        },\n    )\n    chunked_cross_entropy_num_chunks: int | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Number of chunks to use for chunked cross entropy loss\"\n        },\n    )\n    use_eaft: bool | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Enable Entropy-Aware Focal Training loss (EAFT)\"\n        },\n    )\n    eaft_alpha: float | None = Field(\n        default=1.0,\n        json_schema_extra={\n            \"description\": \"Exponent for entropy weighting in EAFT (default: 1.0)\"\n        },\n    )\n    eaft_k: int | None = Field(\n        default=20,\n        json_schema_extra={\n            \"description\": \"Number of top logits for entropy approximation (default: 20)\"\n        },\n    )\n\n    tiled_mlp: bool | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Whether to use ALST tiled mlp for memory efficient long context\"\n        },\n    )\n\n    tiled_mlp_num_shards: int | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Number of shards to use for ALST tiled mlp. If unset, it will be set based on seqlen/hidden_size\"\n        },\n    )\n\n    tiled_mlp_use_original_mlp: bool | None = Field(\n        default=True,\n        json_schema_extra={\n            \"description\": \"Whether to use original mlp for ALST tiled mlp. Otherwise uses a generic MLP based on llama.\"\n        },\n    )\n\n    llama4_linearized_experts: bool | None = None\n\n    deepspeed: str | dict[str, Any] | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Deepspeed config path. e.g., deepspeed_configs/zero3.json\"\n        },\n    )\n    deepcompile: bool | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Whether to use deepcompile for faster training with deepspeed\"\n        },\n    )\n    fsdp: list[str] | None = Field(\n        default=None,\n        json_schema_extra={\"description\": \"FSDP configuration\"},\n        deprecated=\"Configuring FSDP using `fsdp` is deprecated. Please use `fsdp_config` instead. \",\n    )\n    fsdp_config: FSDPConfig | None = Field(\n        default=None, json_schema_extra={\"description\": \"FSDP configuration options\"}\n    )\n    fsdp_version: int | None = Field(\n        default=None,\n        json_schema_extra={\"description\": \"FSDP version\"},\n    )\n    fsdp_final_state_dict_type: (\n        Literal[\"FULL_STATE_DICT\", \"LOCAL_STATE_DICT\", \"SHARDED_STATE_DICT\"] | None\n    ) = Field(\n        default=None,\n        deprecated=\"Configuring FSDP final state dict type using `fsdp_final_state_dict_type` is deprecated. Please use `fsdp_config.final_state_dict_type` instead.\",\n    )\n\n    val_set_size: float | None = Field(\n        default=0.0,\n        json_schema_extra={\n            \"description\": \"How much of the dataset to set aside as evaluation. 1 = 100%, 0.50 = 50%, etc. 0 for no eval.\"\n        },\n    )\n\n    dp_shard_size: int | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Number of devices to shard across. If not set, will use all available devices.\"\n        },\n    )\n    dp_replicate_size: int | None = Field(\n        default=None,\n        json_schema_extra={\"description\": \"Number of devices to replicate across.\"},\n    )\n    sequence_parallel_degree: int | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Deprecated: use `context_parallel_size` instead\"\n        },\n    )\n    context_parallel_size: int | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Set to a divisor of the number of GPUs available to split sequences into chunks of equal size. Use in long context training to prevent OOM when sequences cannot fit into a single GPU's VRAM. E.g., if 4 GPUs are available, set this value to 2 to split each sequence into two equal-sized subsequences, or set to 4 to split into four equal-sized subsequences. See https://docs.axolotl.ai/docs/sequence_parallelism.html for more details.\"\n        },\n    )\n    heads_k_stride: int | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Optional; strides across the key dimension. Larger values use more memory but should make training faster. Must evenly divide the number of KV heads in your model.\"\n        },\n    )\n    ring_attn_func: RingAttnFunc | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"One of 'varlen_llama3', 'batch_ring', 'batch_zigzag', 'batch_stripe'. Defaults to 'varlen_llama3' in the sample packing case, and 'batch_ring' in the non-sample packing case.\"\n        },\n    )\n    tensor_parallel_size: int | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Number of tensor parallel processes in TP group. Only supported with DeepSpeed AutoTP.\"\n        },\n    )\n    special_tokens: SpecialTokensConfig | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Add or change special tokens. If you add tokens here, you don't need to add them to the `tokens` list.\"\n        },\n    )\n    tokens: list[str] | None = Field(\n        default=None,\n        json_schema_extra={\"description\": \"Add extra tokens to the tokenizer\"},\n    )\n    added_tokens_overrides: dict[int, str] | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Mapping token_id to new_token_string to override reserved added_tokens in the tokenizer. Only works for tokens that are not part of the base vocab (aka are added_tokens). Can be checked if they exist in tokenizer.json added_tokens.\"\n        },\n    )\n\n    torch_compile: Literal[\"auto\"] | bool | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Whether to use torch.compile and which backend to use. setting to `auto` will enable torch compile when torch>=2.6.0\"\n        },\n    )\n    torch_compile_backend: str | None = Field(\n        default=None,\n        json_schema_extra={\"description\": \"Backend to use for torch.compile\"},\n    )\n    torch_compile_mode: Literal[\"default\", \"reduce-overhead\", \"max-autotune\"] | None = (\n        None\n    )\n\n    max_steps: int | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Maximum number of iterations to train for. It precedes num_epochs which means that if both are set, num_epochs will not be guaranteed. e.g., when 1 epoch is 1000 steps => `num_epochs: 2` and `max_steps: 100` will train for 100 steps\"\n        },\n    )\n    warmup_steps: int | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Number of warmup steps. Cannot use with warmup_ratio\"\n        },\n    )\n    warmup_ratio: float | None = Field(\n        default=None,\n        json_schema_extra={\"description\": \"Warmup ratio. Cannot use with warmup_steps\"},\n    )\n    eval_steps: int | float | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Leave empty to eval at each epoch, integer for every N steps. float for fraction of total steps\"\n        },\n    )\n    evals_per_epoch: int | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Number of times per epoch to run evals, mutually exclusive with eval_steps\"\n        },\n    )\n    eval_strategy: str | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Set to `no` to skip evaluation, `epoch` at end of each epoch, leave empty to infer from `eval_steps`\"\n        },\n    )\n\n    save_steps: int | float | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Leave empty to save at each epoch, integer for every N steps. float for fraction of total steps\"\n        },\n    )\n    saves_per_epoch: int | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Number of times per epoch to save a checkpoint, mutually exclusive with save_steps\"\n        },\n    )\n    save_strategy: str | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Set to `no` to skip checkpoint saves, `epoch` at end of each epoch, `best` when better result is achieved, leave empty to infer from `save_steps`\"\n        },\n    )\n    save_total_limit: int | None = Field(\n        default=None, json_schema_extra={\"description\": \"Checkpoints saved at a time\"}\n    )\n    save_first_step: bool | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Whether to checkpoint a model after the first step of training. Defaults to False.\"\n        },\n    )\n\n    logging_steps: int | None = Field(\n        default=None, json_schema_extra={\"description\": \"Logging frequency\"}\n    )\n    early_stopping_patience: int | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Stop training after this many evaluation losses have increased in a row. https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback\"\n        },\n    )\n    load_best_model_at_end: bool | None = False\n    save_only_model: bool | None = Field(\n        default=False,\n        json_schema_extra={\n            \"description\": \"Save only the model weights, skipping the optimizer. Using this means you can't resume from checkpoints.\"\n        },\n    )\n    use_tensorboard: bool | None = Field(\n        default=None, json_schema_extra={\"description\": \"Use tensorboard for logging\"}\n    )\n    profiler_steps: int | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Enable the pytorch profiler to capture the first N steps of training to the output_dir. see https://pytorch.org/blog/understanding-gpu-memory-1/ for more information. Snapshots can be visualized @ https://pytorch.org/memory_viz\"\n        },\n    )\n    profiler_steps_start: int | None = Field(\n        default=0,\n        json_schema_extra={\n            \"description\": \"Which step to start the profiler at. Useful for only capturing a few steps mid-run.\"\n        },\n    )\n    include_tokens_per_second: bool | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"bool of whether to report tokens per second at the end of training. This is not supported with pre-training datasets.\"\n        },\n    )\n    include_tkps: bool | None = Field(\n        default=True,\n        json_schema_extra={\n            \"description\": \"bool of whether to report tokens per second per-gpu during training by measuring throughput of non-padding tokens.\"\n        },\n    )\n    neftune_noise_alpha: float | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"NEFT https://arxiv.org/abs/2310.05914, set this to a number (paper default is 5) to add noise to embeddings. Currently only supported on Llama and Mistral\"\n        },\n    )\n\n    orpo_alpha: float | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Parameter controlling the relative ratio loss weight in the ORPO loss. Passed to `beta` in `ORPOConfig` due to trl mapping.\"\n        },\n    )\n    rpo_alpha: float | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Weighting of NLL term in loss from RPO paper\"\n        },\n    )\n    simpo_gamma: float | None = Field(\n        default=None,\n        json_schema_extra={\"description\": \"Target reward margin for the SimPO loss\"},\n    )\n    cpo_alpha: float | None = Field(\n        default=None, json_schema_extra={\"description\": \"Weight of the BC regularizer\"}\n    )\n\n    kto_desirable_weight: float | None = Field(\n        default=None,\n        json_schema_extra={\"description\": \"Factor for desirable loss term in KTO loss\"},\n    )\n    kto_undesirable_weight: float | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Factor for undesirable loss term in KTO loss\"\n        },\n    )\n    rl_beta: float | None = Field(\n        default=None,\n        json_schema_extra={\"description\": \"The beta parameter for the RL training\"},\n    )\n\n    max_memory: dict[int | Literal[\"cpu\", \"disk\"], int | str] | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Defines the max memory usage per gpu on the system. Passed through to transformers when loading the model.\"\n        },\n    )\n    gpu_memory_limit: int | str | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Limit the memory for all available GPUs to this amount (if an integer, expressed in gigabytes); default: unset\"\n        },\n    )\n    low_cpu_mem_usage: bool | None = Field(\n        default=None,\n        json_schema_extra={\"description\": \"Whether to use low_cpu_mem_usage\"},\n    )\n\n    chat_template: (\n        ChatTemplate\n        | Annotated[str, StringConstraints(pattern=\"^tokenizer_default_fallback_\")]\n    ) | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"The name of the chat template to use for training, following values are supported: tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default value. alpaca/inst/chatml/gemma/cohere/llama3/phi_3/deepseek_v2/jamba: These chat templates are available in the axolotl codebase at src/axolotl/utils/chat_templates.py. tokenizer_default_fallback_*: where * is the name of the chat template to fallback to. E.g. tokenizer_default_fallback_chatml. This is useful when the chat template is not available in the tokenizer. jinja: Uses a custom jinja template for the chat template. The custom jinja template should be provided in the chat_template_jinja field. The selected chat template will be saved to the tokenizer_config.json for easier inferencing\"\n        },\n    )\n    chat_template_jinja: str | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Custom jinja template or path to jinja file for chat template. This will be only used if chat_template is set to `jinja` or `null` (in which case chat_template is automatically set to `jinja`). Default is null.\"\n        },\n    )\n    chat_template_kwargs: dict[str, Any] | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Additional kwargs to pass to the chat template. This is useful for customizing the chat template. For example, you can pass `thinking=False` to add a generation prompt to the chat template.\"\n        },\n    )\n    eot_tokens: list[str] | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Custom EOT (End-of-Turn) tokens to mask/unmask during training. These tokens mark the boundaries between conversation turns. For example: ['/INST', '</s>', '[/SYSTEM_PROMPT]']. If not specified, defaults to just the model's eos_token. This is useful for templates that use multiple delimiter tokens.\"\n        },\n    )\n    default_system_message: str | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Changes the default system message. Currently only supports chatml.\"\n        },\n    )\n\n    fix_untrained_tokens: int | list[int] | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": (\n                \"Token index or indices to adjust embedding weights to the mean of the other tokens. \"\n                \"This is useful when the model has untrained embeddings.\"\n            )\n        },\n    )\n\n    # INTERNALS - document for now, generally not set externally\n    is_preprocess: bool | None = None\n    preprocess_iterable: bool | None = None\n\n    total_num_tokens: int | None = Field(\n        default=None,\n        json_schema_extra={\"description\": \"Total number of tokens - internal use\"},\n    )\n    total_supervised_tokens: int | None = None\n    sample_packing_eff_est: float | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"You can set these packing optimizations AFTER starting a training at least once. The trainer will provide recommended values for these values.\"\n        },\n    )\n    axolotl_config_path: str | None = None\n\n    is_falcon_derived_model: bool | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Internal use only - Used to identify which the model is based on\"\n        },\n    )\n    is_llama_derived_model: bool | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Internal use only - Used to identify which the model is based on\"\n        },\n    )\n    is_mistral_derived_model: bool | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Internal use only - Used to identify which the model is based on. Please note that if you set this to true, `padding_side` will be set to 'left' by default\"\n        },\n    )\n    is_qwen_derived_model: bool | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Internal use only - Used to identify which the model is based on\"\n        },\n    )\n\n    plugins: list[str] | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Add plugins to extend the pipeline. See `src/axolotl/integrations` for the available plugins or doc below for more details. https://docs.axolotl.ai/docs/custom_integrations.html\"\n        },\n    )\n    generate_samples: bool | None = Field(\n        default=False,\n        json_schema_extra={\n            \"description\": \"Enable sample generation during training for monitoring\"\n        },\n    )\n    num_generation_samples: int | None = Field(\n        default=3,\n        json_schema_extra={\n            \"description\": \"Number of samples to generate at each interval\"\n        },\n    )\n    generation_max_new_tokens: int | None = Field(\n        default=50,\n        json_schema_extra={\"description\": \"Maximum new tokens to generate per sample\"},\n    )\n    generation_temperature: float | None = Field(\n        default=0.7,\n        json_schema_extra={\n            \"description\": \"Temperature for sample generation (0.0 = greedy)\"\n        },\n    )\n    generation_top_p: float | None = Field(\n        default=None,\n        json_schema_extra={\"description\": \"Nucleus sampling parameter for generation\"},\n    )\n    generation_top_k: int | None = Field(\n        default=None,\n        json_schema_extra={\"description\": \"Top-k sampling parameter for generation\"},\n    )\n    generation_prompt_ratio: float | None = Field(\n        default=0.5,\n        json_schema_extra={\"description\": \"Ratio of input to use as prompt (0.0-1.0)\"},\n    )\n    generation_do_sample: bool | None = Field(\n        default=True,\n        json_schema_extra={\n            \"description\": \"Whether to use sampling (vs greedy decoding)\"\n        },\n    )\n\n    @field_serializer(\"datasets\")\n    def datasets_serializer(\n        self, ds_configs: list[DatasetConfig] | None\n    ) -> list[dict[str, Any]] | None:\n        if ds_configs:\n            return [ds_config.model_dump(exclude_none=True) for ds_config in ds_configs]\n        return None\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def warn_peft_trainable_token_to_fix_untrained(cls, data):\n        if (\n            peft_trainable_token_indices := data.get(\"peft_trainable_token_indices\")\n        ) and (fix_untrained_tokens := data.get(\"fix_untrained_tokens\")):\n            if isinstance(fix_untrained_tokens, int):\n                fix_untrained_tokens = (fix_untrained_tokens,)\n\n            if isinstance(peft_trainable_token_indices, int):\n                peft_trainable_token_indices = (peft_trainable_token_indices,)\n\n            for untrained_token_id in fix_untrained_tokens:\n                if untrained_token_id not in peft_trainable_token_indices:\n                    LOG.warning_once(\n                        f\"Token {untrained_token_id} is fixed via `fix_untrained_tokens`, yet not in `peft_trainable_token_indices: ` list. \"\n                        \"Please add it, otherwise the token won't be trained on.\"\n                    )\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_sageattn_wo_sample_packing(cls, data):\n        if (not data.get(\"sample_packing\", False)) and data.get(\"sage_attention\"):\n            if not data.get(\"pad_to_sequence_len\", False):\n                LOG.warning(\n                    \"We recommend turning on `pad_to_sequence_len` for SageAttention without packing.\"\n                    \"This is because there has been signs that the loss explodes after a few steps.\"\n                )\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_sageattn_fft(cls, data):\n        if (not data.get(\"adapter\", False)) and data.get(\"sage_attention\"):\n            LOG.warning(\n                \"We found loss to drop to 0 with SageAttention full finetuning.\"\n                \"Please observe the loss, otherwise switch to LoRA/QLoRA or another attention method.\"\n            )\n        return data\n\n\nclass AxolotlConfigWCapabilities(AxolotlInputConfig):\n    \"\"\"Wrapper to valdiate GPU capabilities with the configured options\"\"\"\n\n    capabilities: GPUCapabilities\n    env_capabilities: EnvCapabilities\n\n    @model_validator(mode=\"after\")\n    def check_bf16(self):\n        if self.capabilities.bf16:\n            if not self.bf16 and not self.bfloat16:\n                LOG.info(\n                    \"bf16 support detected, but not enabled for this configuration.\"\n                )\n        else:\n            if (\n                not self.merge_lora\n                and not self.is_preprocess\n                and (self.bf16 is True or self.bfloat16 is True)\n            ):\n                raise ValueError(\n                    \"bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above.\"\n                )\n        return self\n\n    @model_validator(mode=\"after\")\n    def check_tf32(self):\n        if self.tf32 == \"auto\":\n            self.tf32 = self.capabilities.tf32\n        return self\n\n    @model_validator(mode=\"after\")\n    def check_fp8(self):\n        if self.fp8 and not self.capabilities.fp8:\n            raise ValueError(\"fp8 requested, but fp8 is not supported on this GPU\")\n        elif self.fp8 and self.capabilities.fp8 and not is_fp8_available():\n            raise ValueError(\n                \"fp8 requested, but missing one of ms-amp, transformers-engine or torchao.\"\n            )\n        return self\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_sample_packing_w_sdpa_bf16(cls, data):\n        is_sm_90: bool = (\n            data[\"capabilities\"]\n            and data[\"capabilities\"].get(\"compute_capability\") == \"sm_90\"\n        )\n        if (\n            data.get(\"sample_packing\")\n            and data.get(\"sdp_attention\")\n            and (data.get(\"bfloat16\") or data.get(\"bf16\"))\n            and not is_sm_90\n        ):\n            # https://github.com/pytorch/pytorch/blob/1b03423526536b5f3d35bdfa95ccc6197556cf9b/test/test_transformers.py#L2440-L2450\n            LOG.warning(\n                \"sample_packing & torch sdpa with bf16 is unsupported may results in 0.0 loss. \"\n                \"This may work on H100s.\"\n            )\n\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_compute_capability_w_sageattn(cls, data):\n        if (\n            data.get(\"sage_attention\")\n            and data.get(\"capabilities\")\n            and data.get(\"capabilities\").get(\"compute_capability\")\n            not in [\"sm_80\", \"sm_86\", \"sm_89\", \"sm_90\", \"sm_120\"]\n        ):\n            raise ValueError(\n                \"SageAttention supports compute capability between sm_80 and sm_120. \"\n                \"Please use a different attention implementation.\"\n            )\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_multigpu_unsloth(cls, data):\n        if (\n            data.get(\"unsloth_lora_mlp\")\n            or data.get(\"unsloth_lora_qkv\")\n            or data.get(\"unsloth_lora_o\")\n        ):\n            capabilities = data.get(\"capabilities\")\n            if capabilities and capabilities.get(\"n_gpu\", 0) > 1:\n                raise ValueError(\n                    \"unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with multi-GPU training.\"\n                )\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_multigpu_lora_kernels(cls, data):\n        if (\n            data.get(\"lora_mlp_kernel\")\n            or data.get(\"lora_qkv_kernel\")\n            or data.get(\"lora_o_kernel\")\n        ):\n            capabilities = data.get(\"capabilities\")\n            is_fsdp = data.get(\"fsdp_config\") is not None\n            is_fsdp2 = is_fsdp and str(data.get(\"fsdp_version\")) == \"2\"\n\n            if capabilities and capabilities.get(\"n_gpu\", 0) > 1 and not is_fsdp2:\n                if is_fsdp:\n                    raise ValueError(\n                        \"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with FSDP1.\"\n                    )\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_quantize_moe_experts(cls, data):\n        if data.get(\"quantize_moe_experts\"):\n            if data.get(\"lora_target_linear\"):\n                raise ValueError(\n                    \"lora_target_linear is not compatible with quantize_moe_experts. \"\n                    \"Use lora_target_parameters to target expert weights instead.\"\n                )\n            if data.get(\"adapter\") not in (\"lora\", \"qlora\"):\n                raise ValueError(\"quantize_moe_experts requires adapter: lora or qlora\")\n            if not (data.get(\"load_in_4bit\") or data.get(\"load_in_8bit\")):\n                raise ValueError(\n                    \"quantize_moe_experts requires load_in_4bit or load_in_8bit\"\n                )\n            if (\n                data.get(\"capabilities\")\n                and data[\"capabilities\"].get(\"compute_capability\")\n                and not data[\"capabilities\"][\"compute_capability\"].startswith(\"sm_\")\n            ):\n                raise ValueError(\n                    \"quantize_moe_experts requires CUDA (not compatible with ROCm or other backends)\"\n                )\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_auto_enable_lora_kernels(cls, data):\n        # Only proceed if using LoRA or QLoRA adapter\n        if data.get(\"rl\"):\n            # RL trainers not tested so don't enable kernels by default\n            return data\n        if data.get(\"adapter\") in [\"lora\", \"qlora\"]:\n            # Skip if already set, using unsloth optimizations, or using 8-bit\n            unsloth_fields = [\"unsloth_lora_mlp\", \"unsloth_lora_qkv\", \"unsloth_lora_o\"]\n            kernel_fields = [\"lora_mlp_kernel\", \"lora_qkv_kernel\", \"lora_o_kernel\"]\n            if (\n                any(data.get(k) is not None for k in kernel_fields)\n                or any(data.get(k) for k in unsloth_fields)\n                or data.get(\"adapter\") == \"lora\"\n                and data.get(\"load_in_8bit\")\n            ):\n                return data\n\n            # Skip if trust_remote_code is enabled, as lora kernels are not compatible\n            if data.get(\"trust_remote_code\"):\n                return data\n\n            # Skip if dropout is not 0, as auto enabling it would just disable it during runtime patch checks\n            if data.get(\"lora_dropout\") != 0:\n                return data\n\n            # Check multi-GPU compatibility\n            capabilities = data.get(\"capabilities\")\n            is_multi_gpu = capabilities and capabilities.get(\"n_gpu\", 0) > 1\n            is_fsdp = data.get(\"fsdp_config\") is not None\n            is_fsdp2 = is_fsdp and str(data.get(\"fsdp_version\")) == \"2\"\n\n            if (\n                not is_multi_gpu\n                or (is_multi_gpu and not is_fsdp)\n                or (is_multi_gpu and is_fsdp2)\n            ):\n                # Auto-enable kernels if not explicitly set by user\n                if data.get(\"lora_mlp_kernel\") is None:\n                    data[\"lora_mlp_kernel\"] = True\n\n                if data.get(\"lora_qkv_kernel\") is None:\n                    data[\"lora_qkv_kernel\"] = True\n\n                if data.get(\"lora_o_kernel\") is None:\n                    data[\"lora_o_kernel\"] = True\n\n                LOG.warning(\n                    \"Auto-enabling LoRA kernel optimizations for faster training. \"\n                    + \"Please explicitly set `lora_*_kernel` config values to `false` to disable. \"\n                    + \"See https://docs.axolotl.ai/docs/lora_optims.html for more info.\"\n                )\n\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_adopt_torch_version(cls, data):\n        if (data.get(\"optimizer\") is not None) and (\"adopt\" in data.get(\"optimizer\")):\n            env_capabilities = data.get(\"env_capabilities\", {})\n            torch_version = env_capabilities.get(\"torch_version\")\n\n            if torch_version is None:\n                import torch\n\n                torch_version = str(torch.__version__).split(\"+\", maxsplit=1)[0]\n\n            if version.parse(torch_version) < version.parse(\"2.5.1\"):\n                raise ValueError(\n                    \"ADOPT optimizer is incompatible with torch version < 2.5.1\"\n                )\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_flex_torch_version(cls, data):\n        if (data.get(\"flex_attention\") is not None) and (data.get(\"flex_attention\")):\n            env_capabilities = data.get(\"env_capabilities\", {})\n            torch_version = env_capabilities.get(\"torch_version\")\n\n            if torch_version is None:\n                import torch\n\n                torch_version = str(torch.__version__).split(\"+\", maxsplit=1)[0]\n\n            if version.parse(torch_version) < version.parse(\"2.6.0\"):\n                raise ValueError(\n                    \"Flex attention is not supported on torch version < 2.6.0\"\n                )\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_torch_compile_auto(cls, data):\n        if data.get(\"torch_compile\") == \"auto\":\n            env_capabilities = data.get(\"env_capabilities\", {})\n            if env_capabilities.get(\"torch_version\"):\n                if version.parse(\n                    env_capabilities.get(\"torch_version\")\n                ) >= version.parse(\"2.5.1\"):\n                    LOG.info(\n                        \"torch.compile is available, setting torch_compile to True\"\n                    )\n                    data[\"torch_compile\"] = True\n                else:\n                    data[\"torch_compile\"] = False\n            else:\n                data[\"torch_compile\"] = False\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_beta_and_trl_beta_match(cls, data):\n        if data.get(\"beta\") and data.get(\"trl\", {}).get(\"beta\"):\n            if data[\"beta\"] != data[\"trl\"][\"beta\"]:\n                raise ValueError(\"beta and trl.beta must match or one must be removed\")\n        return data\n\n    @model_validator(mode=\"after\")\n    def check_min_torch_version(self):\n        if self.env_capabilities and self.env_capabilities.torch_version:\n            torch_version = self.env_capabilities.torch_version\n            if version.parse(torch_version) < version.parse(\"2.6.0\"):\n                LOG.warning(\n                    f\"torch=={torch_version} not be supported. Please upgrade to torch>=2.6.0.\"\n                )\n\n        return self\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_qat_config(cls, data):\n        qat_cfg = data.get(\"qat\", {})\n        if not qat_cfg:\n            return data\n\n        if data.get(\"peft\"):\n            raise ValueError(\"QAT and PEFT cannot be used together.\")\n\n        if data.get(\"load_in_8bit\"):\n            raise ValueError(\"QAT and load_in_8bit cannot be used together.\")\n\n        if data.get(\"load_in_4bit\"):\n            raise ValueError(\"QAT and load_in_4bit cannot be used together.\")\n\n        env_capabilities = data.get(\"env_capabilities\", {})\n        torch_version = env_capabilities.get(\"torch_version\")\n\n        if torch_version is None:\n            import torch\n\n            torch_version = str(torch.__version__).split(\"+\", maxsplit=1)[0]\n\n        if version.parse(torch_version) < version.parse(\"2.6.0\"):\n            raise ValueError(\"QAT is not supported on torch version < 2.6.0\")\n\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_fsdp_torch_version(cls, data):\n        env_capabilities = data.get(\"env_capabilities\", {})\n        torch_version = env_capabilities.get(\"torch_version\")\n\n        if torch_version is None:\n            import torch\n\n            torch_version = str(torch.__version__).split(\"+\", maxsplit=1)[0]\n\n        if data.get(\"fsdp_config\") and str(data.get(\"fsdp_version\")) == \"2\":\n            if version.parse(torch_version) < version.parse(\"2.7.0\"):\n                raise ValueError(\"FSDP2 is not supported on torch version < 2.7.0\")\n\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def default_dataloader_opts(cls, data):\n        if (\n            data.get(\"dataloader_num_workers\") is None\n            and data.get(\"dataloader_pin_memory\") is None\n            and data.get(\"dataloader_prefetch_factor\") is None\n        ):\n            data[\"dataloader_num_workers\"] = data.get(\"capabilities\").get(\"n_gpu\", 1)\n            data[\"dataloader_pin_memory\"] = True\n            data[\"dataloader_prefetch_factor\"] = 256\n\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def default_dataset_num_proc(cls, data):\n        if data.get(\"dataset_processes\") is not None:\n            if data.get(\"dataset_num_proc\") is None:\n                data[\"dataset_num_proc\"] = data[\"dataset_processes\"]\n                LOG.warning(\n                    \"dataset_processes is deprecated and will be removed in a future version. \"\n                    \"Please use dataset_num_proc instead.\"\n                )\n            else:\n                LOG.warning(\n                    \"Both dataset_processes and dataset_num_proc are set. \"\n                    \"Using dataset_num_proc and ignoring dataset_processes.\"\n                )\n            del data[\"dataset_processes\"]\n        elif data.get(\"dataset_num_proc\") is None:\n            data[\"dataset_num_proc\"] = get_default_process_count()\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_deduplication_with_streaming(cls, data):\n        if data.get(\"dataset_exact_deduplication\") and (\n            data.get(\"streaming\") or data.get(\"pretraining_dataset\")\n        ):\n            raise NotImplementedError(\n                \"dataset_exact_deduplication is not available for streaming datasets. \"\n            )\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_deduplication_with_skip_prepare(cls, data):\n        if data.get(\"dataset_exact_deduplication\") and data.get(\"skip_prepare_dataset\"):\n            raise ValueError(\n                \"dataset_exact_deduplication=True has no effect when \"\n                \"skip_prepare_dataset=True. Deduplication runs as part of the \"\n                \"prepare pipeline, which is skipped. Either set \"\n                \"skip_prepare_dataset: false or disable \"\n                \"dataset_exact_deduplication.\"\n            )\n        return data\n"
  },
  {
    "path": "src/axolotl/utils/schemas/datasets.py",
    "content": "\"\"\"Pydantic models for datasets-related configuration\"\"\"\n\nfrom typing import Literal\n\nfrom pydantic import BaseModel, Field, model_validator\n\nfrom axolotl.utils.schemas.enums import ChatTemplate\nfrom axolotl.utils.schemas.utils import handle_legacy_message_fields_logic\n\n\nclass UserDefinedPrompterType(BaseModel):\n    \"\"\"Structure for user defined prompt types\"\"\"\n\n    system_prompt: str | None = Field(\n        default=None,\n        json_schema_extra={\"description\": \"Custom user instruction prompt\"},\n    )\n    system_format: str | None = Field(\n        default=None,\n        json_schema_extra={\"description\": \"Use {system} as key to be replaced\"},\n    )\n    field_system: str | None = None\n    field_instruction: str | None = None\n    field_input: str | None = None\n    field_output: str | None = None\n\n    format: str | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Customizable to be single line or multi-line. Use {instruction}/{input} as key to be replaced. 'format' can include {input}\"\n        },\n    )\n    no_input_format: str | None = Field(\n        default=None,\n        json_schema_extra={\"description\": \"'no_input_format' cannot include {input}\"},\n    )\n\n\nclass SFTDataset(BaseModel):\n    \"\"\"SFT configuration subset\"\"\"\n\n    path: str | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"HuggingFace dataset repo | s3:// | gs:// | path to local file or directory\"\n        },\n    )\n    split: str | None = Field(\n        default=None,\n        json_schema_extra={\"description\": \"name of dataset split to load from\"},\n    )\n    type: str | UserDefinedPrompterType | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"The type of prompt to use for training. [alpaca, gpteacher, oasst, reflection]\"\n        },\n    )\n    input_transform: str | None = None\n    shards: int | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"split dataset into N pieces (use with shards_idx)\"\n        },\n    )\n    shards_idx: int | None = Field(\n        default=None,\n        json_schema_extra={\"description\": \"the index of sharded dataset to use\"},\n    )\n    preprocess_shards: int | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"process dataset in N sequential chunks for memory efficiency (exclusive with `shards`)\"\n        },\n    )\n    conversation: str | None = None\n    # Do not make this too strict or it will break the validator to choose different dataset class\n    chat_template: ChatTemplate | str | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"The name of the chat template to use for training, following values are supported: tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default. alpaca/inst/chatml/gemma/cohere/llama3/phi_3/deepseek_v2/jamba: These chat templates are available in the axolotl codebase at src/axolotl/utils/chat_templates.py. tokenizer_default_fallback_*: where * is the name of the chat template to fallback to if the tokenizer does not have a chat template else default to tokenizer. E.g. tokenizer_default_fallback_chatml. jinja: Uses a custom jinja template for the chat template. The custom jinja template should be provided in the chat_template_jinja field.\"\n        },\n    )\n    chat_template_jinja: str | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Custom jinja chat template or path to jinja file. Used only if `chat_template: jinja` or empty.\"\n        },\n    )\n    data_files: str | list[str] | None = Field(\n        default=None, json_schema_extra={\"description\": \"path to source data files\"}\n    )\n    input_format: str | None = None\n    name: str | None = Field(\n        default=None,\n        json_schema_extra={\"description\": \"name of dataset configuration to load\"},\n    )\n    ds_type: str | None = Field(\n        default=None,\n        json_schema_extra={\"description\": \"defines the datatype when path is a file\"},\n    )\n    field: str | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"For `completion` datasets only, uses the provided field instead of `text` column\"\n        },\n    )\n    field_human: str | None = None\n    field_model: str | None = None\n    field_messages: str | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": 'Key containing the messages (default: \"messages\")'\n        },\n    )\n    field_tools: str | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": 'Key containing the tools (default: \"tools\"). Must be a list[dict] and follow [JSON schema](https://json-schema.org/learn/getting-started-step-by-step).'\n        },\n    )\n    field_thinking: str | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": 'Key containing the reasoning trace (default: \"reasoning_content\").'\n        },\n    )\n    template_thinking_key: str | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"The key the chat template expects that indicates the reasoning trace.\"\n        },\n    )\n    # deprecated, use message_property_mappings\n    message_field_role: str | None = None\n    # deprecated, use message_property_mappings\n    message_field_content: str | None = None\n    message_property_mappings: dict[str, str] | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Mapping of properties from the input dataset to the chat template. (default: message_property_mappings={'role':'role', 'content':'content'}) If a property exists in the template but not in this mapping, the system will attempt to load it directly from the message using the property name as the key. Example: In the mapping below, 'from' is loaded from input dataset and used as 'role', while 'value' is loaded and used as 'content' in the chat template.\"\n        },\n    )\n    message_field_training: str | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"The key in the message turn that indicates via boolean whether tokens of a turn should be considered for training. Useful to selectively train on certain turns besides the `roles_to_train`.\"\n        },\n    )\n    message_field_training_detail: str | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"The key in the message turn that contains the training details. Useful to selectively train on certain tokens in a turn. The value of the key is a List[Dict] containing `begin_offset` (start character index in content), `end_offset` (end character index in content), and `train` (boolean whether to train).\"\n        },\n    )\n    split_thinking: bool | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"(for Qwen3 template only) Whether to split the assistant content based on a reasoning trace inside delimited tags\"\n        },\n    )\n    logprobs_field: str | None = None\n    temperature: float | None = None\n    roles_to_train: list[str] | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Roles to train on. The tokens from these roles will be considered for the loss.\"\n        },\n    )\n    train_on_eos: Literal[\"all\", \"turn\", \"last\"] | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Which EOS tokens to train on in the conversation. Possible values are: all: train on all EOS tokens, turn (default): train on the EOS token at the end of each trainable turn, last: train on the last EOS token in the conversation\"\n        },\n    )\n    roles: dict[str, list[str]] | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": 'Roles mapping in the messages. The format is {target_role: [source_roles]}. All source roles will be mapped to the target role. The default is: user: [\"human\", \"user\"], assistant: [\"gpt\", \"assistant\"], system: [\"system\"], tool: [\"tool\"]'\n        },\n    )\n    drop_system_message: bool | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Whether to drop the system turn from the dataset. Only works with chat_template. This does not drop the default system message from chat_template if it exists. If you wish to, we recommend using a custom jinja template with the default system message removed or adding a system turn with empty content.\"\n        },\n    )\n    trust_remote_code: bool | None = Field(\n        default=False,\n        json_schema_extra={\"description\": \"Trust remote code for untrusted source\"},\n    )\n    revision: str | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"The specific revision of the dataset to use when loading from the Hugging Face Hub. This can be a commit hash, tag, or branch name. If not specified, the latest version will be used. This parameter is ignored for local datasets.\"\n        },\n    )\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def handle_legacy_message_fields(cls, data):\n        \"\"\"Handle backwards compatibility between legacy message field mapping and new property mapping system.\"\"\"\n        return handle_legacy_message_fields_logic(data)\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_chat_template_config(cls, data):\n        if isinstance(data, BaseModel):\n            data = data.model_dump()\n\n        # Set chat_template to tokenizer_default if not set\n        if data.get(\"type\") == \"chat_template\" and not data.get(\"chat_template\"):\n            data[\"chat_template\"] = ChatTemplate.tokenizer_default\n\n        # if chat_template is set to jinja, chat_template_jinja is required\n        if data.get(\"chat_template\") == ChatTemplate.jinja and not data.get(\n            \"chat_template_jinja\"\n        ):\n            raise ValueError(\n                \"chat_template_jinja is required when chat_template is set to jinja\"\n            )\n\n        # If chat_template_jinja is set, set chat_template to jinja\n        if data.get(\"chat_template_jinja\") and not data.get(\"chat_template\"):\n            data[\"chat_template\"] = ChatTemplate.jinja\n\n        return data\n\n\nclass PretrainingDataset(BaseModel):\n    \"\"\"Pretraining dataset configuration subset\"\"\"\n\n    name: str | None = None\n    path: str | None = None\n    split: str | None = \"train\"\n    text_column: str | None = \"text\"\n    type: str | None = \"pretrain\"\n    trust_remote_code: bool | None = False\n    data_files: str | None = None\n    skip: int | None = None\n\n\nclass UserDefinedDPOType(BaseModel):\n    \"\"\"User defined typing for DPO\"\"\"\n\n    field_system: str | None = None\n    field_prompt: str | None = None\n    field_chosen: str | None = None\n    field_rejected: str | None = None\n    prompt_format: str | None = None\n    chosen_format: str | None = None\n    rejected_format: str | None = None\n\n\nclass DPODataset(BaseModel):\n    \"\"\"DPO configuration subset\"\"\"\n\n    path: str | None = None\n    split: str | None = None\n    type: UserDefinedDPOType | str | None = None\n    data_files: list[str] | None = None\n    revision: str | None = None\n    field_messages: str | None = None\n\n\nclass StepwiseSupervisedDataset(BaseModel):\n    \"\"\"Stepwise supervised dataset configuration subset\"\"\"\n\n    path: str | None = None\n    split: str | None = None\n    data_files: list[str] | None = None\n    revision: str | None = None\n    step_separator: str | None = None\n    max_completion_length: int | None = None\n    train_on_last_step_only: bool | None = None\n\n\nclass UserDefinedKTOType(BaseModel):\n    \"\"\"User defined typing for KTO\"\"\"\n\n    field_system: str | None = None\n    field_prompt: str | None = None\n    field_completion: str | None = None\n    field_label: bool | None = None\n    prompt_format: str | None = None\n    completion_format: str | None = None\n\n\nclass KTODataset(BaseModel):\n    \"\"\"KTO configuration subset\"\"\"\n\n    path: str | None = None\n    split: str | None = None\n    type: UserDefinedKTOType | str | None = None\n    data_files: list[str] | None = None\n    trust_remote_code: bool | None = False\n    revision: str | None = None\n\n\nDatasetConfig = SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset\n"
  },
  {
    "path": "src/axolotl/utils/schemas/deprecated.py",
    "content": "\"\"\"Pydantic models for deprecated and remapped configuration parameters\"\"\"\n\nfrom typing import Any\n\nfrom pydantic import BaseModel, Field, field_validator\n\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\n\nclass DeprecatedParameters(BaseModel):\n    \"\"\"configurations that are deprecated\"\"\"\n\n    max_packed_sequence_len: int | None = None\n    rope_scaling: Any | None = None\n    noisy_embedding_alpha: float | None = None\n    dpo_beta: float | None = None\n    evaluation_strategy: str | None = None\n    eval_table_size: int | None = None\n    eval_max_new_tokens: int | None = None\n    dpo_use_logits_to_keep: bool | None = None\n    dpo_generate_during_eval: bool | None = None\n\n    @field_validator(\"max_packed_sequence_len\")\n    @classmethod\n    def validate_max_packed_sequence_len(cls, max_packed_sequence_len):\n        if max_packed_sequence_len:\n            raise DeprecationWarning(\"`max_packed_sequence_len` is no longer supported\")\n        return max_packed_sequence_len\n\n    @field_validator(\"rope_scaling\")\n    @classmethod\n    def validate_rope_scaling(cls, rope_scaling):\n        if rope_scaling:\n            raise DeprecationWarning(\n                \"`rope_scaling` is no longer supported, it should now be be a key under `model_config`\"\n            )\n        return rope_scaling\n\n    @field_validator(\"noisy_embedding_alpha\")\n    @classmethod\n    def validate_noisy_embedding_alpha(cls, noisy_embedding_alpha):\n        if noisy_embedding_alpha:\n            LOG.warning(\"noisy_embedding_alpha is deprecated, use neftune_noise_alpha\")\n        return noisy_embedding_alpha\n\n    @field_validator(\"dpo_beta\")\n    @classmethod\n    def validate_dpo_beta(cls, dpo_beta):\n        if dpo_beta is not None:\n            LOG.warning(\"dpo_beta is deprecated, use rl_beta instead\")\n        return dpo_beta\n\n    @field_validator(\"evaluation_strategy\")\n    @classmethod\n    def validate_evaluation_strategy(cls, evaluation_strategy):\n        if evaluation_strategy is not None:\n            LOG.warning(\"evaluation_strategy is deprecated, use eval_strategy instead\")\n        return evaluation_strategy\n\n    @field_validator(\"eval_table_size\")\n    @classmethod\n    def validate_eval_table_size(cls, eval_table_size):\n        if eval_table_size is not None:\n            LOG.warning(\n                \"eval_table_size is deprecated and superseded by generate_samples config. \"\n                \"Please use generate_samples: true and num_generation_samples instead. \"\n                \"The LogPredictionCallback is replaced by the new sample generation feature.\"\n            )\n        return eval_table_size\n\n    @field_validator(\"eval_max_new_tokens\")\n    @classmethod\n    def validate_eval_max_new_tokens(cls, eval_max_new_tokens):\n        if eval_max_new_tokens is not None:\n            LOG.warning(\n                \"eval_max_new_tokens is deprecated and superseded by generate_samples config. \"\n                \"Please use generation_max_new_tokens instead.\"\n            )\n        return eval_max_new_tokens\n\n    @field_validator(\"dpo_use_logits_to_keep\")\n    @classmethod\n    def validate_dpo_use_logits_to_keep(cls, dpo_use_logits_to_keep):\n        if dpo_use_logits_to_keep is not None:\n            raise DeprecationWarning(\n                \"`dpo_use_logits_to_keep` is no longer supported, \"\n                \"it has been removed in TRL >= 0.29.0\"\n            )\n        return dpo_use_logits_to_keep\n\n    @field_validator(\"dpo_generate_during_eval\")\n    @classmethod\n    def validate_dpo_generate_during_eval(cls, dpo_generate_during_eval):\n        if dpo_generate_during_eval is not None:\n            raise DeprecationWarning(\n                \"`dpo_generate_during_eval` is no longer supported, \"\n                \"it has been removed in TRL >= 0.29.0\"\n            )\n        return dpo_generate_during_eval\n\n\nclass RemappedParameters(BaseModel):\n    \"\"\"Parameters that have been remapped to other names\"\"\"\n\n    overrides_of_model_config: dict[str, Any] | None = Field(\n        default=None,\n        alias=\"model_config\",\n        json_schema_extra={\n            \"description\": \"optional overrides to the base model configuration\"\n        },\n    )\n    overrides_of_model_kwargs: dict[str, Any] | None = Field(\n        default=None,\n        alias=\"model_kwargs\",\n        json_schema_extra={\n            \"description\": \"optional overrides the base model loading from_pretrained\"\n        },\n    )\n    type_of_model: str | None = Field(\n        default=None,\n        alias=\"model_type\",\n        json_schema_extra={\n            \"description\": \"If you want to specify the type of model to load, AutoModelForCausalLM is a good choice too\"\n        },\n    )\n    revision_of_model: str | None = Field(\n        default=None,\n        alias=\"model_revision\",\n        json_schema_extra={\n            \"description\": \"You can specify to choose a specific model revision from huggingface hub\"\n        },\n    )\n"
  },
  {
    "path": "src/axolotl/utils/schemas/dynamic_checkpoint.py",
    "content": "\"\"\"Schema for dynamic checkpoint configuration.\"\"\"\n\nfrom pydantic import BaseModel, Field\n\n\nclass DynamicCheckpointConfig(BaseModel):\n    \"\"\"Configuration for dynamic checkpoint triggering during training.\"\"\"\n\n    enabled: bool = Field(\n        default=False,\n        json_schema_extra={\n            \"description\": \"Enable dynamic checkpoint triggering during training. \"\n            \"Create a file 'axolotl_checkpoint.save' in the configured `output_dir` to trigger. \"\n        },\n    )\n    check_interval: int = Field(\n        default=10,\n        ge=1,\n        json_schema_extra={\n            \"description\": \"Check for trigger file every N steps (reduces I/O overhead). \"\n            \"Default: 100\"\n        },\n    )\n    trigger_file_path: str = Field(\n        default=\"\",\n        json_schema_extra={\n            \"description\": \"Custom trigger filename (optional). \"\n            \"If not specified, defaults to 'axolotl_checkpoint.save'. \"\n            \"Specify a filename (not a full path) to override the default.\"\n        },\n    )\n"
  },
  {
    "path": "src/axolotl/utils/schemas/enums.py",
    "content": "\"\"\"Enums for Axolotl input config\"\"\"\n\nfrom enum import Enum\n\nimport torch\n\n\nclass TorchAOQuantDType(Enum):\n    int4 = torch.int4\n    int8 = torch.int8\n    float8_e4m3fn = torch.float8_e4m3fn\n    nvfp4 = \"nvfp4\"\n    mxfp4 = \"mxfp4\"\n\n    def from_string(str):\n        if str == \"int4\":\n            return TorchAOQuantDType.int4\n        if str == \"int8\":\n            return TorchAOQuantDType.int8\n        if str in [\"float8_e4m3fn\", \"fp8\", \"float8\"]:\n            return TorchAOQuantDType.float8_e4m3fn\n        if str == \"nvfp4\":\n            return TorchAOQuantDType.nvfp4\n        if str == \"mxfp4\":\n            return TorchAOQuantDType.mxfp4\n\n\nclass RLType(str, Enum):\n    \"\"\"RL trainer type configuration subset\"\"\"\n\n    DPO = \"dpo\"\n    GDPO = \"gdpo\"\n    GRPO = \"grpo\"\n    IPO = \"ipo\"\n    ORPO = \"orpo\"\n    KTO = \"kto\"\n    SIMPO = \"simpo\"\n\n\nclass ChatTemplate(str, Enum):\n    \"\"\"Chat templates configuration subset\"\"\"\n\n    alpaca = \"alpaca\"\n    chatml = \"chatml\"\n    mistral_v1 = \"mistral_v1\"\n    mistral_v2v3 = \"mistral_v2v3\"\n    mistral_v3_tekken = \"mistral_v3_tekken\"\n    mistral_v7_tekken = \"mistral_v7_tekken\"\n    gemma = \"gemma\"\n    cohere = \"cohere\"\n    llama3 = \"llama3\"\n    llama3_2_vision = \"llama3_2_vision\"\n    llama4 = \"llama4\"\n    phi_3 = \"phi_3\"\n    phi_35 = \"phi_35\"\n    deepseek_v2 = \"deepseek_v2\"\n    deepseek_v3 = \"deepseek_v3\"\n    jamba = \"jamba\"\n    jinja = \"jinja\"\n    qwen_25 = \"qwen_25\"\n    qwen3 = \"qwen3\"\n    qwen3_5 = \"qwen3_5\"\n    falcon_h1 = \"falcon_h1\"\n    tokenizer_default = \"tokenizer_default\"\n    exaone = \"exaone\"\n    exaone4 = \"exaone4\"\n    metharme = \"metharme\"\n    pixtral = \"pixtral\"\n    llava = \"llava\"\n    qwen2_vl = \"qwen2_vl\"\n    gemma3 = \"gemma3\"\n    gemma3n = \"gemma3n\"\n    command_a = \"command_a\"\n    command_a_tool_use = \"command_a_tool_use\"\n    command_a_rag = \"command_a_rag\"\n    aya = \"aya\"\n\n\nclass CustomSupportedOptimizers(str, Enum):\n    \"\"\"Custom supported optimizers\"\"\"\n\n    optimi_adamw = \"optimi_adamw\"\n    ao_adamw_4bit = \"ao_adamw_4bit\"\n    ao_adamw_8bit = \"ao_adamw_8bit\"\n    ao_adamw_fp8 = \"ao_adamw_fp8\"\n    adopt_adamw = \"adopt_adamw\"\n    came_pytorch = \"came_pytorch\"\n    muon = \"muon\"\n    dion = \"dion\"\n    flash_adamw = \"flash_adamw\"\n    flash_adam = \"flash_adam\"\n    flash_sgd = \"flash_sgd\"\n    flash_sgdw = \"flash_sgdw\"\n    flash_lion = \"flash_lion\"\n\n\nclass RingAttnFunc(str, Enum):\n    \"\"\"Enum class for supported `ring-flash-attn` implementations\"\"\"\n\n    VARLEN_LLAMA3 = \"varlen_llama3\"\n    BATCH_RING = \"batch_ring\"\n    # VARLEN_RING = \"varlen_ring\"\n    # VARLEN_ZIGZAG = \"varlen_zigzag\"\n    # BATCH_ZIGZAG = \"batch_zigzag\"\n    # BATCH_STRIPE = \"batch_stripe\"\n"
  },
  {
    "path": "src/axolotl/utils/schemas/fsdp.py",
    "content": "\"\"\"\nFSDP Configuration Schema\n\"\"\"\n\nfrom typing import Literal\n\nfrom pydantic import AliasChoices, BaseModel, Field\n\n\nclass FSDPConfig(BaseModel):\n    \"\"\"\n    FSDP Configuration Schema\n    \"\"\"\n\n    fsdp_version: int | None = Field(\n        validation_alias=AliasChoices(\"fsdp_version\", \"version\"),\n        default=None,\n        json_schema_extra={\"description\": \"FSDP version\"},\n    )\n    activation_checkpointing: bool | None = Field(\n        default=None,\n        description=\"Enable activation checkpointing to reduce memory usage during forward passes\",\n    )\n    offload_params: bool | None = Field(\n        default=None,\n        description=\"Offload parameters to CPU to reduce GPU memory usage\",\n    )\n    sync_module_states: bool | None = Field(\n        default=None,\n        description=\"Synchronize module states across all processes\",\n    )\n    cpu_ram_efficient_loading: bool | None = Field(\n        default=None,\n        description=\"Enable CPU RAM efficient loading to reduce memory usage during model loading\",\n    )\n    cpu_offload_pin_memory: bool | None = Field(\n        default=None,\n        description=\"Disabling this enables swap memory usage for resource-constrained setups when offload_params is enabled.\",\n    )\n    use_orig_params: bool | None = Field(\n        default=None,\n        description=\"Use original parameters instead of flattened parameters\",\n    )\n\n    state_dict_type: (\n        Literal[\"FULL_STATE_DICT\", \"LOCAL_STATE_DICT\", \"SHARDED_STATE_DICT\"] | None\n    ) = Field(\n        default=None,\n        description=\"Type of state dict to use for saving/loading checkpoints\",\n    )\n    final_state_dict_type: (\n        Literal[\"FULL_STATE_DICT\", \"LOCAL_STATE_DICT\", \"SHARDED_STATE_DICT\"] | None\n    ) = Field(\n        default=None,\n        description=\"Final state dict type to use after training completion\",\n    )\n\n    auto_wrap_policy: Literal[\"TRANSFORMER_BASED_WRAP\", \"SIZE_BASED_WRAP\"] | None = (\n        Field(\n            default=None,\n            description=\"Policy for automatically wrapping modules with FSDP\",\n        )\n    )\n    transformer_layer_cls_to_wrap: str | None = Field(\n        default=None,\n        description=\"Class name of transformer layers to wrap (e.g., 'LlamaDecoderLayer')\",\n    )\n\n    reshard_after_forward: bool | None = Field(\n        default=None,\n        description=\"Reshard parameters after forward pass to save memory\",\n    )\n    mixed_precision_policy: str | None = Field(\n        default=None,\n        description=\"Mixed precision policy for FSDP (e.g., 'fp16', 'bf16')\",\n    )\n"
  },
  {
    "path": "src/axolotl/utils/schemas/integrations.py",
    "content": "\"\"\"Pydantic models for Axolotl integrations\"\"\"\n\nfrom typing import Any\n\nfrom pydantic import BaseModel, Field, model_validator\n\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\n\nclass MLFlowConfig(BaseModel):\n    \"\"\"MLFlow configuration subset\"\"\"\n\n    use_mlflow: bool | None = None\n    mlflow_tracking_uri: str | None = Field(\n        default=None, json_schema_extra={\"description\": \"URI to mlflow\"}\n    )\n    mlflow_experiment_name: str | None = Field(\n        default=None, json_schema_extra={\"description\": \"Your experiment name\"}\n    )\n    mlflow_run_name: str | None = Field(\n        default=None, json_schema_extra={\"description\": \"Your run name\"}\n    )\n    hf_mlflow_log_artifacts: bool | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"set to true to copy each saved checkpoint on each save to mlflow artifact registry\"\n        },\n    )\n\n\nclass LISAConfig(BaseModel):\n    \"\"\"LISA configuration subset\"\"\"\n\n    lisa_n_layers: int | None = Field(\n        default=None,\n        json_schema_extra={\"description\": \"the number of activate layers in LISA\"},\n    )\n    lisa_step_interval: int | None = Field(\n        default=None,\n        json_schema_extra={\"description\": \"how often to switch layers in LISA\"},\n    )\n    lisa_layers_attribute: str | None = Field(\n        default=\"model.layers\",\n        json_schema_extra={\"description\": \"path under the model to access the layers\"},\n    )\n\n\nclass WandbConfig(BaseModel):\n    \"\"\"Wandb configuration subset\"\"\"\n\n    use_wandb: bool | None = None\n    wandb_name: str | None = Field(\n        default=None,\n        json_schema_extra={\"description\": \"Set the name of your wandb run\"},\n    )\n    wandb_run_id: str | None = Field(\n        default=None, json_schema_extra={\"description\": \"Set the ID of your wandb run\"}\n    )\n    wandb_mode: str | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": '\"offline\" to save run metadata locally and not sync to the server, \"disabled\" to turn off wandb'\n        },\n    )\n    wandb_project: str | None = Field(\n        default=None, json_schema_extra={\"description\": \"Your wandb project name\"}\n    )\n    wandb_entity: str | None = Field(\n        default=None,\n        json_schema_extra={\"description\": \"A wandb Team name if using a Team\"},\n    )\n    wandb_watch: str | None = None\n    wandb_log_model: str | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": '\"checkpoint\" to log model to wandb Artifacts every `save_steps` or \"end\" to log only at the end of training'\n        },\n    )\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_wandb_run(cls, data):\n        if data.get(\"wandb_run_id\") and not data.get(\"wandb_name\"):\n            data[\"wandb_name\"] = data.get(\"wandb_run_id\")\n\n            LOG.warning(\n                \"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead.\"\n            )\n\n        return data\n\n\nclass CometConfig(BaseModel):\n    \"\"\"Comet configuration subset\"\"\"\n\n    use_comet: bool | None = Field(\n        default=None,\n        json_schema_extra={\"description\": \"Enable or disable Comet integration.\"},\n    )\n    comet_api_key: str | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"API key for Comet. Recommended to set via `comet login`.\"\n        },\n    )\n    comet_workspace: str | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Workspace name in Comet. Defaults to the user's default workspace.\"\n        },\n    )\n    comet_project_name: str | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Project name in Comet. Defaults to Uncategorized.\"\n        },\n    )\n    comet_experiment_key: str | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Identifier for the experiment. Used to append data to an existing experiment or control the key of new experiments. Default to a random key.\"\n        },\n    )\n    comet_mode: str | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": 'Create a new experiment (\"create\") or log to an existing one (\"get\"). Default (\"get_or_create\") auto-selects based on configuration.'\n        },\n    )\n    comet_online: bool | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Set to True to log data to Comet server, or False for offline storage. Default is True.\"\n        },\n    )\n    comet_experiment_config: dict[str, Any] | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Dictionary for additional configuration settings, see the doc for more details.\"\n        },\n    )\n\n\nclass GradioConfig(BaseModel):\n    \"\"\"Gradio configuration subset\"\"\"\n\n    gradio_title: str | None = None\n    gradio_share: bool | None = None\n    gradio_server_name: str | None = None\n    gradio_server_port: int | None = None\n    gradio_max_new_tokens: int | None = None\n    gradio_temperature: float | None = None\n\n\nclass RayConfig(BaseModel):\n    \"\"\"Ray launcher configuration subset\"\"\"\n\n    use_ray: bool = Field(default=False)\n    ray_run_name: str | None = Field(\n        default=None,\n        json_schema_extra={\n            \"help\": \"The training results will be saved at `saves/ray_run_name`.\"\n        },\n    )\n    ray_num_workers: int = Field(\n        default=1,\n        json_schema_extra={\n            \"help\": \"The number of workers for Ray training. Default is 1 worker.\"\n        },\n    )\n    resources_per_worker: dict = Field(\n        default_factory=lambda: {\"GPU\": 1},\n        json_schema_extra={\n            \"help\": \"The resources per worker for Ray training. Default is to use 1 GPU per worker.\"\n        },\n    )\n\n\nclass OpenTelemetryConfig(BaseModel):\n    \"\"\"OpenTelemetry configuration subset\"\"\"\n\n    use_otel_metrics: bool | None = Field(\n        default=False,\n        json_schema_extra={\n            \"description\": \"Enable OpenTelemetry metrics collection and Prometheus export\"\n        },\n    )\n    otel_metrics_host: str | None = Field(\n        default=\"localhost\",\n        json_schema_extra={\n            \"title\": \"OpenTelemetry Metrics Host\",\n            \"description\": \"Host to bind the OpenTelemetry metrics server to\",\n        },\n    )\n    otel_metrics_port: int | None = Field(\n        default=8000,\n        json_schema_extra={\n            \"description\": \"Port for the Prometheus metrics HTTP server\"\n        },\n    )\n\n\nclass TrackioConfig(BaseModel):\n    \"\"\"Trackio configuration subset\"\"\"\n\n    use_trackio: bool | None = None\n    trackio_project_name: str | None = Field(\n        default=None,\n        json_schema_extra={\"description\": \"Your trackio project name\"},\n    )\n    trackio_run_name: str | None = Field(\n        default=None,\n        json_schema_extra={\"description\": \"Set the name of your trackio run\"},\n    )\n    trackio_space_id: str | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Hugging Face Space ID to sync dashboard to (optional, runs locally if not provided)\"\n        },\n    )\n"
  },
  {
    "path": "src/axolotl/utils/schemas/internal/__init__.py",
    "content": "\"\"\"module for gpu capabilities\"\"\"\n\nfrom typing import Optional\n\nfrom pydantic import BaseModel, Field\n\n\nclass GPUCapabilities(BaseModel):\n    \"\"\"model to manage the gpu capabilities statically\"\"\"\n\n    bf16: bool = Field(default=False)\n    fp8: bool = Field(default=False)\n    tf32: bool = Field(default=False)\n    n_gpu: int = Field(default=1)\n    n_node: int = Field(default=1)\n    compute_capability: Optional[str] = Field(default=None)\n\n\nclass EnvCapabilities(BaseModel):\n    \"\"\"model to manage the environment capabilities statically\"\"\"\n\n    torch_version: Optional[str] = Field(default=None)\n"
  },
  {
    "path": "src/axolotl/utils/schemas/model.py",
    "content": "\"\"\"Pydantic models for model input / output, etc. configuration\"\"\"\n\nfrom typing import Any, Literal\n\nfrom pydantic import BaseModel, Field, field_validator\n\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\n\nclass ModelInputConfig(BaseModel):\n    \"\"\"Model configuration subset\"\"\"\n\n    model_config = {\"protected_namespaces\": ()}\n\n    base_model: str = Field(\n        json_schema_extra={\n            \"description\": \"This is the huggingface model that contains *.pt, *.safetensors, or *.bin files. This can also be a relative path to a model on disk\"\n        }\n    )\n    base_model_config: str | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"If the base_model repo on hf hub doesn't include configuration .json files, You can set that here, or leave this empty to default to base_model\"\n        },\n    )\n    cls_model_config: str | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"transformers config class (e.g., 'LlamaConfig', 'MistralConfig'). Defaults to AutoConfig.\"\n        },\n    )\n    tokenizer_config: str | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Optional tokenizer configuration path in case you want to use a different tokenizer than the one defined in the base model\"\n        },\n    )\n    tokenizer_use_fast: bool | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"use_fast option for tokenizer loading from_pretrained, default to True\"\n        },\n    )\n    tokenizer_legacy: bool | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Whether to use the legacy tokenizer setting, defaults to True\"\n        },\n    )\n    tokenizer_use_mistral_common: bool | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Whether to use mistral-common tokenizer. If set to True, it will use the mistral-common tokenizer.\"\n        },\n    )\n    tokenizer_type: str | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Corresponding tokenizer for the model AutoTokenizer is a good choice\"\n        },\n    )\n    processor_type: str | None = Field(\n        default=None, json_schema_extra={\"description\": \"transformers processor class\"}\n    )\n    tokenizer_save_jinja_files: bool | None = Field(\n        default=True,  # match the default behavior from transformers\n        json_schema_extra={\n            \"description\": \"Whether to save jinja files for tokenizer, transformers default is True\"\n        },\n    )\n    trust_remote_code: bool | None = Field(\n        default=None,\n        json_schema_extra={\"description\": \"Trust remote code for untrusted source\"},\n    )\n\n    experimental_skip_move_to_device: bool | None = Field(\n        default=True,\n        json_schema_extra={\n            \"description\": \"Don't move the model to the device before sharding. Set to `false` to revert to legacy behavior.\"\n        },\n    )\n\n    use_kernels: bool | None = Field(\n        default=None,\n        json_schema_extra={\"description\": \"Use custom kernels, e.g. MegaBlocks.\"},\n    )\n\n    model_quantization_config: Literal[\"Mxfp4Config\"] | None = Field(\n        default=None,\n        json_schema_extra={\"description\": \"Model loading quantization config\"},\n    )\n    model_quantization_config_kwargs: dict[str, Any] | None = Field(\n        default=None,\n        json_schema_extra={\"description\": \"kwargs for model quantization config\"},\n    )\n\n    @field_validator(\"trust_remote_code\")\n    @classmethod\n    def hint_trust_remote_code(cls, trust_remote_code):\n        if trust_remote_code:\n            LOG.warning(\n                \"`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model.\"\n            )\n        return trust_remote_code\n\n\nclass ModelOutputConfig(BaseModel):\n    \"\"\"model save configuration subset\"\"\"\n\n    output_dir: str = Field(\n        default=\"./model-out\",\n        json_schema_extra={\"description\": \"Where to save the full-finetuned model to\"},\n    )\n    hub_model_id: str | None = Field(\n        default=None, json_schema_extra={\"description\": \"push checkpoints to hub\"}\n    )\n    hub_strategy: str | None = Field(\n        default=None,\n        json_schema_extra={\"description\": \"how to push checkpoints to hub\"},\n    )\n    hub_revision: str | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"branch/revision to push to on hub (default: main)\"\n        },\n    )\n    save_safetensors: bool | None = Field(\n        default=True,\n        json_schema_extra={\n            \"description\": \"Whether to save the model using safetensors format. Defaults to True.\"\n        },\n    )\n\n    @field_validator(\"save_safetensors\")\n    @classmethod\n    def validate_save_safetensors(cls, v):\n        if v is False:\n            raise ValueError(\n                \"save_safetensors=False is not supported in Transformers V5. \"\n                \"Transformers V5 always uses safetensors format for model serialization. \"\n                \"This field is deprecated and will be removed in a future version.\"\n            )\n        # Allow None and True, will default to True if None\n        return True if v is None else v\n\n\nclass SpecialTokensConfig(BaseModel):\n    \"\"\"Special tokens configuration subset\"\"\"\n\n    bos_token: str | None = None\n    eos_token: str | None = None\n    pad_token: str | None = None\n    unk_token: str | None = None\n    additional_special_tokens: list[str] | None = None\n"
  },
  {
    "path": "src/axolotl/utils/schemas/multimodal.py",
    "content": "\"\"\"Pydantic models for multimodal-related configuration\"\"\"\n\nfrom typing import Literal\n\nfrom PIL.Image import Resampling\nfrom pydantic import BaseModel, Field, field_validator\n\n\nclass MultiModalConfig(BaseModel):\n    \"\"\"Multi-modal configuration subset\"\"\"\n\n    image_size: int | tuple[int, int] | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": (\n                \"The size of the image to resize to. It can be an integer (resized into padded-square image) or a tuple (width, height).\"\n                \"If not provided, we will attempt to load from preprocessor.size, otherwise, images won't be resized.\"\n            )\n        },\n    )\n    image_resize_algorithm: (\n        Literal[\"bilinear\", \"bicubic\", \"lanczos\"] | Resampling | None\n    ) = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"The resampling algorithm to use for image resizing. Default is bilinear. Please refer to PIL.Image.Resampling for more details.\"\n        },\n    )\n\n    @field_validator(\"image_resize_algorithm\", mode=\"before\")\n    @classmethod\n    def convert_image_resize_algorithm(cls, image_resize_algorithm):\n        \"\"\"\n        Convert the image resize algorithm to a PIL.Image.Resampling enum.\n        \"\"\"\n        if isinstance(image_resize_algorithm, str):\n            image_resize_algorithm = image_resize_algorithm.lower()\n            if image_resize_algorithm == \"bilinear\":\n                image_resize_algorithm = Resampling.BILINEAR\n            elif image_resize_algorithm == \"bicubic\":\n                image_resize_algorithm = Resampling.BICUBIC\n            elif image_resize_algorithm == \"lanczos\":\n                image_resize_algorithm = Resampling.LANCZOS\n            else:\n                raise ValueError(\n                    f\"Invalid image resize algorithm: {image_resize_algorithm}\"\n                )\n        return image_resize_algorithm\n"
  },
  {
    "path": "src/axolotl/utils/schemas/peft.py",
    "content": "\"\"\"Pydantic models for PEFT-related configuration\"\"\"\n\nfrom typing import Any, Literal\n\nfrom pydantic import BaseModel, Field, field_validator, model_validator\n\n\nclass LoftQConfig(BaseModel):\n    \"\"\"LoftQ configuration subset\"\"\"\n\n    loftq_bits: int = Field(\n        default=4, json_schema_extra={\"description\": \"typically 4 bits\"}\n    )\n    # loftq_iter: int = Field(default=1, json_schema_extra={\"description\": \"Alternating iterations for LoftQ\"})\n\n\nclass PeftConfig(BaseModel):\n    \"\"\"peftq configuration subset\"\"\"\n\n    loftq_config: LoftQConfig | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Configuration options for loftq initialization for LoRA\"\n        },\n    )\n\n\nclass LoraConfig(BaseModel):\n    \"\"\"Peft / LoRA configuration subset\"\"\"\n\n    load_in_8bit: bool | None = Field(\n        default=False,\n        json_schema_extra={\n            \"description\": \"This will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer\"\n        },\n    )\n    load_in_4bit: bool | None = Field(\n        default=False, json_schema_extra={\"description\": \"Use bitsandbytes 4 bit\"}\n    )\n\n    adapter: Literal[\"lora\", \"qlora\", \"llama-adapter\"] | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"If you want to use 'lora', 'qlora', or 'llama-adapter', or leave blank to train all parameters in original model\"\n        },\n    )\n    lora_model_dir: str | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"If you already have a lora model trained that you want to load, put that here. This means after training, if you want to test the model, you should set this to the value of `output_dir`. Note that if you merge an adapter to the base model, a new subdirectory `merged` will be created under the `output_dir`.\"\n        },\n    )\n    lora_r: int | None = None\n    lora_alpha: int | None = None\n    lora_fan_in_fan_out: bool | None = None\n    lora_target_modules: str | list[str] | None = None\n    lora_target_parameters: str | list[str] | None = None\n    lora_target_linear: bool | None = Field(\n        default=None,\n        json_schema_extra={\"description\": \"If true, will target all linear modules\"},\n    )\n    lora_modules_to_save: list[str] | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"If you added new tokens to the tokenizer, you may need to save some LoRA modules because they need to know the new tokens. For LLaMA and Mistral, you need to save `embed_tokens` and `lm_head`. It may vary for other models. `embed_tokens` converts tokens to embeddings, and `lm_head` converts embeddings to token probabilities.\"\n        },\n    )\n    lora_dropout: float | None = 0.0\n    peft_layers_to_transform: list[int] | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"The layer indices to transform, otherwise, apply to all layers\"\n        },\n    )\n    peft_layers_pattern: list[str] | None = None\n    peft: PeftConfig | None = None\n    peft_use_dora: bool | None = Field(\n        default=None, json_schema_extra={\"description\": \"Whether to use DoRA.\"}\n    )\n    peft_use_rslora: bool | None = Field(\n        default=None, json_schema_extra={\"description\": \"Whether to use RSLoRA.\"}\n    )\n    peft_layer_replication: list[tuple[int, int]] | None = Field(\n        default=None,\n        json_schema_extra={\"description\": \"List of layer indices to replicate.\"},\n    )\n    peft_init_lora_weights: bool | str | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"How to initialize LoRA weights. Default to True which is MS original implementation.\"\n        },\n    )\n    peft_trainable_token_indices: list[int] | dict[str, list[int]] | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": (\n                \"A list of token indices to fine-tune on the `embed_tokens` layer.\\n\"\n                \"Otherwise, a dict mapping an embedding layer name to its trainable token indices.\\n\"\n                \"See https://huggingface.co/docs/peft/v0.17.0/en/developer_guides/lora#efficiently-train-tokens-alongside-lora\"\n            )\n        },\n    )\n    peft_ensure_weight_tying: bool | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": (\n                \"Whether to tie adapter weights for tied model weights. \"\n                \"See https://github.com/huggingface/peft/issues/2864\"\n            )\n        },\n    )\n    peft_autocast_adapter_dtype: bool | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Whether to upcast the LoRA adapter to fp32. This is enabled by default in PEFT.\"\n        },\n    )\n\n    qlora_sharded_model_loading: bool | None = Field(\n        default=False,\n        json_schema_extra={\n            \"description\": \"load qlora model in sharded format for FSDP using answer.ai technique.\"\n        },\n    )\n    lora_on_cpu: bool | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Do the LoRA/PEFT loading on CPU -- this is required if the base model is so large it takes up most or all of the available GPU VRAM, e.g. during a model and LoRA merge\"\n        },\n    )\n    gptq: bool | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Whether you are training a 4-bit GPTQ quantized model\"\n        },\n    )\n    bnb_config_kwargs: dict[str, Any] | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"optional overrides to the bnb 4bit quantization configuration\"\n        },\n    )\n\n    loraplus_lr_ratio: float | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"loraplus learning rate ratio lr_B / lr_A. Recommended value is 2^4.\"\n        },\n    )\n    loraplus_lr_embedding: float | None = Field(\n        default=1e-6,\n        json_schema_extra={\n            \"description\": \"loraplus learning rate for lora embedding layers. Default value is 1e-6.\"\n        },\n    )\n\n    merge_lora: bool | None = None\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def validate_adapter(cls, data):\n        if (\n            not data.get(\"adapter\")\n            and not data.get(\"inference\")\n            and (data.get(\"load_in_8bit\") or data.get(\"load_in_4bit\"))\n        ):\n            raise ValueError(\n                \"load_in_8bit and load_in_4bit are not supported without setting an adapter for training.\"\n                \"If you want to full finetune, please turn off load_in_8bit and load_in_4bit.\"\n            )\n        return data\n\n    @model_validator(mode=\"after\")\n    def validate_qlora(self):\n        if self.adapter == \"qlora\":\n            if self.merge_lora:\n                # can't merge qlora if loaded in 8bit or 4bit\n                if self.load_in_8bit:\n                    raise ValueError(\"Can't merge qlora if loaded in 8bit\")\n\n                if self.gptq:\n                    raise ValueError(\"Can't merge qlora if gptq\")\n\n                if self.load_in_4bit:\n                    raise ValueError(\"Can't merge qlora if loaded in 4bit\")\n\n            else:\n                if self.load_in_8bit:\n                    raise ValueError(\"Can't load qlora in 8bit\")\n\n                if self.gptq:\n                    raise ValueError(\"Can't load qlora if gptq\")\n\n                if not self.load_in_4bit:\n                    raise ValueError(\"Require cfg.load_in_4bit to be True for qlora\")\n        return self\n\n    @field_validator(\"loraplus_lr_embedding\")\n    @classmethod\n    def convert_loraplus_lr_embedding(cls, loraplus_lr_embedding):\n        if loraplus_lr_embedding and isinstance(loraplus_lr_embedding, str):\n            loraplus_lr_embedding = float(loraplus_lr_embedding)\n        return loraplus_lr_embedding\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def validate_lora_dropout(cls, data):\n        if data.get(\"adapter\") is not None and data.get(\"lora_dropout\") is None:\n            data[\"lora_dropout\"] = 0.0\n        return data\n\n    @model_validator(mode=\"after\")\n    def validate_lora_target_parameters_dropout(self):\n        if (\n            self.lora_target_parameters\n            and self.lora_dropout\n            and self.lora_dropout != 0.0\n        ):\n            raise ValueError(\n                \"lora_dropout must be 0 when lora_target_parameters is set. \"\n                \"PEFT's ParamWrapper does not support lora_dropout != 0.\"\n            )\n        return self\n\n\nclass ReLoRAConfig(BaseModel):\n    \"\"\"ReLoRA configuration subset\"\"\"\n\n    relora: bool | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Whether to use ReLoRA. Use with jagged_restart_*steps options.\"\n        },\n    )\n    relora_prune_ratio: float | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"threshold for optimizer magnitude when pruning\"\n        },\n    )\n    relora_cpu_offload: bool | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"True to perform lora weight merges on cpu during restarts, for modest gpu memory savings\"\n        },\n    )\n"
  },
  {
    "path": "src/axolotl/utils/schemas/quantization.py",
    "content": "\"\"\"\nQAT Config Schema\n\"\"\"\n\nfrom typing import Any\n\nfrom pydantic import BaseModel, Field, field_validator\n\nfrom axolotl.utils.schemas.enums import TorchAOQuantDType\n\n\ndef validate_ao_dtype(v: Any) -> TorchAOQuantDType | None:\n    if v is None:\n        return None\n    if v == \"int4\":\n        return TorchAOQuantDType.int4\n    if v == \"int8\":\n        return TorchAOQuantDType.int8\n    if v in [\"float8_e4m3fn\", \"fp8\", \"float8\"]:\n        return TorchAOQuantDType.float8_e4m3fn\n    if v == \"nvfp4\":\n        return TorchAOQuantDType.nvfp4\n    if v == \"mxfp4\":\n        return TorchAOQuantDType.mxfp4\n\n    raise ValueError(\n        f\"Invalid dtype: '{v}'. Must be one of: {[e.name for e in TorchAOQuantDType] + ['fp8', 'float8']}\"\n    )\n\n\nclass QATConfig(BaseModel):\n    \"\"\"\n    QAT Config Schema\n    \"\"\"\n\n    activation_dtype: TorchAOQuantDType | None = Field(\n        default=None,\n        description=\"Fake quantization layout to use for activation quantization.\",\n    )\n    weight_dtype: TorchAOQuantDType = Field(\n        default=TorchAOQuantDType.int8,\n        description=\"Fake quantization layout to use for weight quantization.\",\n    )\n    quantize_embedding: bool | None = Field(\n        default=False, description=\"Quantize embedding\"\n    )\n    group_size: int | None = Field(\n        default=32,\n        description=\"The number of elements in each group for per-group fake quantization\",\n    )\n    fake_quant_after_n_steps: int | None = Field(\n        default=None, description=\"The number of steps to apply fake quantization after\"\n    )\n\n    @field_validator(\"activation_dtype\", \"weight_dtype\", mode=\"before\")\n    @classmethod\n    def validate_dtype(cls, v: Any) -> TorchAOQuantDType | None:\n        return validate_ao_dtype(v)\n\n\nclass PTQConfig(BaseModel):\n    \"\"\"\n    PTQ Config Schema\n    \"\"\"\n\n    weight_dtype: TorchAOQuantDType = Field(\n        default=TorchAOQuantDType.int8,\n        description=\"Fake quantization layout to use for weight quantization.\",\n    )\n    activation_dtype: TorchAOQuantDType | None = Field(\n        default=None,\n        description=\"Fake quantization layout to use for activation quantization.\",\n    )\n    quantize_embedding: bool | None = Field(\n        default=None, description=\"Whether to quantize the embedding layer.\"\n    )\n    group_size: int | None = Field(\n        default=32,\n        description=\"The number of elements in each group for per-group fake quantization\",\n    )\n\n    @field_validator(\"activation_dtype\", \"weight_dtype\", mode=\"before\")\n    @classmethod\n    def validate_dtype(cls, v: Any) -> TorchAOQuantDType | None:\n        return validate_ao_dtype(v)\n"
  },
  {
    "path": "src/axolotl/utils/schemas/training.py",
    "content": "\"\"\"Pydantic models for training hyperparameters\"\"\"\n\nfrom typing import Any, Literal\n\nfrom pydantic import BaseModel, Field, field_validator\nfrom transformers import SchedulerType\nfrom transformers.training_args import OptimizerNames\n\nfrom axolotl.utils.logging import get_logger\nfrom axolotl.utils.schemas.enums import CustomSupportedOptimizers\n\nLOG = get_logger(__name__)\n\n\nclass LrGroup(BaseModel):\n    \"\"\"Custom learning rate group configuration\"\"\"\n\n    name: str\n    modules: list[str]\n    lr: float\n\n\nclass HyperparametersConfig(BaseModel):\n    \"\"\"Training hyperparams configuration subset\"\"\"\n\n    gradient_accumulation_steps: int | None = Field(\n        default=1,\n        json_schema_extra={\n            \"description\": \"If greater than 1, backpropagation will be skipped and the gradients will be accumulated for the given number of steps.\"\n        },\n    )\n    micro_batch_size: int | None = Field(\n        default=1,\n        json_schema_extra={\n            \"description\": \"The number of samples to include in each batch. This is the number of samples sent to each GPU. Batch size per gpu = micro_batch_size * gradient_accumulation_steps\"\n        },\n    )\n    batch_size: int | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Total batch size, we do not recommended setting this manually\"\n        },\n    )\n    eval_batch_size: int | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"per gpu micro batch size for evals, defaults to value of micro_batch_size\"\n        },\n    )\n\n    auto_find_batch_size: bool | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"whether to find batch size that fits in memory. Passed to underlying transformers Trainer\"\n        },\n    )\n\n    train_on_inputs: bool | None = Field(\n        default=False,\n        json_schema_extra={\n            \"description\": \"Whether to mask out or include the human's prompt from the training labels\"\n        },\n    )\n    group_by_length: bool | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Group similarly sized data to minimize padding. May be slower to start, as it must download and sort the entire dataset. Note that training loss may have an oscillating pattern with this enabled.\"\n        },\n    )\n\n    learning_rate: str | float\n    embedding_lr: float | None = None\n    embedding_lr_scale: float | None = None\n    weight_decay: float | None = Field(\n        default=0.0, json_schema_extra={\"description\": \"Specify weight decay\"}\n    )\n    optimizer: (OptimizerNames | CustomSupportedOptimizers) | None = Field(\n        default=OptimizerNames.ADAMW_TORCH_FUSED,\n        json_schema_extra={\"description\": \"Specify optimizer\"},\n    )\n    optim_args: (str | dict[str, Any]) | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Dictionary of arguments to pass to the optimizer\"\n        },\n    )\n    optim_target_modules: (list[str] | Literal[\"all_linear\"]) | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"The target modules to optimize, i.e. the module names that you would like to train, right now this is used only for GaLore algorithm\"\n        },\n    )\n    torchdistx_path: str | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Path to torch distx for optim 'adamw_anyprecision'\"\n        },\n    )\n    lr_scheduler: (\n        SchedulerType | Literal[\"one_cycle\"] | Literal[\"rex\"]\n    ) | None = SchedulerType.COSINE\n    lr_scheduler_kwargs: dict[str, Any] | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Specify a scheduler and kwargs to use with the optimizer\"\n        },\n    )\n    lr_quadratic_warmup: bool | None = None\n    cosine_min_lr_ratio: float | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"decay lr to some percentage of the peak lr, e.g. cosine_min_lr_ratio=0.1 for 10% of peak lr\"\n        },\n    )\n    cosine_constant_lr_ratio: float | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"freeze lr at some percentage of the step, e.g. cosine_constant_lr_ratio=0.8 means start cosine_min_lr at 80% of training step\"\n        },\n    )\n    lr_div_factor: float | None = Field(\n        default=None, json_schema_extra={\"description\": \"Learning rate div factor\"}\n    )\n    lr_groups: list[LrGroup] | None = None\n\n    adam_epsilon: float | None = Field(\n        default=None, json_schema_extra={\"description\": \"adamw hyperparams\"}\n    )\n    adam_epsilon2: float | None = Field(\n        default=None, json_schema_extra={\"description\": \"only used for CAME Optimizer\"}\n    )\n    adam_beta1: float | None = Field(\n        default=None, json_schema_extra={\"description\": \"adamw hyperparams\"}\n    )\n    adam_beta2: float | None = Field(\n        default=None, json_schema_extra={\"description\": \"adamw hyperparams\"}\n    )\n    adam_beta3: float | None = Field(\n        default=None, json_schema_extra={\"description\": \"only used for CAME Optimizer\"}\n    )\n\n    dion_lr: float | None = Field(\n        default=None, json_schema_extra={\"description\": \"Dion Optimizer learning rate\"}\n    )\n    dion_momentum: float | None = Field(\n        default=None, json_schema_extra={\"description\": \"Dion Optimizer momentum\"}\n    )\n    dion_rank_fraction: float | None = Field(\n        default=1.0,\n        json_schema_extra={\n            \"description\": \"Dion Optimizer: r/d fraction for low-rank approximation. Used to compute the low-rank dimension.\"\n        },\n    )\n    dion_rank_multiple_of: int | None = Field(\n        default=1,\n        json_schema_extra={\n            \"description\": \"Dion Optimizer: Round up the low-rank dimension to a multiple of this number. This may be useful to ensure even sharding.\"\n        },\n    )\n\n    max_grad_norm: float | None = Field(\n        default=None, json_schema_extra={\"description\": \"Gradient clipping max norm\"}\n    )\n    num_epochs: float = Field(default=1.0)\n\n    @field_validator(\"batch_size\")\n    @classmethod\n    def hint_batch_size_set(cls, batch_size):\n        if batch_size:\n            LOG.warning(\n                \"%s\\n%s\",\n                \"batch_size is not recommended. Please use gradient_accumulation_steps instead.\",\n                \"To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.\",\n            )\n        return batch_size\n\n    @field_validator(\"learning_rate\")\n    @classmethod\n    def convert_learning_rate(cls, learning_rate):\n        if learning_rate and isinstance(learning_rate, str):\n            learning_rate = float(learning_rate)\n        return learning_rate\n\n\nclass JaggedLRConfig(BaseModel):\n    \"\"\"JaggedLR configuration subset, can be used w/ ReLoRA training\"\"\"\n\n    jagged_restart_steps: int | None = Field(\n        default=None,\n        json_schema_extra={\"description\": \"how often to reset for jagged restarts\"},\n    )\n    jagged_restart_warmup_steps: int | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"how many warmup steps to take after reset for jagged restarts\"\n        },\n    )\n    jagged_restart_anneal_steps: int | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"how many anneal steps to take before reset for jagged restarts\"\n        },\n    )\n"
  },
  {
    "path": "src/axolotl/utils/schemas/trl.py",
    "content": "\"\"\"Pydantic models for TRL trainer configuration\"\"\"\n\nfrom typing import Literal\n\nfrom pydantic import BaseModel, Field\n\n\nclass TRLConfig(BaseModel):\n    \"\"\"\n    Input args for TRL.\n    \"\"\"\n\n    beta: float | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Beta parameter for the RL training. Same as `rl_beta`. Use\"\n        },\n    )\n    max_completion_length: int | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Maximum length of the completion for RL training.\"\n        },\n    )\n\n    # GRPO specific args\n    # Ref: https://github.com/huggingface/trl/blob/26d86757a7c7e24e397ea44f57ecce6031dfac01/trl/trainer/grpo_config.py#L23\n    use_vllm: bool = Field(\n        default=False,\n        json_schema_extra={\"description\": \"Whether to use VLLM for RL training.\"},\n    )\n    vllm_mode: Literal[\"server\", \"colocate\"] | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"VLLM mode to use, one of 'server' or 'colocate'\"\n        },\n    )\n    vllm_server_host: str | None = Field(\n        default=\"0.0.0.0\",  # nosec B104\n        json_schema_extra={\"description\": \"Host of the vLLM server to connect to.\"},\n    )\n    vllm_server_port: int | None = Field(\n        default=8000,\n        json_schema_extra={\"description\": \"Port of the vLLM server to connect to.\"},\n    )\n    vllm_server_timeout: int | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Total timeout (in seconds) to wait for the vLLM server to respond.\"\n        },\n    )\n    vllm_guided_decoding_regex: str | None = Field(\n        default=None,\n        json_schema_extra={\"description\": \"Regex for vLLM guided decoding.\"},\n    )\n\n    reward_funcs: list[str] | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"List of reward functions to load. Paths must be importable from current dir.\"\n        },\n    )\n    reward_weights: list[float] | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"List of reward weights for the reward functions.\"\n        },\n    )\n    num_generations: int | None = Field(\n        default=None,\n        json_schema_extra={\"description\": \"Number of generations to sample.\"},\n    )\n    log_completions: bool | None = Field(\n        default=False,\n        json_schema_extra={\"description\": \"Whether to log completions.\"},\n    )\n    num_completions_to_print: int | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Number of completions to print when log_completions is True.\"\n        },\n    )\n    importance_sampling_level: Literal[\"sequence\", \"token\"] | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Controls whether importance sampling ratios are computed at the `'token'` or `'sequence'` level. \"\n            \"For GSPO, use `sequence`, default is None which corresponds to the original GRPO paper.\"\n        },\n    )\n\n    sync_ref_model: bool | None = Field(\n        default=False,\n        json_schema_extra={\"description\": \"Whether to sync the reference model.\"},\n    )\n    ref_model_mixup_alpha: float | None = Field(\n        default=0.9,\n        json_schema_extra={\"description\": \"Mixup alpha for the reference model.\"},\n    )\n    ref_model_sync_steps: int | None = Field(\n        default=64,\n        json_schema_extra={\"description\": \"Sync steps for the reference model.\"},\n    )\n    scale_rewards: bool = Field(\n        default=True,\n        json_schema_extra={\n            \"description\": \"Whether to scale rewards by their standard deviation.\"\n        },\n    )\n\n    temperature: float | None = Field(\n        default=None,\n        json_schema_extra={\"description\": \"Sampling temperature for the GRPO policy.\"},\n    )\n    top_p: float | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Top-p sampling probability for the generation policy.\"\n        },\n    )\n    top_k: int | None = Field(\n        default=None,\n        json_schema_extra={\"description\": \"Top-k sampling for the generation policy.\"},\n    )\n    min_p: float | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Minimum probability for the generation policy.\"\n        },\n    )\n    repetition_penalty: float | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Penalty for tokens that appear in prompt and generated text.\"\n        },\n    )\n    num_iterations: int | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Number of iterations per batch (μ) for GRPO.\"\n        },\n    )\n    epsilon: float | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Epsilon value for clipping in the GRPO algorithm.\"\n        },\n    )\n    epsilon_high: float | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Upper-bound epsilon value for clipping in the GRPO algorithm.\"\n        },\n    )\n    use_liger_loss: bool | None = Field(\n        default=None,\n        json_schema_extra={\"description\": \"Whether to use Liger loss for GRPO.\"},\n    )\n    loss_type: str | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Loss formulation to use. Supported values: grpo, bnpo, dr_grpo.\"\n        },\n    )\n    mask_truncated_completions: bool = Field(\n        default=False,\n        json_schema_extra={\n            \"description\": \"Whether to exclude truncated completions from loss calculation.\"\n        },\n    )\n    vllm_enable_sleep_mode: bool | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Enable sleep mode for vLLM to offload VRAM when idle\"\n        },\n    )\n    rollout_func: str | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Path to custom rollout function. Must be importable from current dir.\"\n        },\n    )\n    multi_objective_aggregation: (\n        Literal[\"sum_then_normalize\", \"normalize_then_sum\"] | None\n    ) = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Multi-objective reward aggregation strategy. \"\n            \"'sum_then_normalize' (GRPO default): weights and sums rewards first, then normalizes. \"\n            \"'normalize_then_sum' (GDPO): normalizes each reward independently, then sums.\"\n        },\n    )\n\n    # Async GRPO fields\n    use_data_producer: bool = Field(\n        default=False,\n        json_schema_extra={\n            \"description\": \"Use the GRPODataProducer protocol for online data generation.\"\n        },\n    )\n    async_prefetch: bool = Field(\n        default=False,\n        json_schema_extra={\n            \"description\": \"Generate rollouts in a background thread while training on the previous rollout.\"\n        },\n    )\n    prefetch_depth: int | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Number of rollouts to prefetch ahead of training.\"\n        },\n    )\n    vllm_sync_interval: int | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Sync model weights to vLLM every N optimizer steps (async mode only).\"\n        },\n    )\n    streaming_partial_batch: bool | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Score prompt groups incrementally instead of the full batch at once.\"\n        },\n    )\n    streaming_min_groups: int | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Minimum prompt groups to score per streaming chunk.\"\n        },\n    )\n    vllm_importance_sampling_correction: bool | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Apply IS correction for distribution mismatch between vLLM and training model.\"\n        },\n    )\n    vllm_importance_sampling_mode: (\n        Literal[\"token_truncate\", \"token_mask\", \"sequence_truncate\", \"sequence_mask\"]\n        | None\n    ) = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"IS mode: token_truncate, token_mask, sequence_truncate, or sequence_mask.\"\n        },\n    )\n    vllm_importance_sampling_cap: float | None = Field(\n        default=None,\n        json_schema_extra={\"description\": \"Cap C for IS ratio clipping/masking.\"},\n    )\n    off_policy_mask_threshold: float | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"KL threshold for off-policy sequence masking (OPSM). None = disabled.\"\n        },\n    )\n    use_bias_correction_kl: bool | None = Field(\n        default=None,\n        json_schema_extra={\"description\": \"Apply IS correction to KL divergence term.\"},\n    )\n\n    reward_num_workers: int = Field(\n        default=1,\n        json_schema_extra={\n            \"description\": \"Number of persistent subprocess workers for parallel reward computation. Each worker has its \"\n            \"own main thread so signal.alarm() (used by math_verify) works correctly. Work is sharded across \"\n            \"workers by prompt groups. Only used with use_data_producer=True and non-nn.Module reward functions.\"\n        },\n    )\n    replay_buffer_size: int = Field(\n        default=0,\n        json_schema_extra={\n            \"description\": \"[Experimental, disabled by default] Size of the replay buffer for storing high-signal rollout \"\n            \"groups. When > 0, groups with reward variance are cached and used to replace zero-signal groups \"\n            \"(where all rewards are identical). Set to 0 to disable. Only used with use_data_producer=True.\"\n        },\n    )\n    replay_recompute_logps: bool = Field(\n        default=True,\n        json_schema_extra={\n            \"description\": \"When True (default), recompute old_per_token_logps for replayed groups using the current \"\n            \"training model. This fixes the importance sampling mismatch that occurs when replaying stale data. \"\n            \"Only relevant when replay_buffer_size > 0.\"\n        },\n    )\n    reroll_start_fraction: float = Field(\n        default=1.0,\n        json_schema_extra={\n            \"description\": \"Fraction of total training steps after which deferred re-rolling begins. Zero-signal prompts \"\n            \"(where all rewards in a group are identical) are buffered and re-injected into later batches when the \"\n            \"model is more likely to solve them. Set to 1.0 to disable. Only used with use_data_producer=True.\"\n        },\n    )\n    reroll_max_groups: int = Field(\n        default=1,\n        json_schema_extra={\n            \"description\": \"Maximum number of prompt groups to replace with re-roll candidates per batch. Higher values \"\n            \"increase data utilization but reduce prompt diversity. Only used with use_data_producer=True.\"\n        },\n    )\n    skip_zero_advantage_batches: bool = Field(\n        default=True,\n        json_schema_extra={\n            \"description\": \"When True, skip gradient computation for micro-batches where all advantages are zero (no learning \"\n            \"signal). This avoids the forward/backward pass entirely when no learning signal is present. The step is \"\n            \"logged with skipped_zero_adv_batches=1 for monitoring.\"\n        },\n    )\n    vllm_lora_sync: bool = Field(\n        default=False,\n        json_schema_extra={\n            \"description\": \"Sync LoRA adapter to vLLM via filesystem instead of merging + NCCL broadcast. \"\n            \"Auto-selects vllm_serve_lora serve module. Syncs only LoRA adapter weights vs full merged model.\"\n        },\n    )\n"
  },
  {
    "path": "src/axolotl/utils/schemas/utils.py",
    "content": "\"\"\"Utilities for Axolotl Pydantic models\"\"\"\n\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\n\ndef handle_legacy_message_fields_logic(data: dict) -> dict:\n    \"\"\"\n    Handle backwards compatibility between legacy message field mapping and new property mapping system.\n\n    Previously, the config only supported mapping 'role' and 'content' fields via dedicated config options:\n    - message_field_role: Mapped to the role field\n    - message_field_content: Mapped to the content field\n\n    The new system uses message_property_mappings to support arbitrary field mappings:\n    message_property_mappings:\n        role: source_role_field\n        content: source_content_field\n        additional_field: source_field\n\n    Args:\n        data: Dictionary containing configuration data\n\n    Returns:\n        Updated dictionary with message field mappings consolidated\n\n    Raises:\n        ValueError: If there are conflicts between legacy and new mappings\n    \"\"\"\n    data = data.copy()  # Create a copy to avoid modifying the original\n\n    if data.get(\"message_property_mappings\") is None:\n        data[\"message_property_mappings\"] = {}\n\n    # Check for conflicts and handle role\n    if \"message_field_role\" in data:\n        LOG.warning(\n            \"message_field_role is deprecated, use message_property_mappings instead. \"\n            f\"Example: message_property_mappings: {{role: {data['message_field_role']}}}\"\n        )\n        if (\n            \"role\" in data[\"message_property_mappings\"]\n            and data[\"message_property_mappings\"][\"role\"] != data[\"message_field_role\"]\n        ):\n            raise ValueError(\n                f\"Conflicting message role fields: message_field_role='{data['message_field_role']}' \"\n                f\"conflicts with message_property_mappings.role='{data['message_property_mappings']['role']}'\"\n            )\n        data[\"message_property_mappings\"][\"role\"] = data[\"message_field_role\"] or \"role\"\n\n        del data[\"message_field_role\"]\n    elif \"role\" not in data[\"message_property_mappings\"]:\n        data[\"message_property_mappings\"][\"role\"] = \"role\"\n\n    # Check for conflicts and handle content\n    if \"message_field_content\" in data:\n        LOG.warning(\n            \"message_field_content is deprecated, use message_property_mappings instead. \"\n            f\"Example: message_property_mappings: {{content: {data['message_field_content']}}}\"\n        )\n        if (\n            \"content\" in data[\"message_property_mappings\"]\n            and data[\"message_property_mappings\"][\"content\"]\n            != data[\"message_field_content\"]\n        ):\n            raise ValueError(\n                f\"Conflicting message content fields: message_field_content='{data['message_field_content']}' \"\n                f\"conflicts with message_property_mappings.content='{data['message_property_mappings']['content']}'\"\n            )\n        data[\"message_property_mappings\"][\"content\"] = (\n            data[\"message_field_content\"] or \"content\"\n        )\n\n        del data[\"message_field_content\"]\n    elif \"content\" not in data[\"message_property_mappings\"]:\n        data[\"message_property_mappings\"][\"content\"] = \"content\"\n\n    return data\n"
  },
  {
    "path": "src/axolotl/utils/schemas/validation.py",
    "content": "\"\"\"Module with validation methods for config pydantic model.\"\"\"\n\nimport json\nimport sys\nimport tempfile\nfrom pathlib import Path\n\nfrom pydantic import (\n    field_validator,\n    model_validator,\n)\nfrom transformers.utils.import_utils import is_torch_npu_available\n\nfrom axolotl.utils.logging import get_logger\nfrom axolotl.utils.schemas.enums import ChatTemplate, RingAttnFunc, RLType\n\nLOG = get_logger(__name__)\n\nSUPPORTED_METRICS = {\"sacrebleu\", \"comet\", \"ter\", \"chrf\", \"perplexity\"}\n\n\nclass DatasetValidationMixin:\n    \"\"\"Validation methods related to dataset configuration.\"\"\"\n\n    @field_validator(\"seed\", mode=\"after\")\n    @classmethod\n    def set_default_seed(cls, seed):\n        if seed is None:\n            LOG.info(\"`seed` not set in config; setting to 42\")\n            seed = 42\n        return seed\n\n    @field_validator(\"datasets\", mode=\"before\")\n    @classmethod\n    def deprecate_sharegpt_datasets(cls, datasets):\n        for _, ds_cfg in enumerate(datasets):\n            ds_type = (\n                ds_cfg.get(\"type\")\n                if isinstance(ds_cfg, dict)\n                else getattr(ds_cfg, \"type\", None)\n            )\n            if not ds_type:\n                continue\n\n            if isinstance(ds_type, dict):\n                continue\n\n            if isinstance(ds_type, str) and ds_type.startswith(\"sharegpt\"):\n                raise ValueError(\n                    \"`type: sharegpt.*` is deprecated. Please use `type: chat_template` instead.\"\n                )\n\n        return datasets\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_dataset_or_pretraining_dataset(cls, data):\n        if data.get(\"datasets\") is None and data.get(\"pretraining_dataset\") is None:\n            raise ValueError(\"either datasets or pretraining_dataset is required\")\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_pretraining_streaming_deprecation(cls, data):\n        # TODO(djsaunde): remove this check + implement change for 0.13.0 release\n        if data.get(\"pretraining_dataset\") and not data.get(\"streaming\"):\n            LOG.warning(\n                \"Setting `pretraining_dataset` without explicitly setting `streaming: \"\n                \"true` is deprecated. In a future release, streaming will not be \"\n                \"automatically enabled when using pretraining_dataset. Please \"\n                \"explicitly set `streaming: true` in your configuration to maintain \"\n                \"current behavior.\"\n            )\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_push_ds_auth(cls, data):\n        if (\n            data.get(\"push_dataset_to_hub\")\n            and data.get(\"hf_use_auth_token\") is not True\n        ):\n            raise ValueError(\n                \"Require cfg.hf_use_auth_token to be True for push_dataset_to_hub\"\n            )\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_val_w_test_datasets(cls, data):\n        if data.get(\"test_datasets\") and data.get(\"val_set_size\"):\n            raise ValueError(\n                \"non-zero val_set_size should not be used with test_datasets configuration\"\n            )\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_test_datasets_bench(cls, data):\n        if (\n            data.get(\"do_bench_eval\")\n            and not data.get(\"test_datasets\")\n            and not data.get(\"val_set_size\")\n        ):\n            LOG.warning(\n                \"`do_bench_eval` needs a test dataset to run evals, adding an empty test_dataset.\"\n            )\n            data[\"test_datasets\"] = [{\"path\": \"axolotl-ai-co/empty-test-ds\"}]\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_eval_packing(cls, data):\n        # TODO also should check test_datasets and val_set_size as we can skip\n        # if there are no eval datasets/splits\n        if (\n            data.get(\"sample_packing\")\n            and data.get(\"eval_table_size\")\n            and data.get(\"eval_sample_packing\") is not False\n        ):\n            raise ValueError(\n                \"eval_table_size and eval_sample_packing are not supported together with sample_packing. Please set 'eval_sample_packing' to false.\"\n            )\n        if (\n            data.get(\"sample_packing\")\n            and data.get(\"eval_sample_packing\") is None\n            and not data.get(\"eval_table_size\")\n        ):\n            LOG.info(\n                \"explicitly setting `eval_sample_packing` to match `sample_packing`\",\n            )\n            data[\"eval_sample_packing\"] = True\n\n        if (\n            data.get(\"sample_packing\")\n            and data.get(\"eval_sample_packing\") is False\n            and data.get(\"remove_unused_columns\") is None\n        ):\n            LOG.info(\n                \"setting `remove_unused_columns: false` for when sample_packing and eval_sample_packing don't match\"\n            )\n            data[\"remove_unused_columns\"] = False\n\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_mm_prepare(cls, data):\n        if data.get(\"skip_prepare_dataset\"):\n            if data.get(\"remove_unused_columns\") is None:\n                LOG.info(\n                    \"setting `remove_unused_columns: false` for skip_prepare_dataset\"\n                )\n                data[\"remove_unused_columns\"] = False\n\n        return data\n\n\nclass AttentionValidationMixin:\n    \"\"\"Validation methods related to attention mechanisms.\"\"\"\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_attention_fields(cls, data):\n        fields = (\n            \"xformers_attention\",\n            \"sdp_attention\",\n            # \"s2_attention\",  # requires both FA and this to be enabled\n            \"flash_attention\",\n            \"flex_attention\",\n            \"sage_attention\",\n        )\n        non_empty_count = sum(1 for field in fields if data.get(field))\n\n        if non_empty_count > 1:\n            raise ValueError(f\"Only one of {', '.join(fields)} must be set\")\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_sample_packing_without_attention(cls, data):\n        if (\n            data.get(\"sample_packing\")\n            and not data.get(\"flash_attention\")\n            and not data.get(\"sdp_attention\")\n            and not data.get(\"flex_attention\")\n            and not data.get(\"xformers_attention\")\n            and not data.get(\"sage_attention\")\n        ):\n            LOG.warning(\n                \"sample_packing without flash, sdp, xformers, sage, or flex attention does not handle cross sample decontamination.\"\n            )\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_sample_packing_with_s2attn(cls, data):\n        if data.get(\"sample_packing\") and data.get(\"s2_attention\"):\n            raise ValueError(\n                \"Received `sample_packing=true` and `s2_attention=true`; however, \\\n                shifted-sparse attention does not currently support sample packing.\"\n            )\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_scaling_softmax_requires_flex(cls, data):\n        if data.get(\"scaling_softmax\") and not data.get(\"flex_attention\"):\n            raise ValueError(\n                \"scaling_softmax requires flex_attention: true\\n\"\n                \"Add 'flex_attention: true' to your config file.\\n\"\n            )\n        return data\n\n\nclass TrainingValidationMixin:\n    \"\"\"Validation methods related to training configuration.\"\"\"\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_batch_size_fields(cls, data):\n        fields = (\"micro_batch_size\", \"gradient_accumulation_steps\", \"batch_size\")\n        non_empty_count = sum(1 for field in fields if data.get(field))\n\n        if non_empty_count < 2:\n            raise ValueError(f\"At least two of {', '.join(fields)} must be set\")\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def hint_sample_packing_padding(cls, data):\n        if data.get(\"sample_packing\"):\n            pad_to_sequence_len = data.get(\"pad_to_sequence_len\")\n            if pad_to_sequence_len is False:\n                LOG.warning(\n                    \"`pad_to_sequence_len: true` is recommended when using sample_packing\"\n                )\n            elif pad_to_sequence_len is None:\n                LOG.info(\n                    \"Setting `pad_to_sequence_len: true` to prevent memory leaks when sample_packing\"\n                )\n                data[\"pad_to_sequence_len\"] = True\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def hint_reward_model_pad(cls, data):\n        if data.get(\"reward_model\") and not data.get(\"pad_to_sequence_len\"):\n            LOG.warning(\n                \"`pad_to_sequence_len: true` is recommended when using reward_model\"\n            )\n            if data.get(\"pad_to_sequence_len\") is None:\n                data[\"pad_to_sequence_len\"] = True\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def set_reward_model_defaults(cls, data):\n        if data.get(\"reward_model\"):\n            if data.get(\"num_labels\") is None:\n                data[\"num_labels\"] = 1\n            if not (data.get(\"type_of_model\") or data.get(\"model_type\")):\n                data[\"model_type\"] = \"AutoModelForSequenceClassification\"\n\n        if data.get(\"process_reward_model\"):\n            if data.get(\"num_labels\") is None:\n                data[\"num_labels\"] = 2\n            if not (data.get(\"type_of_model\") or data.get(\"model_type\")):\n                data[\"model_type\"] = \"AutoModelForTokenClassification\"\n\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_gas_bsz(cls, data):\n        if data.get(\"gradient_accumulation_steps\") and data.get(\"batch_size\"):\n            raise ValueError(\n                \"please set only one of gradient_accumulation_steps or batch_size\"\n            )\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def hint_eval_train_mbsz(cls, data):\n        if (\n            data.get(\"eval_batch_size\")\n            and data.get(\"micro_batch_size\")\n            and data.get(\"eval_batch_size\") != data.get(\"micro_batch_size\")\n        ):\n            LOG.warning(\n                \"eval_batch_size != micro_batch_size. This can lead to VRAM instability.\"\n            )\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_warmup(cls, data):\n        if data.get(\"warmup_steps\") and data.get(\"warmup_ratio\"):\n            raise ValueError(\"warmup_steps and warmup_ratio are mutually exclusive\")\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_saves(cls, data):\n        if (\n            data.get(\"save_strategy\")\n            and data.get(\"save_steps\")\n            and data.get(\"save_strategy\") != \"steps\"\n        ):\n            raise ValueError(\n                \"save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps.\"\n            )\n        if data.get(\"saves_per_epoch\") and data.get(\"save_steps\"):\n            raise ValueError(\n                \"save_steps and saves_per_epoch are mutually exclusive and cannot be used together.\"\n            )\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_push_save(cls, data):\n        if data.get(\"hub_model_id\") and (\n            data.get(\"save_strategy\") not in [\"steps\", \"epoch\", None]\n        ):\n            LOG.warning(\n                \"hub_model_id is set without any models being saved. To save a model, set save_strategy.\"\n            )\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_evals(cls, data):\n        if (\n            data.get(\"eval_strategy\")\n            and data.get(\"eval_steps\")\n            and data.get(\"eval_strategy\") != \"steps\"\n        ):\n            raise ValueError(\n                \"eval_strategy and eval_steps mismatch. Please set eval_strategy to 'steps' or remove eval_steps.\"\n            )\n\n        if (\n            data.get(\"val_set_size\") == 0\n            and (data.get(\"eval_steps\") or data.get(\"eval_strategy\"))\n            and not data.get(\"test_datasets\")\n            and data.get(\"eval_strategy\") != \"no\"\n        ):\n            raise ValueError(\n                \"eval_steps and eval_strategy are not supported with val_set_size == 0\"\n            )\n        if data.get(\"evals_per_epoch\") and data.get(\"eval_steps\"):\n            raise ValueError(\n                \"eval_steps and evals_per_epoch are mutually exclusive and cannot be used together.\"\n            )\n        if (\n            data.get(\"evals_per_epoch\")\n            and data.get(\"eval_strategy\")\n            and data.get(\"eval_strategy\") != \"steps\"\n        ):\n            raise ValueError(\n                \"eval_strategy must be empty or set to `steps` when used with evals_per_epoch.\"\n            )\n\n        if data.get(\"do_bench_eval\") and not (\n            data.get(\"evals_per_epoch\") or data.get(\"eval_steps\")\n        ):\n            raise ValueError(\n                \"do_bench_eval requires evals_per_epoch or eval_steps to be set.\"\n            )\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_neftune(cls, data):\n        if data.get(\"noisy_embedding_alpha\") and not data.get(\"neftune_noise_alpha\"):\n            data[\"neftune_noise_alpha\"] = data[\"noisy_embedding_alpha\"]\n            del data[\"noisy_embedding_alpha\"]\n        elif data.get(\"noisy_embedding_alpha\") and data.get(\"neftune_noise_alpha\"):\n            raise ValueError(\n                \"noisy_embedding_alpha is deprecated, use neftune_noise_alpha; both are set, please remove the deprecated noisy_embedding_alpha setting\"\n            )\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_multipack_buffer_size(cls, data):\n        if data.get(\"pretrain_multipack_buffer_size\") and not data.get(\n            \"streaming_multipack_buffer_size\"\n        ):\n            LOG.warning(\n                \"`pretrain_multipack_buffer_size` is deprecated in v0.13.0, will be \"\n                \"removed in v0.14.0. Use `streaming_multipack_buffer_size` instead.\"\n            )\n            data[\"streaming_multipack_buffer_size\"] = data[\n                \"pretrain_multipack_buffer_size\"\n            ]\n            del data[\"pretrain_multipack_buffer_size\"]\n        elif data.get(\"pretrain_multipack_buffer_size\") and data.get(\n            \"streaming_multipack_buffer_size\"\n        ):\n            raise ValueError(\n                \"pretrain_multipack_buffer_size is deprecated, use \"\n                \"streaming_multipack_buffer_size; both are set, please remove the \"\n                \"deprecated pretrain_multipack_buffer_size setting\"\n            )\n        return data\n\n    @model_validator(mode=\"after\")\n    def check_fft_possible_bad_config(self):\n        if (\n            not (self.bf16 or self.bfloat16)\n            and (self.fp16 or self.float16)\n            and not self.adapter\n            and not self.flash_attention\n            and self.sample_packing\n        ):\n            LOG.warning(\n                \"Full fine tune w/o FA2 w/ sample packing and fp16/float16 is likely to raise errors. Try LoRA.\"\n            )\n            # ValueError: Attempting to unscale FP16 gradients.\n            # OR\n            # RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::Half\n        return self\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_fp8_config(cls, data):\n        if data.get(\"fp8\") and not data.get(\"torch_compile\"):\n            LOG.warning(\n                \"torch_compile is strongly recommended for FP8 training in order to \"\n                \"see speed improvements. Please consider setting `torch_compile: \"\n                \"true` in your config.\"\n            )\n        fsdp_config = data.get(\"fsdp_config\") or {}\n        if data.get(\"fp8\") and (\n            fsdp_config.get(\"activation_checkpointing\", False) is True\n            or fsdp_config.get(\"fsdp_activation_checkpointing\", False) is True\n        ):\n            LOG.warning(\n                \"FP8 + FSDP2 + activation checkpointing may be slower than BF16 \"\n                \"training. Please considering setting `activation_checkpointing: false` \"\n                \"in your FSDP config.\"\n            )\n        if (\n            data.get(\"fp8_enable_fsdp_float8_all_gather\")\n            and not data.get(\"fsdp_version\", None) == 2\n        ):\n            raise ValueError(\n                \"fp8_enable_fsdp_float8_all_gather requires FSDP2 (fsdp_version: 2) \"\n                \"to be used.\"\n            )\n\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_use_reentrant_mismatch(cls, data):\n        if (\n            data.get(\"unfrozen_parameters\")\n            and data.get(\"gradient_checkpointing_kwargs\")\n            and data.get(\"gradient_checkpointing_kwargs\", {}).get(\"use_reentrant\")\n            is True\n        ):\n            # https://github.com/huggingface/transformers/issues/21381\n            raise ValueError(\n                \"`use_reentrant` must be false when used with partially frozen model.\"\n            )\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_eval_strategy(cls, data):\n        if (\n            data.get(\"evaluation_strategy\") is not None\n            and data.get(\"eval_strategy\") is None\n        ):\n            LOG.info(\n                \"explicitly setting `eval_strategy` from the `evaluation_strategy`\"\n            )\n            data[\"eval_strategy\"] = data.get(\"evaluation_strategy\")\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_causal_lm_evals(cls, data):\n        if data.get(\"do_causal_lm_eval\") and data.get(\"eval_sample_packing\"):\n            raise ValueError(\n                \"do_causal_lm_eval is enabled, eval_sample_packing must be set to False\"\n            )\n\n        if data.get(\"eval_causal_lm_metrics\"):\n            if not isinstance(data.get(\"eval_causal_lm_metrics\"), list):\n                raise ValueError(\"eval_causal_lm_metrics must be a list\")\n            # only [\"sacrebleu\", \"comet\", \"ter\", \"chrf\"] supported\n            if set(data.get(\"eval_causal_lm_metrics\")) - SUPPORTED_METRICS:\n                raise ValueError(\n                    f\"eval_causal_lm_metrics must be one of {SUPPORTED_METRICS}\"\n                )\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_tokenizer_use_mistral_common(cls, data):\n        if data.get(\"tokenizer_use_mistral_common\") is None:\n            if any(\n                \"magistral\" in name.lower()\n                for name in [\n                    data.get(\"base_model\", \"\"),\n                    data.get(\"base_model_config\", \"\"),\n                    data.get(\"tokenizer_config\", \"\"),\n                ]\n            ):\n                LOG.warning(\n                    \"tokenizer_use_mistral_common auto inferred to True for Magistral models. Please set it to True explicitly if you want to use mistral-common tokenizer.\"\n                )\n                data[\"tokenizer_use_mistral_common\"] = True\n\n        return data\n\n    @field_validator(\"tokenizer_use_mistral_common\", mode=\"after\")\n    @classmethod\n    def check_mistral_common_import(cls, tokenizer_use_mistral_common):\n        if tokenizer_use_mistral_common:\n            import importlib.util\n\n            if importlib.util.find_spec(\"mistral_common\") is None:\n                raise ImportError(\n                    \"mistral-common is required for mistral models. Please install it with `pip install axolotl` or `pip install -e .`.\"\n                )\n\n        return tokenizer_use_mistral_common\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_mistral_common_incompatible_options(cls, data):\n        if not data.get(\"tokenizer_use_mistral_common\"):\n            return data\n\n        # NOTE: mistral-common tokenizer is not compatible with editing tokenizer at the moment\n\n        if data.get(\"added_tokens_overrides\"):\n            raise ValueError(\n                \"added_tokens_overrides is not supported with mistral-common tokenizer\"\n            )\n\n        if data.get(\"special_tokens\"):\n            raise ValueError(\n                \"special_tokens override is not supported with mistral-common tokenizer\"\n            )\n\n        if data.get(\"tokens\"):\n            raise ValueError(\n                \"tokens override is not supported with mistral-common tokenizer\"\n            )\n\n        if data.get(\"chat_template\"):\n            raise ValueError(\n                \"Setting chat_template is not supported with mistral-common tokenizer\"\n            )\n\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def pretrain_with_tps(cls, data):\n        if data.get(\"pretraining_dataset\") and data.get(\n            \"include_tokens_per_second\", False\n        ):\n            # combining these would raise `TypeError: cannot pickle 'dict_keys' object`\n            # due to trying to count the number of tokens total in the dataset\n            raise ValueError(\n                \"pretraining_dataset and include_tokens_per_second cannot be used together.\"\n            )\n\n        return data\n\n\nclass LoRAValidationMixin:\n    \"\"\"Validation methods related to LoRA/QLoRA configuration.\"\"\"\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_lr_groups(cls, data):\n        if data.get(\"lr_groups\") and data.get(\"loraplus_lr_ratio\"):\n            raise ValueError(\"lr_groups and loraplus_lr_ratio cannot be used together.\")\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_frozen(cls, data):\n        if (\n            data.get(\"adapter\")\n            and data.get(\"peft_layers_to_transform\")\n            and data.get(\"unfrozen_parameters\")\n        ):\n            raise ValueError(\n                \"`unfrozen_parameters` used with `peft_layers_to_transform` can have unexpected behavior.\"\n            )\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_peft_layers_pattern(cls, data):\n        if data.get(\"peft_layers_pattern\") and not data.get(\"peft_layers_to_transform\"):\n            raise ValueError(\n                \"peft_layers_pattern requires peft_layers_to_transform to be set\"\n            )\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_qlora_unsloth(cls, data):\n        if (\n            data.get(\"unsloth_lora_mlp\")\n            or data.get(\"unsloth_lora_qkv\")\n            or data.get(\"unsloth_lora_o\")\n        ):\n            if data.get(\"adapter\") == \"lora\" and data.get(\"load_in_8bit\"):\n                raise ValueError(\n                    \"unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with 8-bit LoRA\"\n                )\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_lora_axolotl_unsloth(cls, data):\n        is_lora_kernel = any(\n            data.get(k) for k in [\"lora_mlp_kernel\", \"lora_qkv_kernel\", \"lora_o_kernel\"]\n        )\n        is_unsloth_lora = any(\n            data.get(k)\n            for k in [\"unsloth_lora_mlp\", \"unsloth_lora_qkv\", \"unsloth_lora_o\"]\n        )\n        if is_lora_kernel and is_unsloth_lora:\n            raise ValueError(\n                \"both lora_mlp_kernel and unsloth_lora_mlp cannot be true (similarly for lora_qkv_kernel, lora_o_kernel)\"\n            )\n        return data\n\n    @model_validator(mode=\"after\")\n    def check_fused_lora(self):\n        if self.adapter in [\"lora\", \"qlora\"] and self.flash_attn_fuse_mlp:\n            raise ValueError(\"Fused modules are not supported with LoRA/QLoRA\")\n        return self\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def warn_qlora_zero3_w_use_reentrant(cls, data):\n        if (\n            data.get(\"adapter\") == \"qlora\"\n            and data.get(\"gradient_checkpointing_kwargs\", {})\n            and data.get(\"gradient_checkpointing_kwargs\", {}).get(\"use_reentrant\")\n            is False\n            and data.get(\"deepspeed\", \"\") is not None\n            and \"zero3\" in data.get(\"deepspeed\", \"\")\n        ):\n            # may result in:\n            # torch.utils.checkpoint.CheckpointError: torch.utils.checkpoint:\n            # Recomputed values for the following tensors have different metadata\n            # than during the forward pass.\n            LOG.warning(\n                \"qlora + zero3 with use_reentrant: false may result in a CheckpointError about recomputed values\"\n            )\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_lora_kernels_8bit(cls, data):\n        if (\n            data.get(\"lora_mlp_kernel\")\n            or data.get(\"lora_qkv_kernel\")\n            or data.get(\"lora_o_kernel\")\n        ):\n            if data.get(\"adapter\") == \"lora\" and data.get(\"load_in_8bit\"):\n                raise ValueError(\n                    \"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not \"\n                    \"compatible with 8-bit LoRA a the moment.\"\n                )\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_lora_kernels_dora(cls, data):\n        if (\n            data.get(\"lora_mlp_kernel\")\n            or data.get(\"lora_qkv_kernel\")\n            or data.get(\"lora_o_kernel\")\n        ) and data.get(\"peft_use_dora\"):\n            raise ValueError(\n                \"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not \"\n                \"compatible with DoRA at the moment.\"\n            )\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_lora_kernels_trust_remote_code(cls, data):\n        if (\n            data.get(\"lora_mlp_kernel\")\n            or data.get(\"lora_qkv_kernel\")\n            or data.get(\"lora_o_kernel\")\n        ) and data.get(\"trust_remote_code\"):\n            raise ValueError(\n                \"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not \"\n                \"compatible with trust_remote_code. Please disable trust_remote_code \"\n                \"or explicitly set lora_*_kernel to false.\"\n            )\n        return data\n\n\nclass RLValidationMixin:\n    \"\"\"Validation methods related to RL training configuration.\"\"\"\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_sample_packing_w_rl(cls, data):\n        if data.get(\"sample_packing\") and data.get(\"rl\"):\n            raise ValueError(\"`sample_packing: true` does not work with RLHF training\")\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_kto_config(cls, data):\n        if data.get(\"rl\") == \"kto\":\n            if data.get(\"sample_packing\") or data.get(\"eval_sample_packing\"):\n                raise ValueError(\"sample_packing is not supported with kto\")\n\n            if data.get(\"remove_unused_columns\") is not False:\n                raise ValueError(\"Set `remove_unused_columns: False` when using kto\")\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_grpo_liger_sequence_parallel(cls, data):\n        if (\n            data.get(\"rl\") == \"grpo\"\n            and data.get(\"trl\", {})\n            and data.get(\"trl\").get(\"use_liger_loss\")\n            and data.get(\"context_parallel_size\", 1) > 1\n        ):\n            raise ValueError(\"GRPO + SP + Liger not currently supported\")\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_rl_config_gradient_checkpointing(cls, data):\n        # TODO: SalmanMohammadi\n        # Distributed RL with QLoRA + gradient checkpointing\n        # and use_reentrant = True is broken upstream in TRL\n\n        if (\n            data.get(\"rl\")\n            and data.get(\"gradient_checkpointing\")\n            and data.get(\"gradient_checkpointing_kwargs\")\n            and data.get(\"gradient_checkpointing_kwargs\").get(\"use_reentrant\")\n            and data.get(\"load_in_4bit\")\n            and data.get(\"adapter\") == \"qlora\"\n            and data.get(\"capabilities\")\n            and data.get(\"capabilities\").get(\"n_gpu\", 1) > 1\n        ):\n            raise ValueError(\n                \"The `use_reentrant: True` implementation of gradient checkpointing \"\n                \"is not supported for distributed RL training with QLoRA. Please set \"\n                \"`use_reentrant: False` in `gradient_checkpointing_kwargs`.\"\n            )\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_gdpo(cls, data):\n        if (\n            data.get(\"rl\") == \"gdpo\"\n            and data.get(\"trl\", {}).get(\"multi_objective_aggregation\")\n            == \"sum_then_normalize\"\n        ):\n            raise ValueError(\n                \"`multi_objective_aggregation` value set as `sum_then_normalize` => GRPO, but GDPO was selected\"\n            )\n        return data\n\n\nclass OptimizationValidationMixin:\n    \"\"\"Validation methods related to optimization and performance.\"\"\"\n\n    @model_validator(mode=\"after\")\n    def check_adamw_optimizer_params(self):\n        if any([self.adam_beta1, self.adam_beta2, self.adam_epsilon]) and (\n            not self.optimizer or \"adamw\" not in str(self.optimizer).lower()\n        ):\n            LOG.warning(\"adamw hyperparameters found, but no adamw optimizer set\")\n        return self\n\n    @staticmethod\n    def _resolve_fsdp_version(data):\n        \"\"\"Resolve FSDP version from top-level fsdp_version or fsdp_config.fsdp_version.\"\"\"\n        fsdp_version = data.get(\"fsdp_version\")\n        if fsdp_version is None:\n            fsdp_version = data.get(\"fsdp_config\", {}).get(\"fsdp_version\", 1)\n        return fsdp_version\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_muon_deepspeed_fsdp(cls, data):\n        if data.get(\"optimizer\") == \"muon\":\n            if data.get(\"deepspeed\"):\n                raise ValueError(\n                    \"Muon optimizer is currently incompatible with DeepSpeed\"\n                )\n            if data.get(\"fsdp\") or data.get(\"fsdp_config\"):\n                fsdp_version = cls._resolve_fsdp_version(data)\n                if str(fsdp_version) != \"2\":\n                    raise ValueError(\n                        \"Muon optimizer is only compatible with FSDP2. Set fsdp_version: 2 to use Muon with FSDP.\"\n                    )\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_flashoptim_deepspeed_fsdp(cls, data):\n        optimizer = data.get(\"optimizer\") or \"\"\n        if str(optimizer).startswith(\"flash_\"):\n            if data.get(\"deepspeed\"):\n                raise ValueError(\n                    f\"{optimizer} optimizer is incompatible with DeepSpeed. \"\n                    \"Flash optimizers only support DDP and FSDP2.\"\n                )\n            if data.get(\"fsdp\") or data.get(\"fsdp_config\"):\n                fsdp_version = cls._resolve_fsdp_version(data)\n                if str(fsdp_version) != \"2\":\n                    raise ValueError(\n                        f\"{optimizer} optimizer is only compatible with FSDP2. \"\n                        \"Set fsdp_version: 2 to use flash optimizers with FSDP.\"\n                    )\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_batch_flattening_fa(cls, data):\n        if data.get(\"batch_flattening\"):\n            batch_flattening_auto = data.get(\"batch_flattening\") == \"auto\"\n            if not data.get(\"flash_attention\") and not batch_flattening_auto:\n                raise ValueError(\"batch_flattening requires flash attention\")\n            if data.get(\"sample_packing\") and not batch_flattening_auto:\n                raise ValueError(\"batch_flattening not compatible with sample_packing\")\n            if data.get(\"micro_batch_size\") == 1 and not batch_flattening_auto:\n                LOG.warning(\"batch_flattening has no effect with micro_batch_size == 1\")\n\n            if (\n                batch_flattening_auto\n                and data.get(\"flash_attention\")\n                and not data.get(\"sample_packing\")\n                and data.get(\"micro_batch_size\") > 1\n            ):\n                data[\"batch_flattening\"] = True\n            elif batch_flattening_auto:\n                data[\"batch_flattening\"] = False\n\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_xentropy_patch_conflicts(cls, data):\n        if data.get(\"flash_attn_cross_entropy\") and data.get(\n            \"unsloth_cross_entropy_loss\"\n        ):\n            raise ValueError(\n                \"flash_attn_cross_entropy and unsloth_cross_entropy_loss cannot be both enabled\"\n            )\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_cross_entropy_conflicts(cls, data):\n        \"\"\"Check for mutual exclusivity between cross entropy patch options.\n\n        Only one of the following can be enabled at a time:\n        - cut_cross_entropy (CutCrossEntropyPlugin)\n        - chunked_cross_entropy\n        - liger_cross_entropy (LigerPlugin)\n        - liger_fused_linear_cross_entropy (LigerPlugin)\n        \"\"\"\n        ce_options = {\n            \"cut_cross_entropy\": data.get(\"cut_cross_entropy\"),\n            \"chunked_cross_entropy\": data.get(\"chunked_cross_entropy\"),\n            \"liger_cross_entropy\": data.get(\"liger_cross_entropy\"),\n            \"liger_fused_linear_cross_entropy\": data.get(\n                \"liger_fused_linear_cross_entropy\"\n            ),\n        }\n\n        enabled_options = [k for k, v in ce_options.items() if v]\n\n        if len(enabled_options) > 1:\n            raise ValueError(\n                f\"Only one cross entropy optimization can be enabled at a time. \"\n                f\"Found {len(enabled_options)} enabled: {', '.join(enabled_options)}. \"\n                \"Please disable all but one.\"\n            )\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_fsdp_version(cls, data):\n        fsdp_config = data.get(\"fsdp_config\", {})\n        if fsdp_config and str(data.get(\"fsdp_version\")) != \"2\":\n            LOG.info(\n                \"FSDP1 will be deprecated in an upcoming release of Axolotl.\"\n                \"We recommend that you use FSDP version 2 for better performance and compatibility. \"\n                \"Please see this link for more details: https://docs.axolotl.ai/docs/multi-gpu.html#sec-fsdp \"\n                \"For more details on migrating your config. \"\n            )\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_fsdp2_cpu_offload_pin_memory(cls, data):\n        if not (fsdp_config := data.get(\"fsdp_config\")):\n            return data\n\n        if fsdp_config.get(\"cpu_offload_pin_memory\") is False:\n            if str(data.get(\"fsdp_version\")) != \"2\":\n                raise ValueError(\n                    \"FSDP1 does not support disabling cpu_offload_pin_memory, please set `fsdp_version` to 2\"\n                )\n            if not fsdp_config.get(\"offload_params\"):\n                raise ValueError(\n                    \"disabling cpu_offload_pin_memory requires enabling offload_params\"\n                )\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_fsdp2_base_model_quant_rl(cls, data):\n        if data.get(\"fsdp_version\") == 2 and data.get(\"rl\") in [\n            RLType.DPO,\n            RLType.KTO,\n            RLType.ORPO,\n            RLType.IPO,\n        ]:\n            if data.get(\"load_in_8bit\") or data.get(\"load_in_4bit\"):\n                raise ValueError(\n                    f\"FSDP2 does not support load_in_8bit or load_in_4bit with {data.get('rl')}. Please use DeepSpeed or set `fsdp_version` to 1.\"\n                )\n\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_fsdp_config_kwargs_prefix(cls, data):\n        if fsdp_config := data.get(\"fsdp_config\"):\n            should_fix = False\n            for key, _ in fsdp_config.items():\n                if key.startswith(\"fsdp_\"):\n                    should_fix = True\n                    LOG.warning_once(\n                        \"Configuring FSDP fields with the `fsdp_` prefix is deprecated. \"\n                        \"Please omit the `fsdp_` prefix from the any fields in `fsdp_config`.\"\n                    )\n            if should_fix:\n                update_fsdp_config = {}\n                for key, value in fsdp_config.items():\n                    if key.startswith(\"fsdp_\") and key != \"fsdp_version\":\n                        update_fsdp_config[key.replace(\"fsdp_\", \"\")] = value\n                    else:\n                        update_fsdp_config[key] = value\n                data[\"fsdp_config\"] = update_fsdp_config\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_fsdp_version_in_fsdp_config(cls, data):\n        fsdp_config = data.get(\"fsdp_config\") or {}\n        fsdp_version = data.get(\"fsdp_version\", None)\n        if not fsdp_version and fsdp_config and fsdp_config.get(\"version\"):\n            fsdp_cfg_version = fsdp_config.pop(\"version\")\n            data[\"fsdp_version\"] = fsdp_cfg_version\n            data[\"fsdp_config\"][\"fsdp_version\"] = fsdp_cfg_version\n        elif not fsdp_version and fsdp_config and fsdp_config.get(\"fsdp_version\"):\n            data[\"fsdp_version\"] = fsdp_config.get(\"fsdp_version\")\n        if fsdp_version and fsdp_config and not fsdp_config.get(\"fsdp_version\"):\n            data[\"fsdp_config\"][\"fsdp_version\"] = fsdp_version\n        return data\n\n    @model_validator(mode=\"after\")\n    def check_fsdp_offload_w_8bit_optimizer(self):\n        if (\n            hasattr(self, \"fsdp_config\")\n            and self.fsdp_config\n            and self.optimizer\n            and \"8bit\" in self.optimizer.value\n            and self.fsdp_config.offload_params\n            and str(self.fsdp_version) != \"2\"\n        ):\n            raise ValueError(\n                f\"FSDP Offload not compatible with {str(self.optimizer.value)}\"\n            )\n        return self\n\n    @model_validator(mode=\"after\")\n    def check_fsdp2_w_8bit_optimizer(self):\n        if (\n            hasattr(self, \"fsdp_config\")\n            and self.fsdp_config\n            and self.optimizer\n            and \"8bit\" in self.optimizer.value\n            and str(self.fsdp_version) == \"2\"\n        ):\n            if self.optimizer in [\"adamw_8bit\", \"adamw_bnb_8bit\"]:\n                # CUDA ops errors with bnb 8bit optimizer + FSDP2\n                raise ValueError(\n                    f\"FSDP2 not compatible with {self.optimizer.value}, use `adamw_torch_8bit` instead\"\n                )\n\n        return self\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_tensor_parallel_size_update_ds_json(cls, data):\n        tensor_parallel_size = data.get(\"tensor_parallel_size\")\n        if tensor_parallel_size is not None and tensor_parallel_size > 1:\n            if data.get(\"deepspeed\"):\n                with open(data.get(\"deepspeed\"), \"r\", encoding=\"utf-8\") as ds_fin:\n                    ds_config = json.load(ds_fin)\n                    should_save = False\n                    if \"tensor_parallel\" not in ds_config:\n                        ds_config[\"tensor_parallel\"] = {\n                            \"autotp_size\": tensor_parallel_size\n                        }\n                        should_save = True\n                    if (\n                        \"gather_16bit_weights_on_model_save\"\n                        not in ds_config[\"zero_optimization\"]\n                    ):\n                        ds_config[\"zero_optimization\"][\n                            \"gather_16bit_weights_on_model_save\"\n                        ] = True\n                        should_save = True\n                    if should_save:\n                        temp_dir = tempfile.mkdtemp()\n                        with open(\n                            Path(temp_dir) / \"autotp_ds.json\", \"w\", encoding=\"utf-8\"\n                        ) as ds_fout:\n                            json.dump(ds_config, ds_fout, indent=4)\n                        data[\"deepspeed\"] = str(Path(temp_dir) / \"autotp_ds.json\")\n\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_deepcompile(cls, data):\n        deepcompile = data.get(\"deepcompile\")\n        if deepcompile:\n            if not data.get(\"deepspeed\"):\n                raise ValueError(\"DeepCompile is only supported with DeepSpeed\")\n            with open(data.get(\"deepspeed\"), \"r\", encoding=\"utf-8\") as ds_fin:\n                ds_config = json.load(ds_fin)\n                if \"compile\" not in ds_config:\n                    ds_config[\"compile\"] = {\"deepcompile\": True}\n                    temp_dir = tempfile.mkdtemp()\n                    with open(\n                        Path(temp_dir) / \"deepcompile_ds.json\", \"w\", encoding=\"utf-8\"\n                    ) as ds_fout:\n                        json.dump(ds_config, ds_fout, indent=4)\n                    data[\"deepspeed\"] = str(Path(temp_dir) / \"deepcompile_ds.json\")\n\n        return data\n\n\nclass SystemValidationMixin:\n    \"\"\"Validation methods related to system and hardware configuration.\"\"\"\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_mem_mismatch(cls, data):\n        if (\n            data.get(\"max_memory\") is not None\n            and data.get(\"gpu_memory_limit\") is not None\n        ):\n            raise ValueError(\n                \"max_memory and gpu_memory_limit are mutually exclusive and cannot be used together.\"\n            )\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_fsdp_deepspeed(cls, data):\n        if data.get(\"deepspeed\") and data.get(\"fsdp\"):\n            raise ValueError(\"deepspeed and fsdp cannot be used together.\")\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_model_quantization_config_vs_bnb(cls, data):\n        if data.get(\"model_quantization_config\"):\n            if data.get(\"load_in_8bit\") or data.get(\"load_in_4bit\"):\n                raise ValueError(\n                    \"model_quantization_config and load_in_8bit or load_in_4bit cannot be used together.\"\n                )\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_npu_config(cls, data):\n        if is_torch_npu_available():\n            # check attention config\n            attn_list = [\"flash_attention\", \"sdp_attention\", \"s2_attention\"]\n            for attn in attn_list:\n                if data.get(attn):\n                    raise NotImplementedError(\n                        f\"{attn} is currently not supported in Ascend npu, please disable this configuration.\"\n                    )\n\n            # check quant config\n            if data.get(\"optimizer\") is not None and \"bit\" in data.get(\"optimizer\"):\n                optimizer = data.get(\"optimizer\")\n                raise NotImplementedError(\n                    f\"{optimizer} is currently not supported in Ascend npu, choose another one please.\"\n                )\n\n            quant_list = [\"load_in_8bit\", \"load_in_4bit\"]\n            for quant in quant_list:\n                if data.get(quant):\n                    raise NotImplementedError(\n                        f\"Quantification is currently not supported in Ascend npu, please disable {quant}.\"\n                    )\n\n            # check dtype config\n            if data.get(\"tf32\"):\n                raise NotImplementedError(\n                    \"tf32 dtype is currently not supported in Ascend npu, please disable this configuration\"\n                )\n\n        return data\n\n\nclass ChatTemplateValidationMixin:\n    \"\"\"Validation methods related to chat template configuration.\"\"\"\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_chat_template_config(cls, data):\n        # if chat_template is set to jinja, chat_template_jinja is required\n        if data.get(\"chat_template\") == ChatTemplate.jinja and not data.get(\n            \"chat_template_jinja\"\n        ):\n            raise ValueError(\n                \"chat_template_jinja is required when chat_template is set to jinja\"\n            )\n\n        # If chat_template_jinja is set, set chat_template to jinja\n        if data.get(\"chat_template_jinja\") and not data.get(\"chat_template\"):\n            data[\"chat_template\"] = ChatTemplate.jinja\n\n        return data\n\n\nclass PretrainingValidationMixin:\n    \"\"\"Validation methods related to pretraining configuration.\"\"\"\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_pretraining_w_max_steps(cls, data):\n        if data.get(\"pretraining_dataset\") and not data.get(\"max_steps\"):\n            raise ValueError(\n                \"max_steps must be set when using iterable pretraining_dataset, Trainer can't infer length and schedule optimizer/learning rate without it!\"\n            )\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_pretraining_w_group_by_length(cls, data):\n        if data.get(\"pretraining_dataset\") and data.get(\"group_by_length\"):\n            LOG.warning(\n                \"You probably want to disable group_by_length as it will force a streamed dataset to download completely.\"\n            )\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_pretraining_split_batches_accelerate(cls, data):\n        # alternatively set ACCELERATE_SPLIT_BATCHES=False\n        if data.get(\"pretraining_dataset\"):\n            accelerator_config = data.get(\"accelerator_config\", {})\n            if not accelerator_config:\n                data[\"accelerator_config\"] = {\n                    \"split_batches\": False,\n                    \"dispatch_batches\": False,\n                }\n            else:\n                if accelerator_config.get(\"split_batches\") is None:\n                    data[\"accelerator_config\"][\"split_batches\"] = False\n                if accelerator_config.get(\"dispatch_batches\") is None:\n                    data[\"accelerator_config\"][\"dispatch_batches\"] = False\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_pretraining_w_val_set_size(cls, data):\n        if data.get(\"pretraining_dataset\") and data.get(\"val_set_size\"):\n            raise ValueError(\n                \"val_set_size is not supported with pretraining_dataset. \"\n                \"Use test_datasets to specify evaluation datasets for pretraining.\"\n            )\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_streaming_w_val_set_size(cls, data):\n        if data.get(\"streaming\") and data.get(\"val_set_size\"):\n            raise ValueError(\n                \"val_set_size is not supported with streaming datasets. \"\n                \"Use test_datasets to specify evaluation datasets when streaming is enabled.\"\n            )\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_streaming_w_max_steps(cls, data):\n        if data.get(\"streaming\") and not data.get(\"max_steps\"):\n            raise ValueError(\n                \"max_steps must be set when using streaming datasets. \"\n                \"Trainer cannot infer dataset length for iterable datasets.\"\n            )\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_streaming_w_multiple_datasets(cls, data):\n        if (\n            data.get(\"streaming\")\n            and data.get(\"sample_packing\")\n            and data.get(\"datasets\")\n            and len(data.get(\"datasets\")) > 1\n        ):\n            raise NotImplementedError(\n                \"Sample packing with multiple streaming datasets is not yet supported\"\n            )\n        return data\n\n\nclass ModelCompatibilityValidationMixin:\n    \"\"\"Validation methods for specific model compatibility.\"\"\"\n\n    @model_validator(mode=\"after\")\n    def check_falcon_fsdp(self):\n        if (self.base_model and \"falcon\" in self.base_model.lower()) and self.fsdp:\n            raise ValueError(\"FSDP is not supported for falcon models\")\n        return self\n\n    @model_validator(mode=\"after\")\n    def check_mpt_checkpointing(self):\n        if (\n            self.base_model and \"mpt\" in self.base_model.lower()\n        ) and self.gradient_checkpointing:\n            raise ValueError(\"gradient_checkpointing is not supported for MPT models\")\n        return self\n\n    @model_validator(mode=\"after\")\n    def check_gradient_checkpointing_w_offload(self):\n        if self.gradient_checkpointing == \"offload\":\n            LOG.warning(\n                \"`offload` is deprecated for gradient_checkpointing, use `activation_offloading: true` or `activation_offloading: legacy`\"\n            )\n            self.gradient_checkpointing = True\n            LOG.warning(\n                \"`offload` now uses a new stream implementation; to use the previous implementation, use `activation_offloading: legacy`\"\n            )\n            self.activation_offloading = True\n        if self.gradient_checkpointing == \"offload_disk\":\n            LOG.warning(\n                \"`offload_disk` is deprecated for gradient_checkpointing, use `activation_offloading: disk`\"\n            )\n            self.gradient_checkpointing = True\n            self.activation_offloading = \"disk\"\n        return self\n\n    @model_validator(mode=\"after\")\n    def check_activation_offloading_wo_gc(self):\n        if self.activation_offloading and not self.gradient_checkpointing:\n            raise ValueError(\"activation_offloading requires gradient_checkpointing\")\n        return self\n\n    @model_validator(mode=\"after\")\n    def check_better_transformers(self):\n        if self.flash_optimum is True:\n            if self.adapter:\n                LOG.warning(\n                    \"BetterTransformers probably doesn't work with PEFT adapters\"\n                )\n            if self.fp16 or self.bf16:\n                raise ValueError(\"AMP is not supported with BetterTransformer\")\n            if self.float16 is not True and self.bfloat16 is not True:\n                LOG.warning(\n                    \"You should probably set bfloat16 or float16 to true to \"\n                    \"load the model in float16 for BetterTransformers\"\n                )\n        return self\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_gptq_w_revision(cls, data):\n        if data.get(\"gptq\") and data.get(\"revision_of_model\"):\n            raise ValueError(\n                \"revision_of_model is not supported for GPTQ models. \"\n                + \"Please download the model from HuggingFace Hub manually for correct branch, \"\n                + \"point to its path, and remove revision_of_model from the config.\"\n            )\n        return data\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def check_gpt_oss_fsdp_loading(cls, data):\n        if data.get(\"model_quantization_config\", \"\") == \"Mxfp4Config\":\n            fsdp_config = data.get(\"fsdp_config\") or {}\n            if fsdp_config.get(\"cpu_ram_efficient_loading\", False) is True:\n                raise ValueError(\n                    \"FSDP cpu_ram_efficient_loading is not supported for Mxfp4Config model quantization.\"\n                )\n        return data\n\n\nclass ComplexValidationMixin:\n    \"\"\"Complex validation methods that involve multiple systems.\"\"\"\n\n    @field_validator(\"neftune_noise_alpha\")\n    @classmethod\n    def validate_neftune_noise_alpha(cls, neftune_noise_alpha):\n        if neftune_noise_alpha is not None and neftune_noise_alpha <= 0.0:\n            raise ValueError(\"neftune_noise_alpha must be > 0.0\")\n        return neftune_noise_alpha\n\n    @model_validator(mode=\"after\")\n    def check_rl_beta(self):\n        if self.dpo_beta and not self.rl_beta:\n            self.rl_beta = self.dpo_beta\n            del self.dpo_beta\n        return self\n\n    @model_validator(mode=\"after\")\n    def check_simpo_warmup(self):\n        if self.rl is RLType.SIMPO and self.warmup_ratio:\n            raise ValueError(\n                \"warmup_ratio is not supported with the simpo trainer. Please use `warmup_steps` instead\"\n            )\n        return self\n\n    @model_validator(mode=\"after\")\n    def check_relora(self):\n        if self.relora:\n            if not self.jagged_restart_steps:\n                raise ValueError(\"jagged_restart_steps must be set to use ReLoRA\")\n            if self.adapter not in (\"lora\", \"qlora\"):\n                raise ValueError(\"cfg.adapter must be lora or qlora to use ReLoRA\")\n\n            if self.fsdp or self.fsdp_config:\n                raise ValueError(\"fsdp not supported with ReLoRA\")\n\n            if self.deepspeed:\n                raise ValueError(\"deepspeed not supported with ReLoRA\")\n\n            if self.lr_scheduler == \"one_cycle\":\n                raise ValueError(\n                    \"ReLoRA is not compatible with the one_cycle scheduler\"\n                )\n\n            if self.flash_attn_fuse_mlp:\n                raise ValueError(\"Fused modules are not supported with ReLoRA\")\n        return self\n\n    @model_validator(mode=\"after\")\n    def check_early_stopping(self):\n        if self.early_stopping_patience:\n            if not self.save_steps or not self.eval_steps:\n                raise ValueError(\n                    \"`early_stopping_patience` requires save_steps and eval_steps to be set. eval_steps should evenly divide save_steps.\"\n                )\n            if self.save_steps % self.eval_steps != 0:\n                raise ValueError(\n                    \"`early_stopping_patience` requires that eval_steps should evenly divide save_steps.\"\n                )\n        return self\n\n    @model_validator(mode=\"after\")\n    def check_tensor_parallel_size(self):\n        if not self.tensor_parallel_size:\n            self.tensor_parallel_size = 1\n        return self\n\n    @model_validator(mode=\"after\")\n    def check_context_parallel_size(self):\n        if self.sequence_parallel_degree and not self.context_parallel_size:\n            LOG.warning(\n                \"`sequence_parallel_degree` is deprecated, use `context_parallel_size`\"\n            )\n            self.context_parallel_size = self.sequence_parallel_degree\n        if not self.context_parallel_size:\n            self.context_parallel_size = 1\n        elif self.context_parallel_size > 1:\n            if not self.flash_attention:\n                raise ValueError(\n                    \"flash_attention: true must be set with context_parallel_size > 1\"\n                )\n\n            if self.sample_packing and self.micro_batch_size > 1:\n                raise ValueError(\n                    \"micro_batch_size must be set to 1 when sample_packing is enabled \"\n                    \"due to a `ring-flash-attn` requirement\"\n                )\n\n            try:\n                import transformers.modeling_flash_attention_utils\n                from transformers.utils import is_flash_attn_greater_or_equal\n\n                transformers.modeling_flash_attention_utils._flash_supports_window = (\n                    True\n                )\n                sys.modules[\n                    \"transformers.modeling_flash_attention_utils\"\n                ]._flash_supports_window = True\n                sys.modules[\n                    \"transformers.modeling_flash_attention_utils\"\n                ]._flash_supports_window_size = True\n                sys.modules[\n                    \"transformers.modeling_flash_attention_utils\"\n                ].is_flash_attn_greater_or_equal = is_flash_attn_greater_or_equal\n                import ring_flash_attn  # noqa: F401  # Required after monkey-patching\n            except ImportError as exception:\n                raise ImportError(\n                    \"context_parallel_size > 1 but ring_flash_attn is not installed. \"\n                    \"Please install it with `pip install axolotl[ring-flash-attn] \"\n                    \"or `pip install ring-flash-attn>=0.1.4`.\"\n                ) from exception\n\n            LOG.warning(\n                \"Sequence parallelism (SP) is enabled with \"\n                f\"context_parallel_size={self.context_parallel_size}. \"\n                \"Please note that logged losses may differ slightly to the non-SP \"\n                \"losses due to transformers Trainer implementation details. \"\n                \"Please see https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 \"\n                \"for more details.\"\n            )\n\n        return self\n\n    @model_validator(mode=\"after\")\n    def validate_ring_attn_func(self):\n        if getattr(self, \"context_parallel_size\", 1) == 1:\n            return self\n\n        if self.ring_attn_func is not None:\n            self.ring_attn_func = RingAttnFunc(self.ring_attn_func)\n        else:\n            # Default ring attention function selection\n            sample_packing = getattr(self, \"sample_packing\", False)\n            self.ring_attn_func = (\n                RingAttnFunc.VARLEN_LLAMA3\n                if sample_packing\n                else RingAttnFunc.BATCH_RING\n            )\n\n        return self\n\n    def hint_gradient_checkpointing_dpo_lora_ddp(self):\n        if (\n            (self.gradient_checkpointing is True or self.gradient_checkpointing is None)\n            and self.capabilities\n            and self.capabilities.get(\"n_gpu\", 1) > 1\n            and self.adapter in (\"lora\", \"qlora\")\n            and self.rl == RLType.DPO\n            and not self.fsdp\n            and not self.deepspeed\n        ):\n            LOG.warning(\n                \"gradient_checkpointing with DPO + DDP + LoRA is not recommended.\"\n            )\n        return self\n\n\nclass DistributedValidationMixin:\n    \"\"\"validation for distributed training.\"\"\"\n\n    @model_validator(mode=\"after\")\n    def check_tensor_parallel_optimizer(self):\n        if self.tensor_parallel_size > 1:\n            if self.optimizer in [\"paged_adamw_8bit\", \"adamw_8bit\", \"adamw_bnb_8bit\"]:\n                raise ValueError(\n                    \"tensor_parallel_size is not supported with paged_adamw_8bit, adamw_8bit, and adamw_bnb_8bit optimizers\"\n                )\n\n        return self\n\n\nclass GRPOVllmValidationMixin:\n    \"\"\"Validation mixin for vllm when using GRPO.\"\"\"\n\n    @model_validator(mode=\"after\")\n    def check_vllm_mode_set(self):\n        if self.trl and self.trl.use_vllm and not self.trl.vllm_mode:\n            LOG.warning(\n                \"vllm_mode must be set to either `server` or `colocate` when using vllm, using default value `server`\"\n            )\n            self.trl.vllm_mode = \"server\"\n        return self\n\n\nclass ValidationMixin(\n    DatasetValidationMixin,\n    AttentionValidationMixin,\n    TrainingValidationMixin,\n    LoRAValidationMixin,\n    RLValidationMixin,\n    OptimizationValidationMixin,\n    SystemValidationMixin,\n    ChatTemplateValidationMixin,\n    PretrainingValidationMixin,\n    ModelCompatibilityValidationMixin,\n    ComplexValidationMixin,\n    GRPOVllmValidationMixin,\n):\n    \"\"\"Full validation mixin for Axolotl configuration.\"\"\"\n"
  },
  {
    "path": "src/axolotl/utils/schemas/vllm.py",
    "content": "\"\"\"\nPydantic models for VLLM configuration, used primarily for RL training with TRL + grpo\n\"\"\"\n\nfrom pydantic import BaseModel, Field\n\n\nclass VllmConfig(BaseModel):\n    \"\"\"\n    Configuration for VLLM server\n    \"\"\"\n\n    device: str | None = Field(\n        default=\"auto\",\n        json_schema_extra={\"description\": \"Device to use for VLLM\"},\n    )\n    tensor_parallel_size: int | None = Field(\n        default=None,\n        json_schema_extra={\"description\": \"Tensor parallel size for VLLM\"},\n    )\n    data_parallel_size: int | None = Field(\n        default=None,\n        json_schema_extra={\"description\": \"Data parallel size for VLLM\"},\n    )\n    gpu_memory_utilization: float | None = Field(\n        default=0.9,\n        json_schema_extra={\"description\": \"GPU memory utilization for VLLM\"},\n    )\n    dtype: str | None = Field(\n        default=\"auto\",\n        json_schema_extra={\"description\": \"Data type for VLLM\"},\n    )\n    max_model_len: int | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Maximum length of the model context for VLLM\"\n        },\n    )\n    enable_prefix_caching: bool | None = Field(\n        default=None,\n        json_schema_extra={\"description\": \"Enable prefix caching for VLLM\"},\n    )\n    host: str | None = Field(\n        default=\"0.0.0.0\",  # nosec B104\n        json_schema_extra={\"description\": \"Host for the vLLM server to start on\"},\n    )\n    port: int | None = Field(\n        default=8000,\n        json_schema_extra={\"description\": \"Port of the vLLM server to start on\"},\n    )\n\n    enable_reasoning: bool | None = Field(\n        default=None,\n        json_schema_extra={\"description\": \"Enable reasoning for VLLM\"},\n    )\n    reasoning_parser: str | None = Field(\n        default=None,\n        json_schema_extra={\"description\": \"Reasoning parser for VLLM\"},\n    )\n    serve_module: str | None = Field(\n        default=None,\n        json_schema_extra={\n            \"description\": \"Python module for vLLM serve script. Set to 'axolotl.scripts.vllm_serve_lora' \"\n            \"for native LoRA support, or leave None for default TRL serve.\"\n        },\n    )\n"
  },
  {
    "path": "src/axolotl/utils/tee.py",
    "content": "\"\"\"\nUtilities for managing the debug log file and providing a file-only stream for logging\nhandlers.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport io\nimport os\nimport sys\nimport threading\nfrom pathlib import Path\nfrom typing import TextIO, cast\n\n_lock = threading.Lock()\n_file_handle: io.TextIOWrapper | None = None\n_log_path: str | None = None\n_tee_installed: bool = False\n_orig_stdout: TextIO | None = None\n_orig_stderr: TextIO | None = None\n\n\nclass _FileOnlyWriter(io.TextIOBase):\n    \"\"\"A stream-like object that writes only to the tee file.\n\n    Before the file is prepared, writes are dropped (no-op).\n    \"\"\"\n\n    def write(self, s: str) -> int:  # type: ignore[override]\n        with _lock:\n            if _file_handle is not None:\n                _file_handle.write(s)\n                return len(s)\n            return len(s)\n\n    def flush(self) -> None:  # type: ignore[override]\n        with _lock:\n            if _file_handle is not None:\n                try:\n                    _file_handle.flush()\n                except Exception:\n                    pass\n\n\nfile_only_stream: io.TextIOBase = _FileOnlyWriter()\n\n\nclass _StreamTee(io.TextIOBase):\n    \"\"\"A minimal tee that mirrors writes to the debug log file.\n\n    Installed only after the debug log is prepared; no buffering.\n    \"\"\"\n\n    def __init__(self, stream: io.TextIOBase):\n        self._stream = stream\n\n    def write(self, s: str) -> int:  # type: ignore[override]\n        with _lock:\n            n = self._stream.write(s)\n            if _file_handle is not None:\n                _file_handle.write(s)\n            return n\n\n    def flush(self) -> None:  # type: ignore[override]\n        with _lock:\n            self._stream.flush()\n            if _file_handle is not None:\n                try:\n                    _file_handle.flush()\n                except Exception:\n                    pass\n\n    @property\n    def encoding(self):  # type: ignore[override]\n        return getattr(self._stream, \"encoding\", None)\n\n    @property\n    def errors(self):  # type: ignore[override]\n        return getattr(self._stream, \"errors\", None)\n\n    def isatty(self):  # type: ignore[override]\n        return getattr(self._stream, \"isatty\", lambda: False)()\n\n    def fileno(self):  # type: ignore[override]\n        if hasattr(self._stream, \"fileno\"):\n            return self._stream.fileno()\n        raise OSError(\"Underlying stream has no fileno\")\n\n\ndef prepare_debug_log(cfg, filename: str = \"debug.log\") -> str:\n    \"\"\"\n    Prepare the debug log.\n\n    Creates the output directory, handles append/truncate logic based on cfg, and opens\n    the debug log file for subsequent writes via file-only handlers.\n    \"\"\"\n    global _file_handle, _log_path, _tee_installed\n\n    with _lock:\n        # If already initialized, reuse existing path\n        if _log_path is not None:\n            return _log_path\n\n        output_dir = cfg.output_dir\n        os.makedirs(output_dir, exist_ok=True)\n\n        log_path = Path(output_dir) / filename\n        append = bool(\n            cfg.get(\"resume_from_checkpoint\") or cfg.get(\"auto_resume_from_checkpoints\")\n        )\n\n        if not append:\n            log_path.unlink(missing_ok=True)\n\n        fh = open(log_path, \"a\", encoding=\"utf-8\")\n        fh.flush()\n\n        _file_handle = fh\n        _log_path = str(log_path)\n\n        # Install a tee so stdout/stderr are mirrored to the debug file\n        # Allow disabling via env for testing or advanced usage.\n        tee_enabled = os.getenv(\"AXOLOTL_TEE_STDOUT\", \"1\").lower() not in {\n            \"0\",\n            \"false\",\n            \"no\",\n        }\n        if tee_enabled and not _tee_installed:\n            # Save originals so we can restore later (e.g., tests)\n            global _orig_stdout, _orig_stderr\n            _orig_stdout = sys.stdout\n            _orig_stderr = sys.stderr\n            sys.stdout = _StreamTee(cast(io.TextIOBase, sys.stdout))\n            sys.stderr = _StreamTee(cast(io.TextIOBase, sys.stderr))\n            _tee_installed = True\n\n        return _log_path\n\n\ndef close_debug_log() -> None:\n    \"\"\"Flush and close the debug log and uninstall the stdout/stderr tee.\n\n    Safe to call even if not initialized.\n    \"\"\"\n    global _file_handle, _log_path, _tee_installed, _orig_stdout, _orig_stderr\n    with _lock:\n        # Restore original stdout/stderr if we installed a tee\n        if _tee_installed:\n            if _orig_stdout is not None:\n                sys.stdout = _orig_stdout\n            if _orig_stderr is not None:\n                sys.stderr = _orig_stderr\n            _tee_installed = False\n            _orig_stdout = None\n            _orig_stderr = None\n\n        # Close the file handle if open\n        if _file_handle is not None:\n            try:\n                _file_handle.flush()\n                _file_handle.close()\n            except Exception:\n                pass\n            finally:\n                _file_handle = None\n        _log_path = None\n"
  },
  {
    "path": "src/axolotl/utils/tokenization.py",
    "content": "\"\"\"Module for tokenization utilities\"\"\"\n\nfrom termcolor import colored\n\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\n\ndef check_dataset_labels(\n    dataset,\n    tokenizer,\n    num_examples=5,\n    text_only=False,\n    rl_mode=False,\n):\n    # the dataset is already shuffled, so let's just check the first 5 elements\n    for idx in range(num_examples):\n        if not rl_mode:\n            check_example_labels(dataset[idx], tokenizer, text_only=text_only)\n        else:\n            check_rl_example_labels(dataset[idx], tokenizer, text_only=text_only)\n\n\ndef check_example_labels(example, tokenizer, text_only=False):\n    # Get the input_ids, labels, and attention_mask from the dataset\n    input_ids = example[\"input_ids\"]\n    labels = example[\"labels\"]\n    target_mask = example.pop(\"target_mask\", None)\n\n    # You can compare the input_ids and labels element-wise\n    # Remember to ignore positions with IGNORE_TOKEN_ID (if you use it) or attention_mask equal to 0\n    colored_tokens = []\n    for _, (input_id, label_id) in enumerate(zip(input_ids, labels, strict=False)):\n        decoded_input_token = tokenizer.decode(input_id)\n        # Choose the color based on whether the label has the ignore value or not\n        color = \"red\" if label_id == -100 else (\"yellow\" if label_id == 0 else \"green\")\n        colored_token = colored(decoded_input_token, color) + (\n            not text_only and colored(f\"({label_id}, {input_id})\", \"white\") or \"\"\n        )\n        colored_tokens.append(colored_token)\n\n    delimiter = \"\" if text_only else \" \"\n    LOG.info(delimiter.join(colored_tokens))\n    LOG.info(\"\\n\\n\\n\")\n    target_labels_count = sum(label_id != -100 for label_id in labels)\n    total_len = len(input_ids)\n    LOG.info(f\"Total input len: {total_len}\")\n    LOG.info(f\"Count of labels: {target_labels_count}\")\n    if target_mask:\n        target_mask_positions = sum(m[0] for m in target_mask)\n        LOG.info(f\"Number of positions in target_mask: {target_mask_positions}\")\n\n    return \" \".join(colored_tokens)\n\n\ndef color_token_for_rl_debug(decoded_token, encoded_token, color, text_only):\n    \"\"\"Helper function to color tokens based on their type.\"\"\"\n    colored_text = colored(decoded_token, color)\n    return (\n        colored_text\n        if text_only\n        else f\"{colored_text}{colored(f'({encoded_token})', 'white')}\"\n    )\n\n\ndef process_tokens_for_rl_debug(tokens, color, tokenizer, text_only):\n    \"\"\"Helper function to process and color tokens.\"\"\"\n    colored_tokens = [\n        color_token_for_rl_debug(tokenizer.decode(token), token, color, text_only)\n        for token in tokenizer.encode(tokens, add_special_tokens=False)\n    ]\n    return colored_tokens\n\n\ndef check_rl_example_labels(example, tokenizer, text_only=False):\n    field_prompt, field_chosen, field_rejected, field_completion = (\n        \"prompt\",\n        \"chosen\",\n        \"rejected\",\n        \"completion\",\n    )\n\n    input_tokens = example[field_prompt]\n\n    labels_chosen = example.get(field_chosen)\n    labels_rejected = example.get(field_rejected)\n    labels_completion = example.get(field_completion)\n\n    # Create a delimiter based on text_only flag\n    delimiter = \"\" if text_only else \" \"\n\n    # Process and color each type of token\n    colored_tokens = process_tokens_for_rl_debug(\n        input_tokens, \"yellow\", tokenizer, text_only\n    )\n\n    # Process tokens\n    if labels_completion is None:\n        colored_chosens = process_tokens_for_rl_debug(\n            labels_chosen, \"green\", tokenizer, text_only\n        )\n        colored_rejecteds = process_tokens_for_rl_debug(\n            labels_rejected, \"red\", tokenizer, text_only\n        )\n    else:\n        colored_completion = process_tokens_for_rl_debug(\n            labels_completion, \"green\", tokenizer, text_only\n        )\n\n    # Logging information\n    LOG.info(f\"INPUT PROMPT: {delimiter.join(colored_tokens)}\\n\\n\")\n\n    if labels_completion is None:\n        LOG.info(f\"CHOSEN RESPONSE: {delimiter.join(colored_chosens)}\\n\\n\")\n        LOG.info(f\"REJECTED RESPONSE: {delimiter.join(colored_rejecteds)}\\n\\n\\n\")\n    else:\n        LOG.info(f\"COMPLETION RESPONSE: {delimiter.join(colored_completion)}\\n\\n\\n\")\n\n    return delimiter.join(colored_tokens)\n"
  },
  {
    "path": "src/axolotl/utils/trackio_.py",
    "content": "\"\"\"Module for trackio utilities\"\"\"\n\nimport os\n\nfrom axolotl.utils.dict import DictDefault\n\n\ndef setup_trackio_env_vars(cfg: DictDefault):\n    for key in cfg.keys():\n        if key.startswith(\"trackio_\"):\n            value = cfg.get(key, \"\")\n\n            if value and isinstance(value, str) and len(value) > 0:\n                os.environ[key.upper()] = value\n\n    if cfg.trackio_project_name and len(cfg.trackio_project_name) > 0:\n        cfg.use_trackio = True\n"
  },
  {
    "path": "src/axolotl/utils/train.py",
    "content": "\"\"\"Training utils for checkpoints\"\"\"\n\nfrom pathlib import Path\n\nfrom axolotl.utils.dict import DictDefault\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\n\ndef determine_last_checkpoint(cfg: DictDefault, update: bool = True) -> str | None:\n    \"\"\"\n    Determine the checkpoint to resume from based on configuration.\n\n    Args:\n        cfg: Dictionary mapping `axolotl` config keys to values.\n        update: Whether to update the config with the determined checkpoint\n\n    Returns:\n        Path to the checkpoint to resume from, or `None` if not resuming.\n    \"\"\"\n    last_checkpoint = None\n    checkpoints = sorted(\n        (\n            p\n            for p in Path(cfg.output_dir).glob(\"checkpoint-*\")\n            if p.name.split(\"-\")[-1].isdigit()\n        ),\n        key=lambda p: int(p.name.split(\"-\")[-1]),\n    )\n    if checkpoints:\n        last_checkpoint = str(checkpoints[-1])\n        if not update:\n            LOG.info(f\"Resuming from last checkpoint at {last_checkpoint}\")\n            return last_checkpoint\n\n    if (\n        cfg.resume_from_checkpoint is None\n        and cfg.auto_resume_from_checkpoints\n        and last_checkpoint is not None\n    ):\n        cfg.resume_from_checkpoint = last_checkpoint\n        LOG.info(\n            \"Using auto-resume functionality to resume from checkpoint at \"\n            f\"{cfg.resume_from_checkpoint}\"\n        )\n    return cfg.resume_from_checkpoint\n"
  },
  {
    "path": "src/axolotl/utils/trainer.py",
    "content": "\"\"\"Module containing the Trainer class and related functions\"\"\"\n\nimport json\nimport math\nimport os\nimport random\nfrom contextlib import contextmanager\nfrom functools import partial\nfrom tempfile import NamedTemporaryFile\nfrom typing import List, Optional\n\nimport numpy as np\nimport torch\nimport torch.cuda\nfrom datasets import IterableDataset, disable_caching, enable_caching\nfrom torch.utils.data import DataLoader, RandomSampler, SequentialSampler\nfrom transformers.utils import is_torch_bf16_gpu_available\n\nfrom axolotl.utils.dict import DictDefault\nfrom axolotl.utils.distributed import init_distributed_state, reduce_and_broadcast\nfrom axolotl.utils.environment import check_cuda_p2p_ib_support\nfrom axolotl.utils.logging import get_logger\nfrom axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths\n\nLOG = get_logger(__name__)\n\n\n@torch.jit.script\ndef weighted_cross_entropy(\n    logits: torch.Tensor, labels: torch.Tensor, weights: torch.Tensor\n):\n    # Flatten the logits, labels, and weights tensors\n    logits = logits.view(\n        -1, logits.size(-1)\n    )  # logits becomes of shape [batch_size*sequence_length, vocab_size]\n    labels = labels.view(-1)  # labels becomes of shape [batch_size*sequence_length]\n    weights = weights.view(-1)  # weights becomes of shape [batch_size*sequence_length]\n\n    # Compute the unweighted cross entropy loss\n    losses = torch.nn.functional.cross_entropy(logits, labels, reduction=\"none\")\n\n    # Apply the weights to the losses and compute their sum\n    return (weights * losses).sum()\n\n\n@torch.jit.script\ndef create_weighted_mask(labels: torch.Tensor):\n    # Check if the tensor is 2D. If not, unsqueeze it to make it 2D\n    if len(labels.shape) == 1:\n        labels = labels.unsqueeze(0)\n\n    weights = torch.zeros_like(labels).float()\n    for i in range(labels.shape[0]):\n        mask = labels[i] != -100\n\n        # Create a tensor to track group ids\n        group_ids = torch.zeros_like(labels[i]).int()\n        curr_group_id = 0\n\n        for j in range(1, len(labels[i])):\n            if mask[j] and not mask[j - 1]:  # switch from masked to unmasked label\n                curr_group_id += 1  # start new group\n            group_ids[j] = (\n                curr_group_id if mask[j] else 0\n            )  # assign group id if unmasked label\n\n        # Count only unmasked labels in each group\n        group_counts = torch.bincount(group_ids[mask])\n\n        mask_weights = torch.zeros_like(labels[i]).float()\n        mask_weights[mask] = 1.0 / group_counts[group_ids[mask]]\n\n        weights[i] = mask_weights\n\n    return weights.squeeze()  # squeeze the output to match the input dimension\n\n\ndef trainer_weighted_loss(model_output, labels, shift_labels=True):\n    logits = (\n        model_output[\"logits\"] if isinstance(model_output, dict) else model_output[0]\n    )\n    if shift_labels:\n        logits = logits[..., :-1, :].contiguous()\n        labels = labels[..., 1:].contiguous()\n\n    weights = create_weighted_mask(labels)\n    return weighted_cross_entropy(logits, labels, weights)\n\n\n@contextmanager\ndef disable_datasets_caching():\n    try:\n        disable_caching()\n        yield\n    finally:\n        enable_caching()\n\n\ndef add_position_ids(sample):\n    \"\"\"\n    Handle both single-example and batched data.\n    - single example: sample['input_ids'] is a list[int]\n    - batched data: sample['input_ids'] is a list[list[int]]\n    \"\"\"\n    # Return sample unchanged if \"input_ids\" is not present, or is empty\n    if \"input_ids\" not in sample or not sample[\"input_ids\"]:\n        return sample\n\n    input_ids = sample[\"input_ids\"]\n\n    # If first element is an int, it’s a single example\n    # If first element is a list, it’s a batch\n    if isinstance(input_ids[0], int):\n        # ---- SINGLE EXAMPLE ----\n        seq_len = len(input_ids)\n        # Position IDs for a single example\n        # As a list\n        sample[\"position_ids\"] = list(range(seq_len))\n        sample[\"length\"] = seq_len\n\n    else:\n        # ---- BATCHED EXAMPLES ----\n        # input_ids is a list of lists\n        position_ids_batch = []\n        lengths_batch = []\n        for seq in input_ids:\n            seq_len = len(seq)\n            position_ids_batch.append(list(range(seq_len)))\n            lengths_batch.append(seq_len)\n\n        # Now store them back\n        sample[\"position_ids\"] = position_ids_batch\n        sample[\"length\"] = lengths_batch\n\n    return sample\n\n\ndef add_pose_position_ids(\n    sample,\n    max_context_len=32768,\n    split_on_token_ids: Optional[List[int]] = None,\n    chunks: int = 2,\n):\n    \"\"\"\n    use the PoSE technique to extend the context length by randomly skipping\n    positions in the context. We only want to skip right before tokens in\n    the split_on_token_ids list. We should attempt to randomly distribute\n    the skips, but we don't need the final position_ids to be the full\n    context_len. There may be multiple turns in the context, so we want to\n    make sure we take into account the maximum possible number of skips\n    remaining in each sample.\n    \"\"\"\n\n    input_ids = sample[\"input_ids\"]\n    sample_len = len(input_ids)\n    max_skips = max_context_len - sample_len\n\n    if split_on_token_ids is None:\n        split_on_token_ids = []\n\n    if split_on_token_ids:\n        split_indices = [\n            i for i, token_id in enumerate(input_ids) if token_id in split_on_token_ids\n        ]\n    else:\n        chunk_len = sample_len // chunks\n        split_indices = [i * chunk_len for i in range(1, chunks)]\n    split_indices.append(len(input_ids))  # make sure we go to the end of the sample\n    if split_indices[0] < 2:\n        # drop the first split index if it's too close to the beginning\n        split_indices = split_indices[1:]\n\n    position_ids = []\n    prev_index = 0\n    total_skips = 0\n\n    for split_index in split_indices:\n        num_skips = (\n            random.randint(0, max_skips)  # nosec B311\n            if prev_index != 0 and max_skips\n            else 0\n        )\n        max_skips -= num_skips\n        total_skips += num_skips\n\n        segment_position_ids = list(\n            range(prev_index + total_skips, split_index + total_skips)\n        )\n\n        position_ids.extend(segment_position_ids)\n        prev_index = split_index\n\n    sample[\"sequence_len\"] = position_ids[-1]\n    position_ids = torch.tensor(position_ids)\n\n    sample[\"position_ids\"] = position_ids\n    sample[\"length\"] = len(position_ids)\n    assert len(position_ids) == len(input_ids)\n\n    return sample\n\n\ndef add_length(sample):\n    sample[\"length\"] = len(sample[\"input_ids\"])\n    return sample\n\n\ndef filter_sequences_by_length(\n    sample, sequence_len=2048, min_sequence_len=2, raise_on_drop=False\n):\n    \"\"\"\n    Filter sequences outside valid length range [min_sequence_len, sequence_len].\n\n    Drops samples that are either too short (< min_sequence_len) or too long (> sequence_len).\n\n    Works for both single-example (list[int]) or batched (list[list[int]]).\n\n    If raise_on_drop is set, the code raises a ValueError if a sample is\n    encountered that is too long and would have been dropped.\n    \"\"\"\n    min_sequence_len = min_sequence_len or 2\n\n    input_ids = sample[\"input_ids\"]\n\n    # Edge case: if input_ids is empty\n    if not input_ids:\n        # Decide if you want to drop or keep empty. Let's drop.\n        return False\n\n    # Check if single example or batched by looking at the first element\n    if isinstance(input_ids[0], int):\n        # Single example (input_ids is a list of int)\n        length = len(input_ids)\n        if raise_on_drop and length > sequence_len:\n            raise ValueError(\n                f\"Sequence encountered with {length} tokens, which exceeds the maximum {sequence_len}.\"\n            )\n        return min_sequence_len <= length <= sequence_len\n\n    # Batched (input_ids is a list of lists)\n    results = []\n    for seq in input_ids:\n        length = len(seq)\n        if raise_on_drop and length > sequence_len:\n            raise ValueError(\n                f\"Sequence encountered with {length} tokens, which exceeds the maximum {sequence_len}.\"\n            )\n        results.append(min_sequence_len <= length <= sequence_len)\n    return results\n\n\ndef process_datasets_for_packing(cfg, train_dataset, eval_dataset):\n    drop_attn_mask = cfg.model_config_type in [\"mamba\", \"gemma3\"]\n    if drop_attn_mask:\n        LOG.info(\"dropping attention_mask column\")\n        train_dataset = train_dataset.remove_columns(\"attention_mask\")\n        if eval_dataset:\n            eval_dataset = eval_dataset.remove_columns(\"attention_mask\")\n\n    if cfg.model_config_type in [\"falcon\", \"mistral\"]:\n        LOG.info(\"dropping token_type_ids column if it exists\")\n        if \"token_type_ids\" in train_dataset.column_names:\n            train_dataset = train_dataset.remove_columns(\"token_type_ids\")\n        if eval_dataset and \"token_type_ids\" in eval_dataset.column_names:\n            eval_dataset = eval_dataset.remove_columns(\"token_type_ids\")\n\n    def drop_no_trainable_tokens(sample):\n        \"\"\"\n        Drop samples if all labels are -100 (i.e., zero trainable tokens).\n        Works for both single-example or batched input.\n        \"\"\"\n        labels = sample[\"labels\"]\n        if not labels:\n            return True\n\n        # Check if single example or batch\n        # If first element is an int, we assume a single example\n        # If it's a list, we assume we're dealing with a batch\n        if isinstance(labels[0], int):\n            # Single example: return a single bool\n            return np.any(labels != -100)\n\n        # Batched: 'labels' is a list of lists\n        # Return a list of booleans, one per sub-list\n        results = [np.any(row_labels != -100) for row_labels in labels]\n        return results\n\n    try:\n        prior_len = len(train_dataset)\n    except TypeError:\n        # handle iterable datasets case\n        prior_len = None\n    filter_map_kwargs = {}\n    if not isinstance(train_dataset, IterableDataset):\n        filter_map_kwargs[\"num_proc\"] = cfg.dataset_num_proc\n        filter_map_kwargs[\"load_from_cache_file\"] = not cfg.is_preprocess\n\n    drop_long_kwargs = {}\n    if filter_map_kwargs:\n        drop_long_kwargs[\"desc\"] = \"Drop Samples with Zero Trainable Tokens\"\n    train_dataset = train_dataset.filter(\n        drop_no_trainable_tokens,\n        batched=True,\n        **filter_map_kwargs,\n        **drop_long_kwargs,\n    )\n    if prior_len:\n        dropped = prior_len - len(train_dataset)\n        if dropped:\n            LOG.warning(\n                f\"Dropped {dropped} samples with no trainable tokens from train dataset\"\n            )\n\n    if eval_dataset:\n        try:\n            prior_len = len(eval_dataset)\n        except TypeError:\n            # handle iterable datasets case\n            prior_len = None\n        eval_dataset = eval_dataset.filter(\n            drop_no_trainable_tokens,\n            **filter_map_kwargs,\n            **drop_long_kwargs,\n        )\n        if prior_len:\n            dropped = prior_len - len(eval_dataset)\n            if dropped:\n                LOG.warning(\n                    f\"Dropped {dropped} samples with no trainable tokens from eval dataset\"\n                )\n\n    if cfg.group_by_length:\n        train_dataset = train_dataset.map(\n            add_length,\n            num_proc=cfg.dataset_num_proc,\n            load_from_cache_file=not cfg.is_preprocess,\n            desc=\"Group By Length\",\n        )\n\n    if cfg.use_pose:\n        pose_kwargs = {}\n        if cfg.pose_num_chunks is not None:\n            pose_kwargs[\"chunks\"] = cfg.pose_num_chunks\n        pose_fn = partial(\n            add_pose_position_ids,\n            max_context_len=cfg.pose_max_context_len,\n            split_on_token_ids=cfg.pose_split_on_token_ids,\n            **pose_kwargs,\n        )\n        train_dataset = train_dataset.map(\n            pose_fn,\n            num_proc=cfg.dataset_num_proc,\n            load_from_cache_file=not cfg.is_preprocess,\n            desc=\"Add position_id column (PoSE)\",\n        )\n        train_dataset = train_dataset.sort(\"sequence_len\")\n        if cfg.eval_sample_packing is not False:\n            if eval_dataset:\n                eval_dataset = eval_dataset.map(\n                    pose_fn,\n                    num_proc=cfg.dataset_num_proc,\n                    load_from_cache_file=not cfg.is_preprocess,\n                    desc=\"Add position_id column (PoSE)\",\n                )\n    elif cfg.sample_packing:\n        drop_long_kwargs = {}\n        if filter_map_kwargs:\n            drop_long_kwargs[\"desc\"] = \"Add position_id column (Sample Packing)\"\n        train_dataset = train_dataset.map(\n            add_position_ids,\n            batched=True,\n            **filter_map_kwargs,\n            **drop_long_kwargs,\n        )\n        if cfg.eval_sample_packing:\n            if eval_dataset:\n                eval_dataset = eval_dataset.map(\n                    add_position_ids,\n                    **filter_map_kwargs,\n                    **drop_long_kwargs,\n                )\n\n    return train_dataset, eval_dataset\n\n\ndef process_pretraining_datasets_for_packing(\n    train_dataset, sequence_len, skip_position_ids=True, drop_attention_mask=False\n):\n    drop_outside_range = partial(filter_sequences_by_length, sequence_len=sequence_len)\n\n    train_dataset = train_dataset.filter(\n        drop_outside_range,\n        desc=\"Dropping Long Sequences\",\n        load_from_cache_file=False,\n    )\n    if not skip_position_ids:\n        train_dataset = train_dataset.map(\n            add_position_ids,\n            batched=True,\n            desc=\"Add position_id column (Pretraining Sample Packing)\",\n        )\n    if drop_attention_mask:\n        train_dataset = train_dataset.remove_columns(\"attention_mask\")\n\n    return train_dataset\n\n\ndef calculate_total_num_steps(cfg, train_dataset, update=True):\n    if (\n        not cfg.total_num_tokens\n        and not cfg.skip_prepare_dataset\n        and not cfg.reward_model\n    ):\n        total_num_tokens = np.sum(\n            train_dataset.select_columns(\"input_ids\")\n            .to_pandas()[\"input_ids\"]\n            .apply(len)\n            .values\n        )\n        LOG.debug(f\"total_num_tokens: {total_num_tokens:_}\")\n        if update:\n            cfg.total_num_tokens = total_num_tokens\n\n    skip_estimates = cfg.model_config_type == \"mamba\"\n\n    if (\n        not skip_estimates\n        and not cfg.total_supervised_tokens\n        and not cfg.skip_prepare_dataset\n        and not cfg.reward_model\n    ):\n        total_supervised_tokens = (\n            train_dataset.data.column(\"labels\")\n            .to_pandas()\n            .apply(lambda x: np.sum(np.array(x) != -100))\n            .sum()\n        )\n        LOG.debug(f\"`total_supervised_tokens: {total_supervised_tokens:_}`\")\n        if update:\n            cfg.total_supervised_tokens = total_supervised_tokens\n\n    if not skip_estimates and cfg.sample_packing:\n        # we have to drop anything longer then sequence len otherwise\n        # flash attention with position ids fails\n\n        if cfg.sample_packing_eff_est:\n            total_num_steps = (\n                # match count to len est in dataloader\n                int(\n                    math.floor(\n                        0.99\n                        * cfg.total_num_tokens\n                        / cfg.sample_packing_eff_est\n                        / cfg.sequence_len\n                        // cfg.batch_size\n                    )\n                    - 1\n                )\n                * cfg.num_epochs\n            )\n            LOG.debug(\n                f\"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}\"\n            )\n        else:\n            if cfg.flash_attention and not cfg.multipack_real_batches:\n                sampler_batch_size = 1\n                batch_max_len = cfg.micro_batch_size * cfg.sequence_len\n            else:\n                sampler_batch_size = cfg.micro_batch_size\n                batch_max_len = cfg.sequence_len\n            if cfg.curriculum_sampling:\n                sampler = SequentialSampler(train_dataset)\n            else:\n                sampler = RandomSampler(train_dataset)\n            sampler = MultipackBatchSampler(\n                sampler=sampler,\n                lengths=get_dataset_lengths(train_dataset),\n                batch_size=sampler_batch_size,\n                batch_max_len=batch_max_len,\n                group_size=cfg.sample_packing_group_size,\n                bin_size=cfg.sample_packing_bin_size,\n                sequential=cfg.sample_packing_sequentially,\n                drop_last=True,\n                num_processes=cfg.dataset_num_proc,\n                mp_start_method=cfg.sample_packing_mp_start_method or \"fork\",\n            )\n\n            data_loader = DataLoader(\n                train_dataset.remove_columns([\"length\"]),\n                batch_sampler=sampler,\n            )\n            data_loader_len = max(\n                1, len(data_loader) * cfg.micro_batch_size // cfg.batch_size\n            )\n            LOG.debug(f\"data_loader_len: {data_loader_len}\")\n            # FIXME: is there a bug here somewhere? the total num steps depends\n            # on the agreed on value for sample_packing_eff_est\n            total_num_steps = int(math.floor(data_loader_len * cfg.num_epochs))\n            if cfg.dataloader_drop_last:\n                # drop the last batch for each epoch\n                total_num_steps -= int(math.ceil(cfg.num_epochs))\n\n            def calc_sample_packing_eff_est(estimates: List[float]):\n                LOG.info(f\"sample_packing_eff_est across ranks: {repr(estimates)}\")\n                return max(estimates)\n\n            sample_packing_actual_eff_all = reduce_and_broadcast(\n                lambda: sampler.efficiency(),\n                calc_sample_packing_eff_est,\n            )\n            sample_packing_eff_est = (\n                math.ceil(sample_packing_actual_eff_all * 100.0) / 100.0\n            )\n            if update:\n                cfg.sample_packing_eff_est = sample_packing_eff_est\n            LOG.debug(f\"sample_packing_eff_est: {cfg.sample_packing_eff_est}\")\n    else:\n        total_num_steps = int(\n            math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)\n        )\n    LOG.debug(f\"total_num_steps: {total_num_steps}\")\n    return total_num_steps\n\n\ndef setup_torch_compile_env(cfg):\n    if cfg.torch_compile:\n        if not cfg.torch_compile_backend:\n            os.environ[\"ACCELERATE_DYNAMO_BACKEND\"] = \"INDUCTOR\"\n        else:\n            os.environ[\"ACCELERATE_DYNAMO_BACKEND\"] = cfg.torch_compile_backend.upper()\n\n\ndef setup_deepspeed_env(cfg, stage=None):\n    from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig\n\n    from axolotl.utils.distributed import distributed_state\n\n    if distributed_state and distributed_state.initialized:\n        raise RuntimeError(\n            \"Distributed State already initialized before Deepspeed setup\"\n        )\n\n    os.environ[\"ACCELERATE_USE_DEEPSPEED\"] = \"true\"\n    if isinstance(cfg.deepspeed, DictDefault):\n        with NamedTemporaryFile(\n            mode=\"w\", delete=False, suffix=\".json\", prefix=\"deepspeed_config_\"\n        ) as temp_file:\n            temp_file.write(json.dumps(cfg.deepspeed.to_dict(), indent=4))\n            temp_file.close()\n            cfg.deepspeed = str(temp_file.name)\n    os.environ[\"ACCELERATE_DEEPSPEED_CONFIG_FILE\"] = cfg.deepspeed\n    os.environ[\"ACCELERATE_GRADIENT_ACCUMULATION_STEPS\"] = str(\n        cfg.gradient_accumulation_steps\n    )\n    if stage:\n        os.environ[\"ACCELERATE_DEEPSPEED_ZERO_STAGE\"] = str(stage)\n        if stage == 3:\n            os.environ[\"ACCELERATE_DEEPSPEED_ZERO3_INIT\"] = \"true\"\n\n    device_count = torch.cuda.device_count()\n    if device_count == 1:\n        os.environ.setdefault(\"WORLD_SIZE\", \"1\")\n        os.environ.setdefault(\"LOCAL_RANK\", \"0\")\n        os.environ.setdefault(\"MASTER_ADDR\", \"0.0.0.0\")  # nosec B104\n        os.environ.setdefault(\"MASTER_PORT\", \"29500\")\n\n    # NOTE(djsaunde): The distribued state cannot be initialized prior to the\n    # ACCELERATE_USE_DEEPSPEED assignment, but it must be initialized some time prior\n    # to model load.\n    if (\n        int(os.environ.get(\"WORLD_SIZE\", \"1\")) == 1\n        and os.environ.get(\"AXOLOTL_IS_PREPROCESS\", \"0\") != \"1\"\n        and cfg.use_ray is not True\n    ):\n        os.environ[\"WORLD_SIZE\"] = \"1\"  # force it in case not set\n        os.environ[\"LOCAL_RANK\"] = \"0\"  # force it in case not set\n        os.environ[\"RANK\"] = os.environ.get(\"LOCAL_RANK\", \"0\")\n        import deepspeed.comm as dist\n\n        dist.init_distributed(\n            dist_backend=\"nccl\", auto_mpi_discovery=False, dist_init_required=True\n        )\n    init_distributed_state()\n\n    # If we don't assign this, it doesn't actually get set in the accelerate weakref\n    _ = HfTrainerDeepSpeedConfig(cfg.deepspeed)\n\n\ndef setup_fsdp_envs(cfg):\n    os.environ[\"ACCELERATE_USE_FSDP\"] = \"true\"\n\n    # TODO @SalmanMohammadi remove FSDP1 args in 0.12\n    if str(cfg.fsdp_version) == \"2\":\n        os.environ[\"FSDP_VERSION\"] = \"2\"\n    if cfg.fsdp_config.activation_checkpointing:\n        os.environ[\"FSDP_ACTIVATION_CHECKPOINTING\"] = \"true\"\n    if cfg.fsdp_config.offload_params:\n        os.environ[\"FSDP_OFFLOAD_PARAMS\"] = \"true\"\n    if cfg.fsdp_config.sync_module_states:\n        os.environ[\"FSDP_SYNC_MODULE_STATES\"] = \"true\"\n    if cfg.fsdp_config.cpu_ram_efficient_loading:\n        os.environ[\"FSDP_CPU_RAM_EFFICIENT_LOADING\"] = \"true\"\n    if cfg.fsdp_config.use_orig_params:\n        os.environ[\"FSDP_USE_ORIG_PARAMS\"] = \"true\"\n    if cfg.fsdp_config.state_dict_type:\n        os.environ[\"FSDP_STATE_DICT_TYPE\"] = cfg.fsdp_config.state_dict_type\n    if cfg.fsdp_config.cpu_offload_pin_memory is not None:\n        os.environ[\"FSDP_CPU_OFFLOAD_PIN_MEMORY\"] = str(\n            cfg.fsdp_config.cpu_offload_pin_memory\n        ).lower()\n    if cfg.fsdp_config.auto_wrap_policy:\n        os.environ[\"FSDP_AUTO_WRAP_POLICY\"] = cfg.fsdp_config.auto_wrap_policy\n    if cfg.fsdp_config.transformer_layer_cls_to_wrap:\n        os.environ[\"FSDP_TRANSFORMER_CLS_TO_WRAP\"] = (\n            cfg.fsdp_config.transformer_layer_cls_to_wrap\n        )\n    if cfg.fsdp_config.reshard_after_forward:\n        os.environ[\"FSDP_RESHARD_AFTER_FORWARD\"] = \"true\"\n\n\ndef setup_parallelism_envs(cfg):\n    set_accelerate_parallelism_config = False\n    if cfg.tensor_parallel_size and cfg.tensor_parallel_size > 1:\n        set_accelerate_parallelism_config = True\n        os.environ[\"PARALLELISM_CONFIG_TP_SIZE\"] = str(cfg.tensor_parallel_size)\n    if cfg.dp_shard_size and cfg.dp_shard_size > 1:\n        set_accelerate_parallelism_config = True\n        os.environ[\"PARALLELISM_CONFIG_DP_SHARD_SIZE\"] = str(cfg.dp_shard_size)\n    if cfg.dp_replicate_size and cfg.dp_replicate_size > 1:\n        set_accelerate_parallelism_config = True\n        os.environ[\"PARALLELISM_CONFIG_DP_REPLICATE_SIZE\"] = str(cfg.dp_replicate_size)\n    if cfg.context_parallel_size and cfg.context_parallel_size > 1:\n        set_accelerate_parallelism_config = True\n        os.environ[\"PARALLELISM_CONFIG_CP_SIZE\"] = str(cfg.context_parallel_size)\n        os.environ[\"ACCELERATE_ALLOW_CP_STANDALONE\"] = \"true\"\n        from axolotl.monkeypatch.accelerate.parallelism_config import patch_prepare_cp\n\n        patch_prepare_cp()\n    if set_accelerate_parallelism_config:\n        os.environ[\"ACCELERATE_USE_PARALLELISM_CONFIG\"] = \"true\"\n\n\ndef prepare_optim_env(cfg):\n    if not check_cuda_p2p_ib_support():\n        if os.getenv(\"NCCL_P2P_DISABLE\") is None:\n            LOG.warning(\"P2P support not detected, setting `NCCL_P2P_DISABLE=1`\")\n            os.environ[\"NCCL_P2P_DISABLE\"] = \"1\"\n    # TODO @SalmanMohammadi remove the cfg.fsdp check in 0.12\n    if cfg.fsdp or cfg.fsdp_config:\n        cfg.fsdp = True if not cfg.fsdp else cfg.fsdp\n        setup_fsdp_envs(cfg)\n    elif cfg.deepspeed:\n        stage = None\n        deepspeed_config = None\n        # check if the cfg.deepspeed is a file\n        if isinstance(cfg.deepspeed, DictDefault):\n            deepspeed_config = cfg.deepspeed\n        elif os.path.isfile(cfg.deepspeed):\n            # parse with json\n            with open(cfg.deepspeed, \"r\", encoding=\"utf-8\") as fin:\n                deepspeed_config = json.load(fin)\n        if deepspeed_config:\n            stage = deepspeed_config.get(\"zero_optimization\", {}).get(\"stage\", None)\n        setup_deepspeed_env(cfg, stage=stage)\n\n    setup_parallelism_envs(cfg)\n    setup_torch_compile_env(cfg)\n\n    if cfg.fp8:\n        os.environ[\"ACCELERATE_MIXED_PRECISION\"] = \"fp8\"\n    elif (cfg.bf16 == \"auto\" and is_torch_bf16_gpu_available()) or cfg.bf16 is True:\n        os.environ[\"ACCELERATE_MIXED_PRECISION\"] = \"bf16\"\n    elif cfg.fp16:\n        os.environ[\"ACCELERATE_MIXED_PRECISION\"] = \"fp16\"\n    else:\n        os.environ[\"ACCELERATE_MIXED_PRECISION\"] = \"no\"\n\n\ndef setup_trainer(\n    cfg,\n    train_dataset,\n    eval_dataset,\n    model,\n    tokenizer,\n    processor,\n    total_num_steps,\n    model_ref=None,\n    peft_config=None,\n):\n    \"\"\"\n    Helper method for instantiating and building a (causal or RLHF) trainer.\n\n    Args:\n        cfg: Axolotl config object containing training parameters.\n        train_dataset: Dataset to use for training.\n        eval_dataset: Dataset to use for evaluation.\n        model: The model to train.\n        tokenizer: Tokenizer for processing text input.\n        processor: Processor for data preparation.\n        total_num_steps: The total number of training steps.\n        model_ref: Optional reference model for RLHF training. Default is None.\n        peft_config: Optional PEFT (Parameter-Efficient Fine-Tuning) configuration. Default is None.\n\n    Returns:\n        A trainer instance (either `HFRLTrainer` or `HFCausalTrainer`) configured based\n            on the provided parameters.\n    \"\"\"\n    from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder\n\n    if cfg.rl:\n        trainer_builder = HFRLTrainerBuilder(cfg, model, tokenizer, processor)\n        trainer_builder.model_ref = model_ref\n        trainer_builder.peft_config = peft_config\n    else:\n        trainer_builder = HFCausalTrainerBuilder(cfg, model, tokenizer, processor)\n\n    trainer_builder.train_dataset = train_dataset\n    trainer_builder.eval_dataset = eval_dataset\n\n    return trainer_builder.build(total_num_steps)\n"
  },
  {
    "path": "src/axolotl/utils/wandb_.py",
    "content": "\"\"\"Module for wandb utilities\"\"\"\n\nimport os\n\nfrom axolotl.utils.dict import DictDefault\n\n\ndef setup_wandb_env_vars(cfg: DictDefault):\n    for key in cfg.keys():\n        if key.startswith(\"wandb_\"):\n            value = cfg.get(key, \"\")\n\n            if value and isinstance(value, str) and len(value) > 0:\n                os.environ[key.upper()] = value\n\n    # Enable wandb if project name is present\n    if cfg.wandb_project and len(cfg.wandb_project) > 0:\n        cfg.use_wandb = True\n"
  },
  {
    "path": "src/setuptools_axolotl_dynamic_dependencies.py",
    "content": "\"\"\"\ndynamic requirements for axolotl\n\"\"\"\n\nimport platform\nimport re\nfrom importlib.metadata import PackageNotFoundError, version\n\nfrom setuptools.command.build_py import build_py as _build_py\n\n\ndef parse_requirements():\n    _install_requires = []\n    _dependency_links = []\n    with open(\"./requirements.txt\", encoding=\"utf-8\") as requirements_file:\n        lines = [r.strip() for r in requirements_file.readlines()]\n        for line in lines:\n            is_extras = (\n                \"flash-attn\" in line\n                or \"flash-attention\" in line\n                or \"deepspeed\" in line\n                or \"mamba-ssm\" in line\n                or \"lion-pytorch\" in line\n            )\n            if line.startswith(\"--extra-index-url\"):\n                # Handle custom index URLs\n                _, url = line.split()\n                _dependency_links.append(url)\n            elif not is_extras and line and line[0] != \"#\":\n                # Handle standard packages\n                _install_requires.append(line)\n\n    try:\n        xformers_version = [req for req in _install_requires if \"xformers\" in req][0]\n        torchao_version = [req for req in _install_requires if \"torchao\" in req][0]\n\n        if \"Darwin\" in platform.system():\n            # don't install xformers on MacOS\n            _install_requires.pop(_install_requires.index(xformers_version))\n        else:\n            # detect the version of torch already installed\n            # and set it so dependencies don't clobber the torch version\n            try:\n                torch_version = version(\"torch\")\n            except PackageNotFoundError:\n                torch_version = \"2.5.1\"\n            _install_requires.append(f\"torch=={torch_version}\")\n\n            version_match = re.match(r\"^(\\d+)\\.(\\d+)(?:\\.(\\d+))?\", torch_version)\n            if version_match:\n                major, minor, patch = version_match.groups()\n                major, minor = int(major), int(minor)\n                patch = (\n                    int(patch) if patch is not None else 0\n                )  # Default patch to 0 if not present\n            else:\n                raise ValueError(\"Invalid version format\")\n\n            if (major, minor) >= (2, 5):\n                _install_requires.pop(_install_requires.index(xformers_version))\n                if patch == 0:\n                    _install_requires.append(\"xformers==0.0.28.post2\")\n                else:\n                    _install_requires.append(\"xformers==0.0.28.post3\")\n            elif (major, minor) >= (2, 4):\n                if patch == 0:\n                    _install_requires.pop(_install_requires.index(xformers_version))\n                    _install_requires.append(\"xformers>=0.0.27\")\n                else:\n                    _install_requires.pop(_install_requires.index(xformers_version))\n                    _install_requires.append(\"xformers==0.0.28.post1\")\n            elif (major, minor) >= (2, 3):\n                _install_requires.pop(_install_requires.index(torchao_version))\n                if patch == 0:\n                    _install_requires.pop(_install_requires.index(xformers_version))\n                    _install_requires.append(\"xformers>=0.0.26.post1\")\n                else:\n                    _install_requires.pop(_install_requires.index(xformers_version))\n                    _install_requires.append(\"xformers>=0.0.27\")\n            elif (major, minor) >= (2, 2):\n                _install_requires.pop(_install_requires.index(torchao_version))\n                _install_requires.pop(_install_requires.index(xformers_version))\n                _install_requires.append(\"xformers>=0.0.25.post1\")\n            else:\n                _install_requires.pop(_install_requires.index(torchao_version))\n                _install_requires.pop(_install_requires.index(xformers_version))\n                _install_requires.append(\"xformers>=0.0.23.post1\")\n\n    except PackageNotFoundError:\n        pass\n    return _install_requires, _dependency_links\n\n\nclass BuildPyCommand(_build_py):\n    \"\"\"\n    custom build_py command to parse dynamic requirements\n    \"\"\"\n\n    def finalize_options(self):\n        super().finalize_options()\n        install_requires, _ = parse_requirements()\n        self.distribution.install_requires = install_requires\n"
  },
  {
    "path": "styles.css",
    "content": "/* TYPOGRAPHY SECTION */\n\n/* Import fonts */\n@import url('https://fonts.googleapis.com/css2?family=Be+Vietnam+Pro:wght@400;500&display=swap');\n@import url('https://fonts.googleapis.com/css2?family=JetBrains+Mono:wght@400&display=swap');\n\n/* Typography hierarchy */\n:root {\n    --font-title: 'Be Vietnam Pro', sans-serif;\n    --font-body: 'JetBrains Mono', monospace;\n}\n\n/* Title (h1) */\nh1 {\n    font-family: var(--font-title);\n    font-weight: 400;\n    font-size: 3rem;\n    line-height: 1.1;\n    letter-spacing: -0.05em;\n    font-feature-settings: \"ss01\" on;\n}\n\n/* Heading (h2) */\nh2 {\n    font-family: var(--font-title);\n    font-weight: 500;\n    font-size: 1.5rem;\n    line-height: 1.2;\n    letter-spacing: -0.03em;\n    font-feature-settings: \"ss01\" on;\n}\n\n/* Subtitle/Preamble */\nh3,\nh4 {\n    font-family: var(--font-body);\n    font-weight: 400;\n    font-size: 1.25rem;\n    line-height: 1.5;\n    letter-spacing: -0.02em;\n}\n\n/* Body text */\nbody {\n    font-family: var(--font-body);\n    font-weight: 400;\n    font-size: 1rem;\n    line-height: 1.5;\n    letter-spacing: -0.02em;\n}\n\n/* Links */\na {\n    font-family: var(--font-body);\n    font-weight: 400;\n    font-size: 0.875rem;\n    line-height: 1;\n    letter-spacing: -0.02em;\n}\n\n/* NAV BAR SECTION */\n\n/* Navbar logo styling */\n.navbar-brand img {\n    height: 32px;\n    margin-right: 10px;\n}\n\n/* COLORS SECTION */\n\n/* Brand colors */\n:root {\n    --white: #ffffff;\n    --greige-300: #EEEEE7;\n    --greige-600: #CCCAC0;\n    --black: #141310;\n    --lime: #E3F8A8;\n    --cyan: #A0F4EA;\n    --purple: #C8D0F8;\n}\n\n/* Base styles */\nbody {\n    background-color: var(--black);\n    color: var(--greige-300);\n}\n\n/* Navigation */\n.navbar {\n    background-color: var(--black) !important;\n}\n\n.navbar-dark .navbar-nav .nav-link {\n    color: var(--greige-300);\n}\n\n.navbar-dark .navbar-nav .nav-link:hover {\n    color: var(--lime);\n}\n\n/* Sidebar */\n.sidebar-navigation {\n    background-color: var(--black);\n    border-right: 1px solid var(--greige-600);\n}\n\n.sidebar nav[role=\"doc-toc\"] ul>li>a {\n    color: var(--greige-300);\n}\n\n.sidebar nav[role=\"doc-toc\"] ul>li>a:hover {\n    color: var(--lime);\n}\n\n/* Links */\na {\n    color: var(--lime);\n}\n\na:hover {\n    color: var(--cyan);\n}\n\n/* Headers */\nh1,\nh2,\nh3,\nh4,\nh5,\nh6 {\n    color: var(--white);\n}\n\n/* Code blocks */\npre {\n    background-color: #1a1a1a !important;\n    border: 1px solid var(--greige-600);\n}\n\n/* Tables */\n.table {\n    color: var(--greige-300);\n}\n\n/* TOC */\n#toc-title {\n    color: var(--white);\n}\n\n.toc-active {\n    color: var(--lime) !important;\n}\n\n/* Buttons */\n.btn-primary {\n    background-color: var(--lime);\n    color: var(--black);\n    border: none;\n}\n\n.btn-primary:hover {\n    background-color: var(--cyan);\n    color: var(--black);\n}\n\n/* For inline code (single backtick) */\ncode {\n    background-color: #1a1a1a !important;\n    color: var(--lime) !important;\n    padding: 2px 4px;\n    border-radius: 4px;\n}\n\n/* For inline code that is also a link */\na code {\n    color: var(--cyan) !important;\n}\n\n/* For code blocks (triple backtick) */\npre.sourceCode {\n    background-color: #1a1a1a !important;\n}\n\n/* Make comments in bash/shell scripts green */\ncode span.co {\n    color: #5cb85c !important;\n}\n\n/* Remove underlines from JSON comments and make them green */\ncode span.er {\n    color: #5cb85c !important;\n    text-decoration: none !important;\n}\n\n/* API Documentation Styling */\n\n/* Improve docstring section rendering */\n.level3 p {\n    white-space: pre-line !important;\n}\n\n/* Format docstring sections */\n.level3 p strong {\n    display: block;\n    margin-top: 1em;\n    font-weight: bold;\n    color: var(--cyan);\n}\n\n/* Add spacing after sections */\n.level3 p:has(strong) {\n    margin-bottom: 0.5em;\n}\n\n/* Format Args and Returns sections */\np:has(code) {\n    line-height: 1.6;\n}\n\n/* Function signatures */\n.sourceCode {\n    margin-bottom: 1.5em;\n}\n\n/* Parameter tables */\n.doc-section-parameters table,\n.doc-section-returns table {\n    margin-top: 1em;\n    margin-bottom: 1.5em;\n}\n\n/* Make parameter and returns headers smaller */\nh2.anchored[data-anchor-id=\"parameters\"],\nh2.anchored[data-anchor-id=\"returns\"],\n.doc-section-parameters h4,\n.doc-section-returns h4 {\n    font-size: 1.25rem;\n    margin-top: 2rem;\n    margin-bottom: 1rem;\n    color: var(--lime);\n    border-bottom: 1px solid var(--lime);\n    padding-bottom: 0.3rem;\n    font-family: var(--font-body);\n    font-weight: 500;\n    letter-spacing: normal;\n}\n\n/* Style documentation tables */\ntable {\n    width: 100%;\n    margin-bottom: 1.5rem;\n    border-collapse: collapse;\n}\n\ntable th {\n    background-color: #1a1a1a;\n    padding: 0.5rem 1rem;\n    border-bottom: 2px solid var(--greige-600);\n    text-align: left;\n}\n\ntable td {\n    padding: 0.5rem 1rem;\n    border-bottom: 1px solid var(--greige-600);\n}\n\n/* Code in table cells */\ntable td code {\n    background-color: transparent !important;\n    padding: 0;\n}\n\n/* Improve spacing in parameter and return tables */\n.doc-section-parameters,\n.doc-section-returns {\n    margin-top: 1rem;\n}\n"
  },
  {
    "path": "tests/__init__.py",
    "content": ""
  },
  {
    "path": "tests/cli/__init__.py",
    "content": ""
  },
  {
    "path": "tests/cli/conftest.py",
    "content": "\"\"\"Shared pytest fixtures for cli module.\"\"\"\n\nimport pytest\nfrom click.testing import CliRunner\n\nVALID_TEST_CONFIG = \"\"\"\nbase_model: HuggingFaceTB/SmolLM2-135M\ndatasets:\n  - path: mhenrichsen/alpaca_2k_test\n    type: alpaca\nsequence_len: 2048\nmax_steps: 1\nmicro_batch_size: 1\ngradient_accumulation_steps: 1\nlearning_rate: 1e-3\nspecial_tokens:\n  pad_token: <|endoftext|>\n\"\"\"\n\n\n@pytest.fixture\ndef cli_runner():\n    return CliRunner()\n\n\n@pytest.fixture\ndef valid_test_config():\n    return VALID_TEST_CONFIG\n\n\n@pytest.fixture\ndef config_path(tmp_path):\n    \"\"\"Creates a temporary config file\"\"\"\n    path = tmp_path / \"config.yml\"\n    path.write_text(VALID_TEST_CONFIG)\n\n    return path\n"
  },
  {
    "path": "tests/cli/test_cli_base.py",
    "content": "\"\"\"Base test class for CLI commands.\"\"\"\n\nfrom pathlib import Path\nfrom unittest.mock import patch\n\nfrom axolotl.cli.main import cli\n\n\nclass BaseCliTest:\n    \"\"\"Base class for CLI command tests.\"\"\"\n\n    def _test_cli_validation(self, cli_runner, command: str):\n        \"\"\"Test CLI validation for a command.\n\n        Args:\n            cli_runner: CLI runner fixture\n            command: Command to test (train/evaluate)\n        \"\"\"\n        # Test missing config file\n        result = cli_runner.invoke(cli, [command, \"--launcher\", \"python\"])\n        assert result.exit_code != 0\n\n        # Test non-existent config file\n        result = cli_runner.invoke(\n            cli, [command, \"nonexistent.yml\", \"--launcher\", \"python\"]\n        )\n        assert result.exit_code != 0\n        assert \"Error: Invalid value for 'CONFIG'\" in result.output\n\n    def _test_basic_execution(\n        self,\n        cli_runner,\n        tmp_path: Path,\n        valid_test_config: str,\n        command: str,\n        train: bool = True,\n    ):\n        \"\"\"Test basic execution with accelerate.\n\n        Args:\n            cli_runner: CLI runner fixture\n            tmp_path: Temporary path fixture\n            valid_test_config: Valid config fixture\n            command: Command to test (train/evaluate)\n            train: Whether to test training (default) or evaluation\n        \"\"\"\n        config_path = tmp_path / \"config.yml\"\n        config_path.write_text(valid_test_config)\n\n        mock_fn = \"os.execvpe\" if command == \"train\" else \"subprocess.run\"\n\n        with patch(mock_fn) as mock:\n            result = cli_runner.invoke(cli, [command, str(config_path)])\n\n            assert mock.called\n\n            expected = [\n                \"accelerate\",\n                \"launch\",\n                \"-m\",\n                f\"axolotl.cli.{command}\",\n                str(config_path),\n                \"--debug=False\",\n                \"--debug-text-only=False\",\n                \"--debug-num-examples=0\",\n            ]\n            if train:\n                expected.append(\"--shard=False\")\n\n            if command == \"train\":\n                assert mock.call_args.args[0] == \"accelerate\"\n                assert mock.call_args.args[1] == expected\n            else:\n                assert mock.call_args.args[0] == expected\n                assert mock.call_args.kwargs == {\"check\": True}\n            assert result.exit_code == 0\n\n    def _test_cli_overrides(self, tmp_path: Path, valid_test_config: str):\n        \"\"\"Test CLI argument overrides.\n\n        Args:\n            tmp_path: Temporary path fixture\n            valid_test_config: Valid config fixture\n            command: Command to test (train/evaluate)\n        \"\"\"\n        config_path = tmp_path / \"config.yml\"\n        output_dir = tmp_path / \"model-out\"\n\n        test_config = valid_test_config.replace(\n            \"output_dir: model-out\", f\"output_dir: {output_dir}\"\n        )\n        config_path.write_text(test_config)\n        return config_path\n"
  },
  {
    "path": "tests/cli/test_cli_evaluate.py",
    "content": "\"\"\"Tests for evaluate CLI command.\"\"\"\n\nfrom unittest.mock import patch\n\nfrom axolotl.cli.main import cli\n\nfrom .test_cli_base import BaseCliTest\n\n\nclass TestEvaluateCommand(BaseCliTest):\n    \"\"\"Test cases for evaluate command.\"\"\"\n\n    cli = cli\n\n    def test_evaluate_cli_validation(self, cli_runner):\n        \"\"\"Test CLI validation\"\"\"\n        self._test_cli_validation(cli_runner, \"evaluate\")\n\n    def test_evaluate_basic_execution(self, cli_runner, tmp_path, valid_test_config):\n        \"\"\"Test basic successful execution\"\"\"\n        self._test_basic_execution(\n            cli_runner, tmp_path, valid_test_config, \"evaluate\", train=False\n        )\n\n    def test_evaluate_basic_execution_no_accelerate(\n        self, cli_runner, tmp_path, valid_test_config\n    ):\n        \"\"\"Test basic successful execution without accelerate\"\"\"\n        config_path = tmp_path / \"config.yml\"\n        config_path.write_text(valid_test_config)\n\n        with patch(\"axolotl.cli.evaluate.do_evaluate\") as mock_evaluate:\n            result = cli_runner.invoke(\n                cli,\n                [\n                    \"evaluate\",\n                    str(config_path),\n                    \"--launcher\",\n                    \"python\",\n                ],\n                catch_exceptions=False,\n            )\n\n            assert result.exit_code == 0\n            mock_evaluate.assert_called_once()\n\n    def test_evaluate_cli_overrides(self, cli_runner, tmp_path, valid_test_config):\n        \"\"\"Test CLI arguments properly override config values\"\"\"\n        config_path = self._test_cli_overrides(tmp_path, valid_test_config)\n\n        with patch(\"axolotl.cli.evaluate.do_evaluate\") as mock_evaluate:\n            result = cli_runner.invoke(\n                cli,\n                [\n                    \"evaluate\",\n                    str(config_path),\n                    \"--micro-batch-size\",\n                    \"2\",\n                    \"--sequence-len\",\n                    \"128\",\n                    \"--launcher\",\n                    \"python\",\n                ],\n                catch_exceptions=False,\n            )\n\n            assert result.exit_code == 0\n            mock_evaluate.assert_called_once()\n            cfg = mock_evaluate.call_args[0][0]\n            assert cfg.micro_batch_size == 2\n            assert cfg.sequence_len == 128\n\n    def test_evaluate_with_launcher_args_torchrun(\n        self, cli_runner, tmp_path, valid_test_config\n    ):\n        \"\"\"Test evaluate with torchrun launcher arguments\"\"\"\n        config_path = tmp_path / \"config.yml\"\n        config_path.write_text(valid_test_config)\n\n        with patch(\"subprocess.run\") as mock_subprocess:\n            result = cli_runner.invoke(\n                cli,\n                [\n                    \"evaluate\",\n                    str(config_path),\n                    \"--launcher\",\n                    \"torchrun\",\n                    \"--\",\n                    \"--nproc_per_node=2\",\n                    \"--nnodes=1\",\n                ],\n                catch_exceptions=False,\n            )\n\n            assert result.exit_code == 0\n            mock_subprocess.assert_called_once()\n\n            # Verify launcher args are passed to torchrun\n            called_cmd = mock_subprocess.call_args.args[0]\n            assert called_cmd[0] == \"torchrun\"\n            assert \"--nproc_per_node=2\" in called_cmd\n            assert \"--nnodes=1\" in called_cmd\n            assert \"-m\" in called_cmd\n            assert \"axolotl.cli.evaluate\" in called_cmd\n\n    def test_evaluate_with_launcher_args_accelerate(\n        self, cli_runner, tmp_path, valid_test_config\n    ):\n        \"\"\"Test evaluate with accelerate launcher arguments\"\"\"\n        config_path = tmp_path / \"config.yml\"\n        config_path.write_text(valid_test_config)\n\n        with patch(\"subprocess.run\") as mock_subprocess:\n            result = cli_runner.invoke(\n                cli,\n                [\n                    \"evaluate\",\n                    str(config_path),\n                    \"--launcher\",\n                    \"accelerate\",\n                    \"--\",\n                    \"--config_file=accelerate_config.yml\",\n                    \"--num_processes=4\",\n                ],\n                catch_exceptions=False,\n            )\n\n            assert result.exit_code == 0\n            mock_subprocess.assert_called_once()\n\n            # Verify launcher args are passed to accelerate\n            called_cmd = mock_subprocess.call_args.args[0]\n            assert called_cmd[0] == \"accelerate\"\n            assert called_cmd[1] == \"launch\"\n            assert \"--config_file=accelerate_config.yml\" in called_cmd\n            assert \"--num_processes=4\" in called_cmd\n            assert \"-m\" in called_cmd\n            assert \"axolotl.cli.evaluate\" in called_cmd\n\n    def test_evaluate_backward_compatibility_no_launcher_args(\n        self, cli_runner, tmp_path, valid_test_config\n    ):\n        \"\"\"Test that existing evaluate commands work without launcher args\"\"\"\n        config_path = tmp_path / \"config.yml\"\n        config_path.write_text(valid_test_config)\n\n        with patch(\"subprocess.run\") as mock_subprocess:\n            result = cli_runner.invoke(\n                cli,\n                [\n                    \"evaluate\",\n                    str(config_path),\n                    \"--launcher\",\n                    \"accelerate\",\n                    \"--micro-batch-size\",\n                    \"2\",\n                ],\n                catch_exceptions=False,\n            )\n\n            assert result.exit_code == 0\n            mock_subprocess.assert_called_once()\n\n            # Verify no launcher args contamination\n            called_cmd = mock_subprocess.call_args.args[0]\n            assert called_cmd[0] == \"accelerate\"\n            assert called_cmd[1] == \"launch\"\n            # Should not contain any extra launcher args\n            launcher_section = called_cmd[2 : called_cmd.index(\"-m\")]\n            assert (\n                len(launcher_section) == 0\n            )  # No launcher args between 'launch' and '-m'\n"
  },
  {
    "path": "tests/cli/test_cli_fetch.py",
    "content": "\"\"\"pytest tests for axolotl CLI fetch command.\"\"\"\n\nfrom unittest.mock import patch\n\nfrom axolotl.cli.main import fetch\n\n\ndef test_fetch_cli_examples(cli_runner):\n    \"\"\"Test fetch command with examples directory\"\"\"\n    with patch(\"axolotl.cli.main.fetch_from_github\") as mock_fetch:\n        result = cli_runner.invoke(fetch, [\"examples\"])\n\n        assert result.exit_code == 0\n        mock_fetch.assert_called_once_with(\"examples/\", None)\n\n\ndef test_fetch_cli_deepspeed(cli_runner):\n    \"\"\"Test fetch command with deepspeed_configs directory\"\"\"\n    with patch(\"axolotl.cli.main.fetch_from_github\") as mock_fetch:\n        result = cli_runner.invoke(fetch, [\"deepspeed_configs\"])\n\n        assert result.exit_code == 0\n        mock_fetch.assert_called_once_with(\"deepspeed_configs/\", None)\n\n\ndef test_fetch_cli_with_dest(cli_runner, tmp_path):\n    \"\"\"Test fetch command with custom destination\"\"\"\n    with patch(\"axolotl.cli.main.fetch_from_github\") as mock_fetch:\n        custom_dir = tmp_path / \"tmp_examples\"\n        result = cli_runner.invoke(fetch, [\"examples\", \"--dest\", str(custom_dir)])\n\n        assert result.exit_code == 0\n        mock_fetch.assert_called_once_with(\"examples/\", str(custom_dir))\n\n\ndef test_fetch_cli_invalid_directory(cli_runner):\n    \"\"\"Test fetch command with invalid directory choice\"\"\"\n    result = cli_runner.invoke(fetch, [\"invalid\"])\n    assert result.exit_code != 0\n"
  },
  {
    "path": "tests/cli/test_cli_inference.py",
    "content": "\"\"\"pytest tests for axolotl CLI inference command.\"\"\"\n\nfrom unittest.mock import patch\n\nfrom axolotl.cli.main import cli\n\n\ndef test_inference_basic(cli_runner, config_path):\n    \"\"\"Test basic inference\"\"\"\n    with patch(\"axolotl.cli.inference.do_inference\") as mock:\n        result = cli_runner.invoke(\n            cli,\n            [\"inference\", str(config_path), \"--launcher\", \"python\"],\n            catch_exceptions=False,\n        )\n\n        assert mock.called\n        assert result.exit_code == 0\n\n\ndef test_inference_gradio(cli_runner, config_path):\n    \"\"\"Test basic inference (gradio path)\"\"\"\n    with patch(\"axolotl.cli.inference.do_inference_gradio\") as mock:\n        result = cli_runner.invoke(\n            cli,\n            [\"inference\", str(config_path), \"--launcher\", \"python\", \"--gradio\"],\n            catch_exceptions=False,\n        )\n\n        assert mock.called\n        assert result.exit_code == 0\n\n\ndef test_inference_with_launcher_args_torchrun(cli_runner, config_path):\n    \"\"\"Test inference with torchrun launcher arguments\"\"\"\n    with patch(\"subprocess.run\") as mock_subprocess:\n        result = cli_runner.invoke(\n            cli,\n            [\n                \"inference\",\n                str(config_path),\n                \"--launcher\",\n                \"torchrun\",\n                \"--\",\n                \"--nproc_per_node=2\",\n                \"--nnodes=1\",\n            ],\n            catch_exceptions=False,\n        )\n\n        assert result.exit_code == 0\n        mock_subprocess.assert_called_once()\n\n        # Verify launcher args are passed to torchrun\n        called_cmd = mock_subprocess.call_args.args[0]\n        assert called_cmd[0] == \"torchrun\"\n        assert \"--nproc_per_node=2\" in called_cmd\n        assert \"--nnodes=1\" in called_cmd\n        assert \"-m\" in called_cmd\n        assert \"axolotl.cli.inference\" in called_cmd\n\n\ndef test_inference_with_launcher_args_accelerate(cli_runner, config_path):\n    \"\"\"Test inference with accelerate launcher arguments\"\"\"\n    with patch(\"subprocess.run\") as mock_subprocess:\n        result = cli_runner.invoke(\n            cli,\n            [\n                \"inference\",\n                str(config_path),\n                \"--launcher\",\n                \"accelerate\",\n                \"--\",\n                \"--config_file=accelerate_config.yml\",\n                \"--num_processes=4\",\n            ],\n            catch_exceptions=False,\n        )\n\n        assert result.exit_code == 0\n        mock_subprocess.assert_called_once()\n\n        # Verify launcher args are passed to accelerate\n        called_cmd = mock_subprocess.call_args.args[0]\n        assert called_cmd[0] == \"accelerate\"\n        assert called_cmd[1] == \"launch\"\n        assert \"--config_file=accelerate_config.yml\" in called_cmd\n        assert \"--num_processes=4\" in called_cmd\n        assert \"-m\" in called_cmd\n        assert \"axolotl.cli.inference\" in called_cmd\n\n\ndef test_inference_gradio_with_launcher_args(cli_runner, config_path):\n    \"\"\"Test inference with gradio and launcher arguments\"\"\"\n    with patch(\"subprocess.run\") as mock_subprocess:\n        result = cli_runner.invoke(\n            cli,\n            [\n                \"inference\",\n                str(config_path),\n                \"--launcher\",\n                \"accelerate\",\n                \"--gradio\",\n                \"--\",\n                \"--num_processes=2\",\n            ],\n            catch_exceptions=False,\n        )\n\n        assert result.exit_code == 0\n        mock_subprocess.assert_called_once()\n\n        # Verify both gradio flag and launcher args are present\n        called_cmd = mock_subprocess.call_args.args[0]\n        assert called_cmd[0] == \"accelerate\"\n        assert called_cmd[1] == \"launch\"\n        assert \"--num_processes=2\" in called_cmd\n        assert \"--gradio\" in called_cmd\n        assert \"-m\" in called_cmd\n        assert \"axolotl.cli.inference\" in called_cmd\n\n\ndef test_inference_backward_compatibility_no_launcher_args(cli_runner, config_path):\n    \"\"\"Test that existing inference commands work without launcher args\"\"\"\n    with patch(\"subprocess.run\") as mock_subprocess:\n        result = cli_runner.invoke(\n            cli,\n            [\n                \"inference\",\n                str(config_path),\n                \"--launcher\",\n                \"accelerate\",\n            ],\n            catch_exceptions=False,\n        )\n\n        assert result.exit_code == 0\n        mock_subprocess.assert_called_once()\n\n        # Verify no launcher args contamination\n        called_cmd = mock_subprocess.call_args.args[0]\n        assert called_cmd[0] == \"accelerate\"\n        assert called_cmd[1] == \"launch\"\n        # Should not contain any extra launcher args\n        launcher_section = called_cmd[2 : called_cmd.index(\"-m\")]\n        assert len(launcher_section) == 0  # No launcher args between 'launch' and '-m'\n"
  },
  {
    "path": "tests/cli/test_cli_interface.py",
    "content": "\"\"\"General pytest tests for axolotl.cli.main interface.\"\"\"\n\nfrom axolotl.cli.main import build_command, cli\n\n\ndef test_build_command():\n    \"\"\"Test converting dict of options to CLI arguments\"\"\"\n    base_cmd = [\"accelerate\", \"launch\"]\n    options = {\n        \"learning_rate\": 1e-4,\n        \"batch_size\": 8,\n        \"debug\": True,\n        \"use_fp16\": False,\n        \"null_value\": None,\n    }\n\n    result = build_command(base_cmd, options)\n    assert result == [\n        \"accelerate\",\n        \"launch\",\n        \"--learning-rate=0.0001\",\n        \"--batch-size=8\",\n        \"--debug=True\",\n        \"--use-fp16=False\",\n    ]\n\n\ndef test_invalid_command_options(cli_runner):\n    \"\"\"Test handling of invalid command options\"\"\"\n    result = cli_runner.invoke(\n        cli,\n        [\n            \"train\",\n            \"config.yml\",\n            \"--invalid-option\",\n            \"value\",\n        ],\n    )\n    assert result.exit_code != 0\n    assert \"does not exist\" in result.output\n\n\ndef test_required_config_argument(cli_runner):\n    \"\"\"Test commands fail properly when config argument is missing\"\"\"\n    result = cli_runner.invoke(cli, [\"train\"])\n    assert result.exit_code != 0\n    assert \"Missing argument 'CONFIG'\" in result.output\n"
  },
  {
    "path": "tests/cli/test_cli_merge_lora.py",
    "content": "\"\"\"pytest tests for axolotl CLI merge_lora command.\"\"\"\n\nfrom unittest.mock import patch\n\nfrom axolotl.cli.main import cli\n\n\ndef test_merge_lora_basic(cli_runner, config_path):\n    \"\"\"Test basic merge_lora command\"\"\"\n    with patch(\"axolotl.cli.merge_lora.do_cli\") as mock_do_cli:\n        result = cli_runner.invoke(cli, [\"merge-lora\", str(config_path)])\n        assert result.exit_code == 0\n\n        mock_do_cli.assert_called_once()\n        assert mock_do_cli.call_args.kwargs[\"config\"] == str(config_path)\n\n\ndef test_merge_lora_with_dirs(cli_runner, config_path, tmp_path):\n    \"\"\"Test merge_lora with custom lora and output directories\"\"\"\n    lora_dir = tmp_path / \"lora\"\n    output_dir = tmp_path / \"output\"\n    lora_dir.mkdir()\n\n    with patch(\"axolotl.cli.merge_lora.do_cli\") as mock_do_cli:\n        result = cli_runner.invoke(\n            cli,\n            [\n                \"merge-lora\",\n                str(config_path),\n                \"--lora-model-dir\",\n                str(lora_dir),\n                \"--output-dir\",\n                str(output_dir),\n            ],\n        )\n        assert result.exit_code == 0\n\n        mock_do_cli.assert_called_once()\n        assert mock_do_cli.call_args.kwargs[\"config\"] == str(config_path)\n        assert mock_do_cli.call_args.kwargs[\"lora_model_dir\"] == str(lora_dir)\n        assert mock_do_cli.call_args.kwargs[\"output_dir\"] == str(output_dir)\n\n\ndef test_merge_lora_nonexistent_config(cli_runner, tmp_path):\n    \"\"\"Test merge_lora with nonexistent config\"\"\"\n    config_path = tmp_path / \"nonexistent.yml\"\n    result = cli_runner.invoke(cli, [\"merge-lora\", str(config_path)])\n    assert result.exit_code != 0\n\n\ndef test_merge_lora_nonexistent_lora_dir(cli_runner, config_path, tmp_path):\n    \"\"\"Test merge_lora with nonexistent lora directory\"\"\"\n    lora_dir = tmp_path / \"nonexistent\"\n    result = cli_runner.invoke(\n        cli, [\"merge-lora\", str(config_path), \"--lora-model-dir\", str(lora_dir)]\n    )\n    assert result.exit_code != 0\n"
  },
  {
    "path": "tests/cli/test_cli_merge_sharded_fsdp_weights.py",
    "content": "\"\"\"pytest tests for axolotl CLI merge_sharded_fsdp_weights command.\"\"\"\n\nfrom unittest.mock import patch\n\nfrom axolotl.cli.main import cli\n\n\ndef test_merge_sharded_fsdp_weights_no_accelerate(cli_runner, config_path):\n    \"\"\"Test merge_sharded_fsdp_weights command without accelerate\"\"\"\n    with patch(\"axolotl.cli.merge_sharded_fsdp_weights.do_cli\") as mock:\n        result = cli_runner.invoke(\n            cli,\n            [\"merge-sharded-fsdp-weights\", str(config_path), \"--launcher\", \"python\"],\n        )\n\n        assert mock.called\n        assert mock.call_args.kwargs[\"config\"] == str(config_path)\n        assert result.exit_code == 0\n\n\ndef test_merge_sharded_fsdp_weights_with_launcher_args_torchrun(\n    cli_runner, config_path\n):\n    \"\"\"Test merge-sharded-fsdp-weights with torchrun launcher arguments\"\"\"\n    with patch(\"subprocess.run\") as mock_subprocess:\n        result = cli_runner.invoke(\n            cli,\n            [\n                \"merge-sharded-fsdp-weights\",\n                str(config_path),\n                \"--launcher\",\n                \"torchrun\",\n                \"--\",\n                \"--nproc_per_node=2\",\n                \"--nnodes=1\",\n            ],\n            catch_exceptions=False,\n        )\n\n        assert result.exit_code == 0\n        mock_subprocess.assert_called_once()\n\n        # Verify launcher args are passed to torchrun\n        called_cmd = mock_subprocess.call_args.args[0]\n        assert called_cmd[0] == \"torchrun\"\n        assert \"--nproc_per_node=2\" in called_cmd\n        assert \"--nnodes=1\" in called_cmd\n        assert \"-m\" in called_cmd\n        assert \"axolotl.cli.merge_sharded_fsdp_weights\" in called_cmd\n\n\ndef test_merge_sharded_fsdp_weights_with_launcher_args_accelerate(\n    cli_runner, config_path\n):\n    \"\"\"Test merge-sharded-fsdp-weights with accelerate launcher arguments\"\"\"\n    with patch(\"subprocess.run\") as mock_subprocess:\n        result = cli_runner.invoke(\n            cli,\n            [\n                \"merge-sharded-fsdp-weights\",\n                str(config_path),\n                \"--launcher\",\n                \"accelerate\",\n                \"--\",\n                \"--config_file=accelerate_config.yml\",\n                \"--num_processes=4\",\n            ],\n            catch_exceptions=False,\n        )\n\n        assert result.exit_code == 0\n        mock_subprocess.assert_called_once()\n\n        # Verify launcher args are passed to accelerate\n        called_cmd = mock_subprocess.call_args.args[0]\n        assert called_cmd[0] == \"accelerate\"\n        assert called_cmd[1] == \"launch\"\n        assert \"--config_file=accelerate_config.yml\" in called_cmd\n        assert \"--num_processes=4\" in called_cmd\n        assert \"-m\" in called_cmd\n        assert \"axolotl.cli.merge_sharded_fsdp_weights\" in called_cmd\n\n\ndef test_merge_sharded_fsdp_weights_backward_compatibility_no_launcher_args(\n    cli_runner, config_path\n):\n    \"\"\"Test that existing merge-sharded-fsdp-weights commands work without launcher args\"\"\"\n    with patch(\"subprocess.run\") as mock_subprocess:\n        result = cli_runner.invoke(\n            cli,\n            [\n                \"merge-sharded-fsdp-weights\",\n                str(config_path),\n                \"--launcher\",\n                \"accelerate\",\n            ],\n            catch_exceptions=False,\n        )\n\n        assert result.exit_code == 0\n        mock_subprocess.assert_called_once()\n\n        # Verify no launcher args contamination\n        called_cmd = mock_subprocess.call_args.args[0]\n        assert called_cmd[0] == \"accelerate\"\n        assert called_cmd[1] == \"launch\"\n        # Should not contain any extra launcher args\n        launcher_section = called_cmd[2 : called_cmd.index(\"-m\")]\n        assert len(launcher_section) == 0  # No launcher args between 'launch' and '-m'\n"
  },
  {
    "path": "tests/cli/test_cli_preprocess.py",
    "content": "\"\"\"pytest tests for axolotl CLI preprocess command.\"\"\"\n\nimport shutil\nfrom pathlib import Path\nfrom unittest.mock import MagicMock, patch\n\nimport pytest\n\nfrom axolotl.cli.main import cli\n\n\n@pytest.fixture(autouse=True)\ndef cleanup_last_run_prepared():\n    yield\n\n    if Path(\"last_run_prepared\").exists():\n        shutil.rmtree(\"last_run_prepared\")\n\n\ndef test_preprocess_config_not_found(cli_runner):\n    \"\"\"Test preprocess fails when config not found\"\"\"\n    result = cli_runner.invoke(cli, [\"preprocess\", \"nonexistent.yml\"])\n    assert result.exit_code != 0\n\n\ndef test_preprocess_basic(cli_runner, config_path):\n    \"\"\"Test basic preprocessing with minimal config\"\"\"\n    with patch(\"axolotl.cli.preprocess.do_cli\") as mock_do_cli:\n        with patch(\"axolotl.cli.preprocess.load_datasets\") as mock_load_datasets:\n            mock_load_datasets.return_value = MagicMock()\n\n            result = cli_runner.invoke(cli, [\"preprocess\", str(config_path)])\n            assert result.exit_code == 0\n\n            mock_do_cli.assert_called_once()\n            assert mock_do_cli.call_args.kwargs[\"config\"] == str(config_path)\n            assert mock_do_cli.call_args.kwargs[\"download\"] is True\n\n\ndef test_preprocess_without_download(cli_runner, config_path):\n    \"\"\"Test preprocessing without model download\"\"\"\n    with patch(\"axolotl.cli.preprocess.do_cli\") as mock_do_cli:\n        result = cli_runner.invoke(\n            cli, [\"preprocess\", str(config_path), \"--no-download\"]\n        )\n        assert result.exit_code == 0\n\n        mock_do_cli.assert_called_once()\n        assert mock_do_cli.call_args.kwargs[\"config\"] == str(config_path)\n        assert mock_do_cli.call_args.kwargs[\"download\"] is False\n\n\ndef test_preprocess_custom_path(cli_runner, tmp_path, valid_test_config):\n    \"\"\"Test preprocessing with custom dataset path\"\"\"\n    config_path = tmp_path / \"config.yml\"\n    custom_path = tmp_path / \"custom_prepared\"\n    config_path.write_text(valid_test_config)\n\n    with patch(\"axolotl.cli.preprocess.do_cli\") as mock_do_cli:\n        with patch(\"axolotl.cli.preprocess.load_datasets\") as mock_load_datasets:\n            mock_load_datasets.return_value = MagicMock()\n\n            result = cli_runner.invoke(\n                cli,\n                [\n                    \"preprocess\",\n                    str(config_path),\n                    \"--dataset-prepared-path\",\n                    str(custom_path.absolute()),\n                ],\n            )\n            assert result.exit_code == 0\n\n            mock_do_cli.assert_called_once()\n            assert mock_do_cli.call_args.kwargs[\"config\"] == str(config_path)\n            assert mock_do_cli.call_args.kwargs[\"dataset_prepared_path\"] == str(\n                custom_path.absolute()\n            )\n"
  },
  {
    "path": "tests/cli/test_cli_sweeps.py",
    "content": "\"\"\"\nunit tests for generating sweep configurations\n\"\"\"\n\nfrom axolotl.cli.utils import generate_sweep_configs\n\n\ndef test_generate_sweep_configs_no_pairs():\n    base_config = {\n        \"learning_rate\": 0.1,\n        \"micro_batch_size\": 1,\n        \"sample_packing\": True,\n    }\n\n    sweeps_config = {\"micro_batch_size\": [1, 2, 4], \"weight_decay\": [0.0, 0.1]}\n\n    generate_sweep_configs(base_config, sweeps_config)\n\n    assert len(generate_sweep_configs(base_config, sweeps_config)) == 6\n\n    cfg_1 = {\n        \"learning_rate\": 0.1,\n        \"micro_batch_size\": 2,\n        \"weight_decay\": 0.0,\n        \"sample_packing\": True,\n    }\n\n    assert any(\n        cfg_1 == cfg for cfg in generate_sweep_configs(base_config, sweeps_config)\n    )\n\n\ndef test_generate_sweep_configs_with_pairs():\n    base_config = {\n        \"learning_rate\": 0.1,\n        \"micro_batch_size\": 1,\n        \"sample_packing\": True,\n    }\n\n    sweeps_config = {\n        \"_\": [\n            {\n                \"micro_batch_size\": 1,\n                \"gradient_accumulation_steps\": 8,\n            },\n            {\n                \"micro_batch_size\": 2,\n                \"gradient_accumulation_steps\": 4,\n            },\n            {\n                \"micro_batch_size\": 4,\n                \"gradient_accumulation_steps\": 2,\n            },\n            {\n                \"micro_batch_size\": 8,\n                \"gradient_accumulation_steps\": 1,\n            },\n        ],\n        \"weight_decay\": [0.0, 0.1],\n    }\n\n    generate_sweep_configs(base_config, sweeps_config)\n\n    assert len(generate_sweep_configs(base_config, sweeps_config)) == 8\n\n    assert all(\n        cfg[\"gradient_accumulation_steps\"] * cfg[\"micro_batch_size\"] == 8\n        for cfg in generate_sweep_configs(base_config, sweeps_config)\n    )\n"
  },
  {
    "path": "tests/cli/test_cli_train.py",
    "content": "\"\"\"Tests for train CLI command.\"\"\"\n\nfrom unittest.mock import MagicMock, patch\n\nfrom axolotl.cli.main import cli\n\nfrom .test_cli_base import BaseCliTest\n\n\nclass TestTrainCommand(BaseCliTest):\n    \"\"\"Test cases for train command.\"\"\"\n\n    cli = cli\n\n    def test_train_cli_validation(self, cli_runner):\n        \"\"\"Test CLI validation\"\"\"\n        self._test_cli_validation(cli_runner, \"train\")\n\n    def test_train_basic_execution(self, cli_runner, tmp_path, valid_test_config):\n        \"\"\"Test basic successful execution\"\"\"\n        self._test_basic_execution(\n            cli_runner, tmp_path, valid_test_config, \"train\", train=True\n        )\n\n    def test_train_basic_execution_no_accelerate(\n        self, cli_runner, tmp_path, valid_test_config\n    ):\n        \"\"\"Test basic successful execution without accelerate\"\"\"\n        config_path = tmp_path / \"config.yml\"\n        config_path.write_text(valid_test_config)\n\n        with patch(\"axolotl.cli.train.train\") as mock_train:\n            mock_train.return_value = (MagicMock(), MagicMock(), MagicMock())\n            with patch(\"axolotl.cli.train.load_datasets\") as mock_load_datasets:\n                mock_load_datasets.return_value = MagicMock()\n\n                result = cli_runner.invoke(\n                    cli,\n                    [\n                        \"train\",\n                        str(config_path),\n                        \"--launcher\",\n                        \"python\",\n                    ],\n                    catch_exceptions=False,\n                )\n\n                assert result.exit_code == 0\n                mock_train.assert_called_once()\n\n    def test_train_cli_overrides(self, cli_runner, tmp_path, valid_test_config):\n        \"\"\"Test CLI arguments properly override config values\"\"\"\n        config_path = self._test_cli_overrides(tmp_path, valid_test_config)\n\n        with patch(\"axolotl.cli.train.train\") as mock_train:\n            mock_train.return_value = (MagicMock(), MagicMock(), MagicMock())\n            with patch(\"axolotl.cli.train.load_datasets\") as mock_load_datasets:\n                mock_load_datasets.return_value = MagicMock()\n\n                result = cli_runner.invoke(\n                    cli,\n                    [\n                        \"train\",\n                        str(config_path),\n                        \"--learning-rate=1e-4\",\n                        \"--micro-batch-size=2\",\n                        \"--launcher\",\n                        \"python\",\n                    ],\n                    catch_exceptions=False,\n                )\n\n                assert result.exit_code == 0\n                mock_train.assert_called_once()\n                cfg = mock_train.call_args[1][\"cfg\"]\n                assert cfg[\"learning_rate\"] == 1e-4\n                assert cfg[\"micro_batch_size\"] == 2\n\n    def test_train_with_launcher_args_torchrun(\n        self, cli_runner, tmp_path, valid_test_config\n    ):\n        \"\"\"Test train with torchrun launcher arguments\"\"\"\n        config_path = tmp_path / \"config.yml\"\n        config_path.write_text(valid_test_config)\n\n        with patch(\"os.execvpe\") as mock_subprocess:\n            result = cli_runner.invoke(\n                cli,\n                [\n                    \"train\",\n                    str(config_path),\n                    \"--launcher\",\n                    \"torchrun\",\n                    \"--\",\n                    \"--nproc_per_node=2\",\n                    \"--nnodes=1\",\n                ],\n                catch_exceptions=False,\n            )\n\n            assert result.exit_code == 0\n            mock_subprocess.assert_called_once()\n\n            # Verify launcher args are passed to torchrun\n            called_cmd = mock_subprocess.call_args.args[1]\n            assert called_cmd[0] == \"torchrun\"\n            assert \"--nproc_per_node=2\" in called_cmd\n            assert \"--nnodes=1\" in called_cmd\n            assert \"-m\" in called_cmd\n            assert \"axolotl.cli.train\" in called_cmd\n\n    def test_train_with_launcher_args_accelerate(\n        self, cli_runner, tmp_path, valid_test_config\n    ):\n        \"\"\"Test train with accelerate launcher arguments\"\"\"\n        config_path = tmp_path / \"config.yml\"\n        config_path.write_text(valid_test_config)\n\n        with patch(\"os.execvpe\") as mock_subprocess:\n            result = cli_runner.invoke(\n                cli,\n                [\n                    \"train\",\n                    str(config_path),\n                    \"--launcher\",\n                    \"accelerate\",\n                    \"--\",\n                    \"--config_file=accelerate_config.yml\",\n                    \"--num_processes=4\",\n                ],\n                catch_exceptions=False,\n            )\n\n            assert result.exit_code == 0\n            mock_subprocess.assert_called_once()\n\n            # Verify launcher args are passed to accelerate\n            assert mock_subprocess.call_args.args[0] == \"accelerate\"\n            called_cmd = mock_subprocess.call_args.args[1]\n            assert called_cmd[0] == \"accelerate\"\n            assert called_cmd[1] == \"launch\"\n            assert \"--config_file=accelerate_config.yml\" in called_cmd\n            assert \"--num_processes=4\" in called_cmd\n            assert \"-m\" in called_cmd\n            assert \"axolotl.cli.train\" in called_cmd\n\n    def test_train_backward_compatibility_no_launcher_args(\n        self, cli_runner, tmp_path, valid_test_config\n    ):\n        \"\"\"Test that existing train commands work without launcher args\"\"\"\n        config_path = tmp_path / \"config.yml\"\n        config_path.write_text(valid_test_config)\n\n        with patch(\"os.execvpe\") as mock_subprocess:\n            result = cli_runner.invoke(\n                cli,\n                [\n                    \"train\",\n                    str(config_path),\n                    \"--launcher\",\n                    \"accelerate\",\n                    \"--learning-rate\",\n                    \"1e-4\",\n                ],\n                catch_exceptions=False,\n            )\n\n            assert result.exit_code == 0\n            mock_subprocess.assert_called_once()\n\n            # Verify no launcher args contamination\n            assert mock_subprocess.call_args.args[0] == \"accelerate\"\n            called_cmd = mock_subprocess.call_args.args[1]\n            assert called_cmd[0] == \"accelerate\"\n            assert called_cmd[1] == \"launch\"\n            # Should not contain any extra launcher args\n            launcher_section = called_cmd[2 : called_cmd.index(\"-m\")]\n            assert (\n                len(launcher_section) == 0\n            )  # No launcher args between 'launch' and '-m'\n\n    def test_train_mixed_args_with_launcher_args(\n        self, cli_runner, tmp_path, valid_test_config\n    ):\n        \"\"\"Test train with both regular CLI args and launcher args\"\"\"\n        config_path = tmp_path / \"config.yml\"\n        config_path.write_text(valid_test_config)\n\n        with patch(\"os.execvpe\") as mock_subprocess:\n            result = cli_runner.invoke(\n                cli,\n                [\n                    \"train\",\n                    str(config_path),\n                    \"--launcher\",\n                    \"torchrun\",\n                    \"--learning-rate\",\n                    \"2e-4\",\n                    \"--micro-batch-size\",\n                    \"4\",\n                    \"--\",\n                    \"--nproc_per_node=8\",\n                ],\n                catch_exceptions=False,\n            )\n\n            assert result.exit_code == 0\n            mock_subprocess.assert_called_once()\n\n            assert mock_subprocess.call_args.args[0] == \"torchrun\"\n            called_cmd = mock_subprocess.call_args.args[1]\n            # Verify launcher args\n            assert \"--nproc_per_node=8\" in called_cmd\n            # Verify axolotl args are also present\n            assert \"--learning-rate=2e-4\" in called_cmd\n            assert \"--micro-batch-size=4\" in called_cmd\n\n    def test_train_cloud_with_launcher_args(\n        self, cli_runner, tmp_path, valid_test_config\n    ):\n        \"\"\"Test train with cloud and launcher arguments\"\"\"\n        config_path = tmp_path / \"config.yml\"\n        config_path.write_text(valid_test_config)\n\n        cloud_path = tmp_path / \"cloud.yml\"\n        cloud_path.write_text(\"provider: modal\\ngpu: a100\")\n\n        with patch(\"axolotl.cli.cloud.do_cli_train\") as mock_cloud_train:\n            result = cli_runner.invoke(\n                cli,\n                [\n                    \"train\",\n                    str(config_path),\n                    \"--cloud\",\n                    str(cloud_path),\n                    \"--launcher\",\n                    \"torchrun\",\n                    \"--\",\n                    \"--nproc_per_node=4\",\n                    \"--nnodes=2\",\n                ],\n                catch_exceptions=False,\n            )\n\n            assert result.exit_code == 0\n            mock_cloud_train.assert_called_once()\n\n            # Verify cloud training was called with launcher args\n            call_kwargs = mock_cloud_train.call_args.kwargs\n            assert call_kwargs[\"launcher\"] == \"torchrun\"\n            assert call_kwargs[\"launcher_args\"] == [\"--nproc_per_node=4\", \"--nnodes=2\"]\n"
  },
  {
    "path": "tests/cli/test_cli_version.py",
    "content": "\"\"\"pytest tests for axolotl CLI --version\"\"\"\n\nfrom axolotl.cli.main import cli\n\n\ndef test_print_version(cli_runner):\n    \"\"\"Test that version is printed when --version is used.\"\"\"\n\n    result = cli_runner.invoke(cli, [\"--version\"])\n    assert result.exit_code == 0\n    assert \"axolotl, version \" in result.output\n"
  },
  {
    "path": "tests/cli/test_nested_options.py",
    "content": "\"\"\"Tests for nested config option handling via CLI dot-notation.\"\"\"\n\nimport click\nfrom click.testing import CliRunner\nfrom pydantic import BaseModel, Field\n\nfrom axolotl.cli.utils.args import add_options_from_config, filter_none_kwargs\n\n\nclass InnerConfig(BaseModel):\n    \"\"\"A nested config model for testing.\"\"\"\n\n    beta: float | None = Field(\n        default=None,\n        description=\"Beta parameter.\",\n    )\n    host: str | None = Field(\n        default=None,\n        description=\"Server host.\",\n    )\n    use_feature: bool = Field(\n        default=False,\n        description=\"Whether to use the feature.\",\n    )\n\n\nclass OuterConfig(BaseModel):\n    \"\"\"A top-level config model for testing.\"\"\"\n\n    learning_rate: float | None = Field(\n        default=None,\n        description=\"Learning rate.\",\n    )\n    inner: InnerConfig | None = Field(\n        default=None,\n        description=\"Inner config.\",\n    )\n    name: str | None = Field(\n        default=None,\n        description=\"Model name.\",\n    )\n\n\nclass TestAddOptionsFromConfigNested:\n    \"\"\"Test that add_options_from_config handles nested BaseModel fields.\"\"\"\n\n    def setup_method(self):\n        self.runner = CliRunner()\n\n    def test_nested_dot_notation_options_are_registered(self):\n        \"\"\"Nested model fields should create --parent.child CLI options.\"\"\"\n\n        @click.command()\n        @add_options_from_config(OuterConfig)\n        @filter_none_kwargs\n        def cmd(**kwargs):\n            for k, v in sorted(kwargs.items()):\n                click.echo(f\"{k}={v}\")\n\n        result = self.runner.invoke(cmd, [\"--inner.beta=0.5\", \"--inner.host=localhost\"])\n        assert result.exit_code == 0, result.output\n        assert \"inner__beta=0.5\" in result.output\n        assert \"inner__host=localhost\" in result.output\n\n    def test_nested_bool_option(self):\n        \"\"\"Nested bool fields should support --parent.field/--no-parent.field.\"\"\"\n\n        @click.command()\n        @add_options_from_config(OuterConfig)\n        @filter_none_kwargs\n        def cmd(**kwargs):\n            for k, v in sorted(kwargs.items()):\n                click.echo(f\"{k}={v}\")\n\n        result = self.runner.invoke(cmd, [\"--inner.use-feature\"])\n        assert result.exit_code == 0, result.output\n        assert \"inner__use_feature=True\" in result.output\n\n    def test_flat_and_nested_options_together(self):\n        \"\"\"Flat and nested options should work together.\"\"\"\n\n        @click.command()\n        @add_options_from_config(OuterConfig)\n        @filter_none_kwargs\n        def cmd(**kwargs):\n            for k, v in sorted(kwargs.items()):\n                click.echo(f\"{k}={v}\")\n\n        result = self.runner.invoke(\n            cmd, [\"--learning-rate=0.001\", \"--inner.beta=0.1\", \"--name=test\"]\n        )\n        assert result.exit_code == 0, result.output\n        assert \"learning_rate=0.001\" in result.output\n        assert \"inner__beta=0.1\" in result.output\n        assert \"name=test\" in result.output\n\n    def test_no_nested_options_passed(self):\n        \"\"\"When no nested options are passed, they should not appear in kwargs.\"\"\"\n\n        @click.command()\n        @add_options_from_config(OuterConfig)\n        @filter_none_kwargs\n        def cmd(**kwargs):\n            click.echo(f\"keys={sorted(kwargs.keys())}\")\n\n        result = self.runner.invoke(cmd, [\"--learning-rate=0.01\"])\n        assert result.exit_code == 0, result.output\n        assert \"inner__\" not in result.output\n\n\nclass TestLoadCfgNestedKwargs:\n    \"\"\"Test that load_cfg correctly applies nested (double-underscore) kwargs.\"\"\"\n\n    @staticmethod\n    def _apply_nested_kwargs(cfg, kwargs):\n        \"\"\"Helper that mirrors the nested kwargs handling from load_cfg,\n        including type coercion for string CLI values.\"\"\"\n        from axolotl.cli.config import _coerce_value\n\n        nested_kwargs: dict = {}\n        flat_kwargs: dict = {}\n        for key, value in kwargs.items():\n            if \"__\" in key:\n                parent, child = key.split(\"__\", 1)\n                nested_kwargs.setdefault(parent, {})[child] = value\n            else:\n                flat_kwargs[key] = value\n\n        cfg_keys = cfg.keys()\n        for key, value in flat_kwargs.items():\n            if key in cfg_keys:\n                cfg[key] = _coerce_value(value, cfg.get(key))\n\n        for parent, children in nested_kwargs.items():\n            if cfg[parent] is None:\n                cfg[parent] = {}\n            if not isinstance(cfg[parent], dict):\n                cfg[parent] = {}\n            for child_key, child_value in children.items():\n                existing = cfg[parent].get(child_key)\n                cfg[parent][child_key] = _coerce_value(child_value, existing)\n\n        return cfg\n\n    def test_nested_kwargs_applied_to_cfg(self, tmp_path):\n        \"\"\"Double-underscore kwargs should set nested config values.\"\"\"\n        from axolotl.utils.dict import DictDefault\n\n        cfg = DictDefault({\"trl\": {\"beta\": 0.1}, \"learning_rate\": 0.01})\n        # CLI passes strings, so simulate that\n        kwargs = {\n            \"trl__beta\": \"0.5\",\n            \"trl__host\": \"192.168.1.1\",\n            \"learning_rate\": \"0.02\",\n        }\n\n        cfg = self._apply_nested_kwargs(cfg, kwargs)\n\n        assert cfg[\"learning_rate\"] == 0.02\n        assert isinstance(cfg[\"learning_rate\"], float)\n        assert cfg[\"trl\"][\"beta\"] == 0.5\n        assert isinstance(cfg[\"trl\"][\"beta\"], float)\n        assert cfg[\"trl\"][\"host\"] == \"192.168.1.1\"\n\n    def test_nested_kwargs_creates_parent_if_none(self):\n        \"\"\"If the parent key is None, nested kwargs should create the dict.\"\"\"\n        from axolotl.utils.dict import DictDefault\n\n        cfg = DictDefault({\"trl\": None, \"learning_rate\": 0.01})\n        cfg = self._apply_nested_kwargs(cfg, {\"trl__beta\": \"0.5\"})\n\n        # No existing value, YAML-style inference: \"0.5\" -> 0.5\n        assert cfg[\"trl\"][\"beta\"] == 0.5\n        assert isinstance(cfg[\"trl\"][\"beta\"], float)\n\n    def test_nested_kwargs_overwrites_string_parent(self):\n        \"\"\"If the parent key is a string, it should be replaced with a dict.\"\"\"\n        from axolotl.utils.dict import DictDefault\n\n        cfg = DictDefault({\"trl\": \"some_string\", \"learning_rate\": 0.01})\n        cfg = self._apply_nested_kwargs(cfg, {\"trl__beta\": \"0.5\"})\n\n        assert cfg[\"trl\"][\"beta\"] == 0.5\n\n\nclass TestCoerceValue:\n    \"\"\"Test YAML-style type coercion for CLI string values.\"\"\"\n\n    def test_coerce_with_existing_float(self):\n        from axolotl.cli.config import _coerce_value\n\n        assert _coerce_value(\"0.5\", 0.1) == 0.5\n        assert isinstance(_coerce_value(\"0.5\", 0.1), float)\n\n    def test_coerce_with_existing_int(self):\n        from axolotl.cli.config import _coerce_value\n\n        assert _coerce_value(\"42\", 10) == 42\n        assert isinstance(_coerce_value(\"42\", 10), int)\n\n    def test_coerce_with_existing_bool(self):\n        from axolotl.cli.config import _coerce_value\n\n        assert _coerce_value(\"true\", False) is True\n        assert _coerce_value(\"false\", True) is False\n        assert _coerce_value(\"1\", False) is True\n        assert _coerce_value(\"0\", True) is False\n\n    def test_coerce_yaml_inference_no_existing(self):\n        \"\"\"Without an existing value, use YAML-style inference.\"\"\"\n        from axolotl.cli.config import _coerce_value\n\n        assert _coerce_value(\"true\", None) is True\n        assert _coerce_value(\"false\", None) is False\n        assert _coerce_value(\"42\", None) == 42\n        assert isinstance(_coerce_value(\"42\", None), int)\n        assert _coerce_value(\"3.14\", None) == 3.14\n        assert isinstance(_coerce_value(\"3.14\", None), float)\n        assert _coerce_value(\"null\", None) is None\n        assert _coerce_value(\"hello\", None) == \"hello\"\n\n    def test_coerce_non_string_passthrough(self):\n        \"\"\"Non-string values should pass through unchanged.\"\"\"\n        from axolotl.cli.config import _coerce_value\n\n        assert _coerce_value(0.5, 0.1) == 0.5\n        assert _coerce_value(True, False) is True\n"
  },
  {
    "path": "tests/cli/test_utils.py",
    "content": "\"\"\"pytest tests for axolotl CLI utils.\"\"\"\n\nimport json\nfrom unittest.mock import Mock, patch\n\nimport click\nimport pytest\nimport requests\n\nfrom axolotl.cli.utils import fetch_from_github\n\n# Sample GitHub API response\nMOCK_TREE_RESPONSE = {\n    \"tree\": [\n        {\"path\": \"examples/config1.yml\", \"type\": \"blob\", \"sha\": \"abc123\"},\n        {\"path\": \"examples/config2.yml\", \"type\": \"blob\", \"sha\": \"def456\"},\n        {\"path\": \"other/file.txt\", \"type\": \"blob\", \"sha\": \"xyz789\"},\n    ]\n}\n\n\n@pytest.fixture\ndef mock_responses():\n    \"\"\"Mock responses for API and file downloads\"\"\"\n\n    def mock_get(url, timeout=None):\n        response = Mock()\n        if \"api.github.com\" in url:\n            response.text = json.dumps(MOCK_TREE_RESPONSE)\n        else:\n            response.content = b\"file content\"\n        return response\n\n    return mock_get\n\n\ndef test_fetch_from_github_new_files(tmp_path, mock_responses):\n    \"\"\"Test fetching new files\"\"\"\n    with patch(\"requests.get\", mock_responses):\n        fetch_from_github(\"examples/\", tmp_path)\n\n        # Verify files were created\n        assert (tmp_path / \"config1.yml\").exists()\n        assert (tmp_path / \"config2.yml\").exists()\n        assert not (tmp_path / \"file.txt\").exists()\n\n\ndef test_fetch_from_github_unchanged_files(tmp_path, mock_responses):\n    \"\"\"Test handling of unchanged files\"\"\"\n    # Create existing file with matching SHA\n    existing_file = tmp_path / \"config1.yml\"\n    existing_file.write_bytes(b\"file content\")\n\n    with patch(\"requests.get\", mock_responses):\n        fetch_from_github(\"examples/\", tmp_path)\n\n        # File should not be downloaded again\n        assert existing_file.read_bytes() == b\"file content\"\n\n\ndef test_fetch_from_github_invalid_prefix(mock_responses):\n    \"\"\"Test error handling for invalid directory prefix\"\"\"\n    with patch(\"requests.get\", mock_responses):\n        with pytest.raises(click.ClickException):\n            fetch_from_github(\"nonexistent/\", None)\n\n\ndef test_fetch_from_github_network_error():\n    \"\"\"Test handling of network errors\"\"\"\n    with patch(\"requests.get\", side_effect=requests.RequestException):\n        with pytest.raises(requests.RequestException):\n            fetch_from_github(\"examples/\", None)\n\n\ndef assert_launcher_args_in_command(\n    mock_subprocess_call,\n    launcher: str,\n    expected_launcher_args: list[str],\n    command_module: str,\n):\n    \"\"\"\n    Helper function to verify launcher arguments are properly passed in subprocess calls.\n\n    Args:\n        mock_subprocess_call: The mock subprocess.run call\n        launcher: Expected launcher (\"accelerate\", \"torchrun\", etc.)\n        expected_launcher_args: List of expected launcher arguments\n        command_module: Expected module name (e.g., \"axolotl.cli.train\")\n    \"\"\"\n    assert mock_subprocess_call.called, \"subprocess.run should have been called\"\n    called_cmd = mock_subprocess_call.call_args.args[0]\n\n    # Verify launcher\n    assert called_cmd[0] == launcher, (\n        f\"Expected launcher {launcher}, got {called_cmd[0]}\"\n    )\n\n    # Verify launcher args are present\n    for arg in expected_launcher_args:\n        assert arg in called_cmd, (\n            f\"Expected launcher arg '{arg}' not found in command: {called_cmd}\"\n        )\n\n    # Verify module is present\n    assert \"-m\" in called_cmd, \"Expected -m flag for module execution\"\n    assert command_module in called_cmd, (\n        f\"Expected module {command_module} not found in command: {called_cmd}\"\n    )\n\n\ndef assert_no_launcher_args_contamination(mock_subprocess_call, launcher: str):\n    \"\"\"\n    Helper function to verify no unwanted launcher arguments are present.\n\n    Args:\n        mock_subprocess_call: The mock subprocess.run call\n        launcher: Expected launcher (\"accelerate\", \"torchrun\", etc.)\n    \"\"\"\n    assert mock_subprocess_call.called, \"subprocess.run should have been called\"\n    called_cmd = mock_subprocess_call.call_args.args[0]\n\n    if launcher == \"accelerate\":\n        # For accelerate, launcher args should be between 'launch' and '-m'\n        launch_idx = called_cmd.index(\"launch\")\n        m_idx = called_cmd.index(\"-m\")\n        launcher_section = called_cmd[launch_idx + 1 : m_idx]\n        assert len(launcher_section) == 0, (\n            f\"Unexpected launcher args found: {launcher_section}\"\n        )\n    elif launcher == \"torchrun\":\n        # For torchrun, launcher args should be between 'torchrun' and '-m'\n        torchrun_idx = called_cmd.index(\"torchrun\")\n        m_idx = called_cmd.index(\"-m\")\n        launcher_section = called_cmd[torchrun_idx + 1 : m_idx]\n        assert len(launcher_section) == 0, (\n            f\"Unexpected launcher args found: {launcher_section}\"\n        )\n\n\n@pytest.fixture\ndef common_launcher_args():\n    \"\"\"Fixture providing common launcher argument combinations for testing.\"\"\"\n    return {\n        \"torchrun\": [\"--nproc_per_node=2\", \"--nnodes=1\"],\n        \"accelerate\": [\"--config_file=accelerate_config.yml\", \"--num_processes=4\"],\n    }\n\n\ndef test_add_default_rdzv_args_with_endpoint():\n    \"\"\"Test that default RDZV args are added when rdzv_endpoint is present.\"\"\"\n    from axolotl.cli.utils.train import _add_default_rdzv_args\n\n    launcher_args = [\"--nnodes=2\", \"--rdzv_endpoint=127.0.0.1:29400\"]\n    result = _add_default_rdzv_args(launcher_args)\n\n    # Should have added rdzv_backend\n    assert \"--rdzv_backend\" in result\n    assert \"c10d\" in result\n\n    # Original args should still be present\n    assert \"--nnodes=2\" in result\n    assert \"--rdzv_endpoint=127.0.0.1:29400\" in result\n\n\ndef test_add_default_rdzv_args_with_existing_backend():\n    \"\"\"Test that existing rdzv_backend is not overridden.\"\"\"\n    from axolotl.cli.utils.train import _add_default_rdzv_args\n\n    launcher_args = [\n        \"--nnodes=2\",\n        \"--rdzv_endpoint=127.0.0.1:29400\",\n        \"--rdzv_backend=static\",\n    ]\n    result = _add_default_rdzv_args(launcher_args)\n\n    # Should not add another rdzv_backend\n    backend_count = sum(1 for arg in result if \"--rdzv_backend\" in arg)\n    assert backend_count == 1\n    assert \"--rdzv_backend=static\" in result\n\n\ndef test_add_default_rdzv_args_with_existing_id():\n    \"\"\"Test that existing rdzv_id is not overridden.\"\"\"\n    from axolotl.cli.utils.train import _add_default_rdzv_args\n\n    launcher_args = [\n        \"--nnodes=2\",\n        \"--rdzv_endpoint=127.0.0.1:29400\",\n        \"--rdzv_id=my_job_123\",\n    ]\n    result = _add_default_rdzv_args(launcher_args)\n\n    # Should not add another rdzv_id\n    id_count = sum(1 for arg in result if \"--rdzv_id\" in arg)\n    assert id_count == 1\n    assert \"--rdzv_id=my_job_123\" in result\n\n    # Should still add rdzv_backend\n    assert \"--rdzv_backend\" in result\n    assert \"c10d\" in result\n\n\ndef test_add_default_rdzv_args_without_endpoint():\n    \"\"\"Test that no RDZV args are added when rdzv_endpoint is not present.\"\"\"\n    from axolotl.cli.utils.train import _add_default_rdzv_args\n\n    launcher_args = [\"--nnodes=2\", \"--nproc_per_node=4\"]\n    result = _add_default_rdzv_args(launcher_args)\n\n    # Should not add any rdzv args\n    assert \"--rdzv_backend\" not in result\n    assert result == launcher_args\n\n\ndef test_add_default_rdzv_args_with_all_existing():\n    \"\"\"Test that no defaults are added when all RDZV args are present.\"\"\"\n    from axolotl.cli.utils.train import _add_default_rdzv_args\n\n    launcher_args = [\n        \"--nnodes=2\",\n        \"--rdzv_endpoint=127.0.0.1:29400\",\n        \"--rdzv_backend=static\",\n        \"--rdzv_id=existing_job\",\n    ]\n    result = _add_default_rdzv_args(launcher_args)\n\n    # Should not add any additional args\n    assert len(result) == len(launcher_args)\n    assert result == launcher_args\n"
  },
  {
    "path": "tests/conftest.py",
    "content": "\"\"\"Shared pytest fixtures\"\"\"\n\nimport functools\nimport importlib\nimport logging\nimport os\nimport shutil\nimport sys\nimport tempfile\nimport time\nfrom pathlib import Path\nfrom typing import Generator\n\nimport datasets\nimport pytest\nimport requests\nimport torch\nfrom huggingface_hub import snapshot_download\nfrom huggingface_hub.errors import LocalEntryNotFoundError\nfrom tokenizers import AddedToken\nfrom transformers import AutoTokenizer\n\nfrom axolotl.utils.dict import DictDefault\n\nfrom tests.hf_offline_utils import (\n    enable_hf_offline,\n    hf_offline_context,\n)\n\nlogging.getLogger(\"filelock\").setLevel(logging.CRITICAL)\n\n\ndef retry_on_request_exceptions(max_retries=3, delay=1):\n    def decorator(func):\n        @functools.wraps(func)\n        def wrapper(*args, **kwargs):\n            for attempt in range(max_retries):\n                try:\n                    return func(*args, **kwargs)\n                except (\n                    requests.exceptions.ReadTimeout,\n                    requests.exceptions.ConnectionError,\n                    requests.exceptions.HTTPError,\n                ) as exc:\n                    if attempt < max_retries - 1:\n                        wait = 2**attempt * delay  # in seconds\n                        time.sleep(wait)\n                    else:\n                        raise exc\n\n        return wrapper\n\n    return decorator\n\n\n@retry_on_request_exceptions(max_retries=3, delay=5)\ndef snapshot_download_w_retry(*args, **kwargs):\n    \"\"\"\n    download a model or dataset from HF Hub, retrying in requests failures. We also try to fetch it from the local\n    cache first using hf_hub_offline to avoid hitting HF Hub API rate limits. If it doesn't exist in the cache,\n    disable hf_hub_offline and actually fetch from the hub\n    \"\"\"\n    with hf_offline_context(True):\n        try:\n            return snapshot_download(*args, local_files_only=True, **kwargs)\n        except LocalEntryNotFoundError:\n            pass\n    with hf_offline_context(False):\n        return snapshot_download(*args, **kwargs)\n\n\n@pytest.fixture(scope=\"session\", autouse=True)\ndef download_ds_fixture_bundle():\n    ds_dir = snapshot_download_w_retry(\n        \"axolotl-ai-internal/axolotl-oss-dataset-fixtures\", repo_type=\"dataset\"\n    )\n    return Path(ds_dir)\n\n\n@pytest.fixture(scope=\"session\", autouse=True)\ndef download_smollm2_135m_model():\n    # download the model\n    snapshot_download_w_retry(\"HuggingFaceTB/SmolLM2-135M\", repo_type=\"model\")\n\n\n@pytest.fixture(scope=\"session\", autouse=True)\ndef download_smollm2_135m_instruct_model():\n    # download the model\n    snapshot_download_w_retry(\"HuggingFaceTB/SmolLM2-135M-Instruct\", repo_type=\"model\")\n\n\n@pytest.fixture(scope=\"session\", autouse=True)\ndef download_smollm2_135m_gptq_model():\n    # download the model\n    snapshot_download_w_retry(\"lilmeaty/SmolLM2-135M-Instruct-GPTQ\", repo_type=\"model\")\n\n\n@pytest.fixture(scope=\"session\", autouse=True)\ndef download_qwen_2_5_half_billion_model():\n    # download the model\n    snapshot_download_w_retry(\"Qwen/Qwen2.5-0.5B\", repo_type=\"model\")\n\n\n@pytest.fixture(scope=\"session\", autouse=True)\ndef download_qwen3_half_billion_model():\n    # download the model\n    snapshot_download_w_retry(\"Qwen/Qwen3-0.6B\", repo_type=\"model\")\n\n\n@pytest.fixture(scope=\"session\", autouse=True)\ndef download_tatsu_lab_alpaca_dataset():\n    # download the dataset\n    snapshot_download_w_retry(\"tatsu-lab/alpaca\", repo_type=\"dataset\")\n\n\n@pytest.fixture(scope=\"session\", autouse=True)\ndef download_mhenrichsen_alpaca_2k_dataset():\n    # download the dataset\n    snapshot_download_w_retry(\"mhenrichsen/alpaca_2k_test\", repo_type=\"dataset\")\n\n\n@pytest.fixture(scope=\"session\", autouse=True)\ndef download_mhenrichsen_alpaca_2k_w_revision_dataset():\n    # download the dataset\n    snapshot_download_w_retry(\n        \"mhenrichsen/alpaca_2k_test\", repo_type=\"dataset\", revision=\"d05c1cb\"\n    )\n\n\n@pytest.fixture(scope=\"session\", autouse=True)\ndef download_mlabonne_finetome_100k_dataset():\n    # download the dataset\n    snapshot_download_w_retry(\"mlabonne/FineTome-100k\", repo_type=\"dataset\")\n\n\n@pytest.fixture(scope=\"session\", autouse=True)\ndef download_argilla_distilabel_capybara_dpo_7k_binarized_dataset():\n    # download the dataset\n    snapshot_download_w_retry(\n        \"argilla/distilabel-capybara-dpo-7k-binarized\", repo_type=\"dataset\"\n    )\n\n\n@pytest.fixture(scope=\"session\", autouse=True)\ndef download_argilla_distilabel_intel_orca_dpo_dataset():\n    # download the dataset\n    snapshot_download_w_retry(\n        \"argilla/distilabel-intel-orca-dpo-pairs\", repo_type=\"dataset\"\n    )\n\n\n@pytest.fixture(scope=\"session\", autouse=True)\ndef download_argilla_ultrafeedback_binarized_preferences_cleaned_dataset():\n    # download the dataset\n    snapshot_download_w_retry(\n        \"argilla/ultrafeedback-binarized-preferences-cleaned\", repo_type=\"dataset\"\n    )\n\n\n@pytest.fixture(scope=\"session\", autouse=True)\ndef download_argilla_ultrafeedback_binarized_preferences_cleaned_kto_dataset():\n    # download the dataset\n    snapshot_download_w_retry(\n        \"argilla/ultrafeedback-binarized-preferences-cleaned-kto\", repo_type=\"dataset\"\n    )\n\n\n# @pytest.fixture(scope=\"session\", autouse=True)\n# def download_fozzie_alpaca_dpo_dataset():\n#     # download the dataset\n#     snapshot_download_w_retry(\n#         \"fozziethebeat/alpaca_messages_2k_dpo_test\", repo_type=\"dataset\"\n#     )\n#     snapshot_download_w_retry(\n#         \"fozziethebeat/alpaca_messages_2k_dpo_test\",\n#         repo_type=\"dataset\",\n#         revision=\"ea82cff\",\n#     )\n\n\n# @pytest.fixture(scope=\"session\")\n# @disable_hf_offline\n# def dataset_fozzie_alpaca_dpo_dataset(\n#     download_fozzie_alpaca_dpo_dataset,\n# ):\n#     return load_dataset(\"fozziethebeat/alpaca_messages_2k_dpo_test\", split=\"train\")\n#\n#\n# @pytest.fixture(scope=\"session\")\n# @disable_hf_offline\n# def dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff(\n#     download_fozzie_alpaca_dpo_dataset,\n# ):\n#     return load_dataset(\n#         \"fozziethebeat/alpaca_messages_2k_dpo_test\", split=\"train\", revision=\"ea82cff\"\n#     )\n\n\n@pytest.fixture(scope=\"session\", autouse=True)\ndef download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset():\n    # download the dataset\n    snapshot_download_w_retry(\n        \"arcee-ai/distilabel-intel-orca-dpo-pairs-binarized\", repo_type=\"dataset\"\n    )\n\n\n@pytest.fixture(scope=\"session\", autouse=True)\ndef download_argilla_dpo_pairs_dataset():\n    # download the dataset\n    snapshot_download_w_retry(\n        \"argilla/distilabel-intel-orca-dpo-pairs\", repo_type=\"dataset\"\n    )\n\n\n@pytest.fixture(scope=\"session\", autouse=True)\ndef download_tiny_shakespeare_dataset():\n    # download the dataset\n    snapshot_download_w_retry(\"winglian/tiny-shakespeare\", repo_type=\"dataset\")\n\n\n@pytest.fixture(scope=\"session\", autouse=True)\ndef download_evolkit_kd_sample_dataset():\n    # download the dataset\n    snapshot_download_w_retry(\n        \"axolotl-ai-co/evolkit-logprobs-pipeline-75k-v2-sample\", repo_type=\"dataset\"\n    )\n\n\n@pytest.fixture(scope=\"session\", autouse=True)\ndef download_deepseek_model_fixture():\n    snapshot_download_w_retry(\"axolotl-ai-co/DeepSeek-V3-11M\", repo_type=\"model\")\n\n\n@pytest.fixture(scope=\"session\", autouse=True)\ndef download_huggyllama_model_fixture():\n    # download the tokenizer only\n    snapshot_download_w_retry(\n        \"huggyllama/llama-7b\",\n        repo_type=\"model\",\n        allow_patterns=[\"*token*\", \"config.json\"],\n    )\n\n\n@pytest.fixture(scope=\"session\", autouse=True)\ndef download_llama33_70b_model_fixture():\n    # download the tokenizer only\n    snapshot_download_w_retry(\n        \"axolotl-ai-co/Llama-3.3-70B-Instruct-tokenizer\",\n        repo_type=\"model\",\n        allow_patterns=[\"*token*\", \"config.json\"],\n    )\n\n\n@pytest.fixture(scope=\"session\", autouse=True)\ndef download_llama_1b_model_fixture():\n    # download the tokenizer only\n    snapshot_download_w_retry(\n        \"NousResearch/Llama-3.2-1B\",\n        repo_type=\"model\",\n        allow_patterns=[\"*token*\", \"config.json\"],\n    )\n\n\n@pytest.fixture(scope=\"session\", autouse=True)\ndef download_llama3_8b_model_fixture():\n    # download the tokenizer only\n    snapshot_download_w_retry(\n        \"NousResearch/Meta-Llama-3-8B\",\n        repo_type=\"model\",\n        allow_patterns=[\"*token*\", \"config.json\"],\n    )\n\n\n@pytest.fixture(scope=\"session\", autouse=True)\ndef download_llama3_8b_instruct_model_fixture():\n    # download the tokenizer only\n    snapshot_download_w_retry(\n        \"NousResearch/Meta-Llama-3-8B-Instruct\",\n        repo_type=\"model\",\n        allow_patterns=[\"*token*\", \"config.json\"],\n    )\n\n\n@pytest.fixture(scope=\"session\", autouse=True)\ndef download_phi_35_mini_model_fixture():\n    # download the tokenizer only\n    snapshot_download_w_retry(\n        \"microsoft/Phi-3.5-mini-instruct\",\n        repo_type=\"model\",\n        allow_patterns=[\"*token*\", \"config.json\"],\n    )\n\n\n@pytest.fixture(scope=\"session\", autouse=True)\ndef download_phi_4_reasoning_model_fixture():\n    # download the tokenizer only\n    snapshot_download_w_retry(\n        \"microsoft/Phi-4-reasoning\",\n        repo_type=\"model\",\n        allow_patterns=[\"*token*\", \"config.json\"],\n    )\n\n\n@pytest.fixture(scope=\"session\", autouse=True)\ndef download_phi_3_medium_model_fixture():\n    # download the tokenizer only\n    snapshot_download_w_retry(\n        \"microsoft/Phi-3-medium-128k-instruct\",\n        repo_type=\"model\",\n        allow_patterns=[\"*token*\", \"config.json\"],\n    )\n\n\n@pytest.fixture(scope=\"session\", autouse=True)\ndef download_mistral_7b_model_fixture():\n    # download the tokenizer only\n    snapshot_download_w_retry(\n        \"casperhansen/mistral-7b-instruct-v0.1-awq\",\n        repo_type=\"model\",\n        allow_patterns=[\"*token*\", \"config.json\"],\n    )\n\n\n@pytest.fixture(scope=\"session\", autouse=True)\ndef download_gemma3_4b_model_fixture():\n    # download the tokenizer only\n    snapshot_download_w_retry(\n        \"mlx-community/gemma-3-4b-it-8bit\",\n        repo_type=\"model\",\n        allow_patterns=[\"*token*\", \"config.json\"],\n    )\n\n\n@pytest.fixture(scope=\"session\", autouse=True)\ndef download_gemma_2b_model_fixture():\n    # download the tokenizer only\n    snapshot_download_w_retry(\n        \"unsloth/gemma-2b-it\",\n        revision=\"703fb4a\",\n        repo_type=\"model\",\n        allow_patterns=[\"*token*\", \"config.json\"],\n    )\n\n\n@pytest.fixture(scope=\"session\", autouse=True)\ndef download_gemma2_9b_model_fixture():\n    # download the tokenizer only\n    snapshot_download_w_retry(\n        \"mlx-community/gemma-2-9b-it-4bit\",\n        repo_type=\"model\",\n        allow_patterns=[\"*token*\", \"config.json\"],\n    )\n\n\n@pytest.fixture(scope=\"session\", autouse=True)\ndef download_mlx_mistral_7b_model_fixture():\n    # download the tokenizer only\n    snapshot_download_w_retry(\n        \"mlx-community/Mistral-7B-Instruct-v0.3-4bit\",\n        repo_type=\"model\",\n        allow_patterns=[\"*token*\", \"config.json\"],\n    )\n\n\n@pytest.fixture\ndef download_llama2_model_fixture():\n    # download the tokenizer only\n    snapshot_download_w_retry(\n        \"NousResearch/Llama-2-7b-hf\",\n        repo_type=\"model\",\n        allow_patterns=[\"*token*\", \"config.json\"],\n    )\n\n\n@pytest.fixture(scope=\"session\", autouse=True)\ndef download_llama32_1b_model_fixture():\n    snapshot_download_w_retry(\n        \"osllmai-community/Llama-3.2-1B\",\n        repo_type=\"model\",\n    )\n\n\n@pytest.fixture\n@enable_hf_offline\ndef tokenizer_huggyllama(\n    download_huggyllama_model_fixture,\n):\n    tokenizer = AutoTokenizer.from_pretrained(\"huggyllama/llama-7b\")\n    tokenizer.pad_token = \"</s>\"\n\n    return tokenizer\n\n\n@pytest.fixture\n@enable_hf_offline\ndef tokenizer_huggyllama_w_special_tokens(\n    tokenizer_huggyllama,\n):\n    tokenizer_huggyllama.add_special_tokens(\n        {\n            \"bos_token\": \"<s>\",\n            \"eos_token\": \"</s>\",\n            \"unk_token\": \"<unk>\",\n        }\n    )\n\n    return tokenizer_huggyllama\n\n\n@pytest.fixture\n@enable_hf_offline\ndef tokenizer_llama2_7b(\n    download_llama2_model_fixture,\n):\n    tokenizer = AutoTokenizer.from_pretrained(\"NousResearch/Llama-2-7b-hf\")\n\n    return tokenizer\n\n\n@pytest.fixture\n@enable_hf_offline\ndef tokenizer_mistral_7b_instruct(\n    download_mlx_mistral_7b_model_fixture,\n):\n    return AutoTokenizer.from_pretrained(\"casperhansen/mistral-7b-instruct-v0.1-awq\")\n\n\n@pytest.fixture\ndef tokenizer_mistral_7b_instruct_chatml(tokenizer_mistral_7b_instruct):\n    tokenizer_mistral_7b_instruct.add_special_tokens(\n        {\n            \"eos_token\": AddedToken(\n                \"<|im_end|>\", rstrip=False, lstrip=False, normalized=False\n            )\n        }\n    )\n    tokenizer_mistral_7b_instruct.add_tokens(\n        [\n            AddedToken(\"<|im_start|>\", rstrip=False, lstrip=False, normalized=False),\n        ]\n    )\n    return tokenizer_mistral_7b_instruct\n\n\n@pytest.fixture\ndef temp_dir() -> Generator[str, None, None]:\n    # Create a temporary directory\n    _temp_dir = tempfile.mkdtemp()\n    yield _temp_dir\n    # Clean up the directory after the test\n    shutil.rmtree(_temp_dir)\n\n\n@pytest.fixture(scope=\"function\", autouse=True)\ndef torch_manual_seed():\n    torch.manual_seed(42)\n\n\n@pytest.fixture(scope=\"function\", autouse=True)\ndef cleanup_monkeypatches():\n    from transformers import Trainer\n    from transformers.models.llama.modeling_llama import (  # LlamaFlashAttention2,\n        LlamaAttention,\n        LlamaForCausalLM,\n    )\n\n    # original_fa2_forward = LlamaFlashAttention2.forward\n    original_llama_attn_forward = LlamaAttention.forward\n    original_llama_forward = LlamaForCausalLM.forward\n    original_trainer_inner_training_loop = Trainer._inner_training_loop\n    original_trainer_training_step = Trainer.training_step\n    # monkey patches can happen inside the tests\n    yield\n    # Reset LlamaFlashAttention2 forward\n    # LlamaFlashAttention2.forward = original_fa2_forward\n    LlamaAttention.forward = original_llama_attn_forward\n    LlamaForCausalLM.forward = original_llama_forward\n    Trainer._inner_training_loop = original_trainer_inner_training_loop\n    Trainer.training_step = original_trainer_training_step\n\n    # Reset other known monkeypatches\n    modules_to_reset: list[tuple[str, list[str]]] = [\n        (\"transformers.models.llama\",),\n        (\n            \"transformers.models.llama.modeling_llama\",\n            [\n                # \"LlamaFlashAttention2\",\n                \"LlamaAttention\",\n            ],\n        ),\n        (\"transformers.trainer\",),\n        (\"transformers\", [\"Trainer\"]),\n        (\"transformers.loss.loss_utils\",),\n    ]\n    for module_name_tuple in modules_to_reset:\n        module_name = module_name_tuple[0]\n\n        spec = importlib.util.spec_from_file_location(\n            module_name, sys.modules[module_name].__file__\n        )\n        sys.modules[module_name] = importlib.util.module_from_spec(spec)\n        spec.loader.exec_module(sys.modules[module_name])\n\n        sys.modules[module_name] = importlib.reload(sys.modules[module_name])\n        if len(module_name_tuple) > 1:\n            module_globals = module_name_tuple[1]\n            for module_global in module_globals:\n                globals().pop(module_global, None)\n\n\n@pytest.fixture\ndef dataset_winglian_tiny_shakespeare(\n    download_ds_fixture_bundle: Path,\n):\n    ds_path = download_ds_fixture_bundle / \"winglian__tiny-shakespeare\"\n    return datasets.load_from_disk(ds_path)\n\n\n@pytest.fixture\ndef dataset_tatsu_lab_alpaca(\n    download_ds_fixture_bundle: Path,\n):\n    ds_path = download_ds_fixture_bundle / \"tatsu-lab__alpaca\"\n    return datasets.load_from_disk(ds_path)[\"train\"]\n\n\n@pytest.fixture\ndef dataset_mhenrichsen_alpaca_2k_test(\n    download_ds_fixture_bundle: Path,\n):\n    ds_path = download_ds_fixture_bundle / \"mhenrichsen__alpaca_2k_test\"\n    return datasets.load_from_disk(ds_path)[\"train\"]\n\n\n@pytest.fixture\ndef dataset_argilla_ultrafeedback_binarized_preferences_cleaned(\n    download_ds_fixture_bundle: Path,\n):\n    ds_path = (\n        download_ds_fixture_bundle\n        / \"argilla__ultrafeedback-binarized-preferences-cleaned\"\n    )\n    return datasets.load_from_disk(ds_path)[\"train\"]\n\n\n@pytest.fixture\ndef dataset_fozziethebeat_alpaca_messages_2k_dpo_test(\n    download_ds_fixture_bundle: Path,\n):\n    ds_path = download_ds_fixture_bundle / \"fozziethebeat__alpaca_messages_2k_dpo_test\"\n    return datasets.load_from_disk(ds_path)[\"train\"]\n\n\n@pytest.fixture\ndef dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff(\n    download_ds_fixture_bundle: Path,\n):\n    ds_path = (\n        download_ds_fixture_bundle\n        / \"fozziethebeat__alpaca_messages_2k_dpo_test__rev_ea82cff\"\n    )\n    return datasets.load_from_disk(ds_path)[\"train\"]\n\n\n@pytest.fixture(name=\"min_base_cfg\")\ndef fixture_min_base_cfg():\n    return DictDefault(\n        base_model=\"HuggingFaceTB/SmolLM2-135M\",\n        learning_rate=1e-3,\n        datasets=[\n            {\n                \"path\": \"mhenrichsen/alpaca_2k_test\",\n                \"type\": \"alpaca\",\n            },\n        ],\n        micro_batch_size=1,\n        gradient_accumulation_steps=1,\n    )\n\n\n#\n@pytest.mark.skipif(\n    os.environ.get(\"AXOLOTL_IS_CI_CACHE_PRELOAD\", \"-1\") != \"1\",\n    reason=\"Not running in CI cache preload\",\n)\ndef test_load_fixtures(\n    download_smollm2_135m_model,\n    download_qwen_2_5_half_billion_model,\n    download_tatsu_lab_alpaca_dataset,\n    download_mhenrichsen_alpaca_2k_dataset,\n    download_mhenrichsen_alpaca_2k_w_revision_dataset,\n    download_mlabonne_finetome_100k_dataset,\n    download_argilla_ultrafeedback_binarized_preferences_cleaned_dataset,\n    download_argilla_ultrafeedback_binarized_preferences_cleaned_kto_dataset,\n    download_argilla_distilabel_capybara_dpo_7k_binarized_dataset,\n    download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset,\n    download_argilla_dpo_pairs_dataset,\n    download_tiny_shakespeare_dataset,\n    download_deepseek_model_fixture,\n    download_huggyllama_model_fixture,\n    download_llama_1b_model_fixture,\n    download_llama3_8b_model_fixture,\n    download_llama3_8b_instruct_model_fixture,\n    download_phi_35_mini_model_fixture,\n    download_phi_3_medium_model_fixture,\n    download_phi_4_reasoning_model_fixture,\n    download_mistral_7b_model_fixture,\n    download_gemma_2b_model_fixture,\n    download_gemma2_9b_model_fixture,\n    download_mlx_mistral_7b_model_fixture,\n    download_llama2_model_fixture,\n):\n    pass\n\n\n@pytest.fixture(autouse=True)\ndef disable_telemetry(monkeypatch):\n    monkeypatch.setenv(\"AXOLOTL_DO_NOT_TRACK\", \"1\")\n    yield\n"
  },
  {
    "path": "tests/constants.py",
    "content": "# constants.py\n\"\"\"\nThis module contains constants and configuration dictionaries used for\ndatasets and other utilities in the Axolotl project, specifically for testing.\n\"\"\"\n\n# Configuration for Alpaca Messages Dataset\nALPACA_MESSAGES_CONFIG_OG = {\n    \"path\": \"fozziethebeat/alpaca_messages_2k_dpo_test\",\n    \"type\": \"chat_template.default\",\n    \"chat_template\": \"llama3\",\n    \"field_messages\": \"conversation\",\n    \"field_chosen\": \"chosen\",\n    \"field_rejected\": \"rejected\",\n    \"message_field_role\": \"role\",\n    \"message_field_content\": \"content\",\n    \"roles\": {\n        \"system\": [\"system\"],\n        \"user\": [\"user\"],\n        \"assistant\": [\"assistant\"],\n    },\n}\n\n# Revision configuration extending the original\nALPACA_MESSAGES_CONFIG_REVISION = ALPACA_MESSAGES_CONFIG_OG.copy()\nALPACA_MESSAGES_CONFIG_REVISION[\"revision\"] = \"ea82cff\"\n\n\nSPECIAL_TOKENS = {\n    \"bos_token\": \"<s>\",\n    \"eos_token\": \"</s>\",\n    \"unk_token\": \"<unk>\",\n}\n"
  },
  {
    "path": "tests/core/chat/__init__.py",
    "content": ""
  },
  {
    "path": "tests/core/chat/format/__init__.py",
    "content": ""
  },
  {
    "path": "tests/core/chat/test_messages.py",
    "content": "\"\"\"\nTests for the chat messages module\n\"\"\"\n\nimport unittest\n\nimport pytest\nfrom transformers import AddedToken, AutoTokenizer\n\nfrom axolotl.core.chat.format.chatml import format_message\nfrom axolotl.core.chat.messages import ChatFormattedChats, Chats\n\nfrom tests.hf_offline_utils import enable_hf_offline  # noqa\n\n\n@pytest.fixture(scope=\"session\", name=\"llama_tokenizer\")\n@enable_hf_offline\ndef llama_tokenizer_fixture():\n    return AutoTokenizer.from_pretrained(\"NousResearch/Meta-Llama-3-8B\")\n\n\n@pytest.fixture(scope=\"session\", name=\"chatml_tokenizer\")\ndef llama_tokenizer_w_chatml(llama_tokenizer):\n    llama_tokenizer.add_special_tokens(\n        {\n            \"eos_token\": AddedToken(\n                \"<|im_end|>\", rstrip=False, lstrip=False, normalized=False\n            )\n        }\n    )\n    llama_tokenizer.add_tokens(\n        [\n            AddedToken(\"<|im_start|>\", rstrip=False, lstrip=False, normalized=False),\n        ]\n    )\n\n    return llama_tokenizer\n\n\n@pytest.fixture(scope=\"session\", name=\"chat_msgs\")\ndef chat_msgs_fixture():\n    return {\n        \"conversation\": [\n            {\n                \"role\": \"system\",\n                \"content\": [\n                    {\"type\": \"text\", \"value\": \"You are a helpful assistant.\"},\n                ],\n            },\n            {\n                \"role\": \"user\",\n                \"content\": [\n                    {\"type\": \"text\", \"value\": \"What is today's stock price of Apple?\"},\n                ],\n            },\n            {\n                \"role\": \"assistant\",\n                \"content\": [\n                    {\n                        \"type\": \"tool_call\",\n                        \"value\": {\n                            \"name\": \"get_date\",\n                            \"arguments\": {},\n                        },\n                    },\n                    {\n                        \"type\": \"tool_call\",\n                        \"value\": {\n                            \"name\": \"get_stock_price\",\n                            \"arguments\": {\"symbol\": \"AAPL\"},\n                        },\n                    },\n                ],\n                \"weight\": 1,\n            },\n            {\n                \"role\": \"tool\",\n                \"content\": [\n                    {\n                        \"type\": \"tool_response\",\n                        \"value\": {\n                            \"name\": \"get_date\",\n                            \"content\": {\"date\": \"2024-09-09\"},\n                        },\n                    },\n                    {\n                        \"type\": \"tool_response\",\n                        \"value\": {\n                            \"name\": \"get_stock_price\",\n                            \"content\": {\"symbol\": \"AAPL\", \"price\": 123.45},\n                        },\n                    },\n                ],\n            },\n            {\n                \"role\": \"assistant\",\n                \"content\": [\n                    {\n                        \"type\": \"text\",\n                        \"value\": \"The stock price of Apple is $123.45.\\n\",\n                        \"weight\": 0,\n                    },\n                    {\n                        \"type\": \"text\",\n                        \"value\": \"<reflection>The original query asked for today's stock price of Apple. This implies they also wanted the date included in the response.</reflection>\",\n                    },\n                    {\n                        \"type\": \"text\",\n                        \"value\": \"The stock price of Apple on September 9, 2024 is $123.45.\",\n                    },\n                ],\n                \"weight\": 1,\n            },\n        ]\n    }\n\n\nclass TestMessagesCase:\n    \"\"\"\n    Test cases for the chat messages module\n    \"\"\"\n\n    def test_tool_call_stringify(self, chat_msgs):\n        chat_msgs_as_obj = Chats(**chat_msgs)\n        assert '{\"name\": \"get_stock_price\", \"arguments\": {\"symbol\": \"AAPL\"}}' == str(\n            chat_msgs_as_obj.conversation[2].content[1].value\n        )\n\n    def test_chatml_formatted_wrapper(self, chat_msgs):\n        chat_msg_formatted = ChatFormattedChats(**chat_msgs, formatter=format_message)\n        target_chatml = \"\"\"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nWhat is today's stock price of Apple?<|im_end|>\n<|im_start|>assistant\n<tool_call>\n{\"name\": \"get_date\", \"arguments\": {}}\n</tool_call>\n<tool_call>\n{\"name\": \"get_stock_price\", \"arguments\": {\"symbol\": \"AAPL\"}}\n</tool_call>\n<|im_end|>\n<|im_start|>tool\n<tool_response>\n{\"name\": \"get_date\", \"content\": {\"date\": \"2024-09-09\"}}\n</tool_response>\n<tool_response>\n{\"name\": \"get_stock_price\", \"content\": {\"symbol\": \"AAPL\", \"price\": 123.45}}\n</tool_response>\n<|im_end|>\n<|im_start|>assistant\nThe stock price of Apple is $123.45.\n<reflection>The original query asked for today's stock price of Apple. This implies they also wanted the date included in the response.</reflection>The stock price of Apple on September 9, 2024 is $123.45.<|im_end|>\\n\"\"\"\n        assert target_chatml == str(chat_msg_formatted)\n\n    def test_chatml_formatting_tool_call(self, chat_msgs):\n        chat_msgs_as_obj = Chats(**chat_msgs)\n        target_chatml_turn2 = \"\"\"<|im_start|>assistant\\n<tool_call>\\n{\"name\": \"get_date\", \"arguments\": {}}\\n</tool_call>\\n<tool_call>\\n{\"name\": \"get_stock_price\", \"arguments\": {\"symbol\": \"AAPL\"}}\\n</tool_call>\\n<|im_end|>\\n\"\"\"\n        assert target_chatml_turn2 == str(\n            format_message(chat_msgs_as_obj.conversation[2])\n        )\n\n    def test_train_labels(self, chatml_tokenizer, chat_msgs):\n        chat_msg_formatted = ChatFormattedChats(**chat_msgs, formatter=format_message)\n        tokenized = chat_msg_formatted.conversation[2].tokenized(chatml_tokenizer)\n        # fmt: off\n        target_labels = [\n            -100, -100, -100,  # role\n            27, 14506, 13735, 397, 5018, 609, 794,\n            330, 456, 4257, 498, 330, 16774, 794, 4792, 534, 524,\n            14506, 13735, 397, 27, 14506, 13735, 397, 5018, 609, 794,\n            330, 456, 31641, 9217, 498, 330, 16774, 794, 5324, 19314,\n            794, 330, 84016, 43, 96742, 524, 14506, 13735, 397,\n            128256,  # <|im_end|>\n            -100  # trailing newline\n        ]\n        # fmt: on\n        assert tokenized[\"labels\"] == target_labels\n\n    def test_train_labels_2(self, chatml_tokenizer, chat_msgs):\n        # also test if indivudal contents are set not to train\n        chat_msg_formatted = ChatFormattedChats(**chat_msgs, formatter=format_message)\n        tokenized = chat_msg_formatted.conversation[4].tokenized(chatml_tokenizer)\n        # fmt: off\n        target_labels = [\n            -100, -100, -100,  # role\n            -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,  # initial response\n            27, 78098, 16761, 4113, 3319, 4691, 369, 3432, 596, 5708, 3430,\n            315, 8325, 13, 1115, 24897, 814, 1101, 4934, 279, 2457,\n            5343, 304, 279, 2077, 4005, 78098, 16761, 5708, 3430, 315,\n            8325, 389, 6250, 220, 24, 11, 220, 2366, 19, 374, 400,\n            4513, 13, 1774, 13,\n            128256,  # <|im_end|>\n            -100,  # trailing newline\n        ]\n        # fmt: on\n        assert tokenized[\"labels\"] == target_labels\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/core/test_async_grpo.py",
    "content": "\"\"\"Unit tests for async GRPO\"\"\"\n\nimport unittest\nfrom unittest.mock import MagicMock\n\nimport torch\n\n\nclass TestReplayBuffer(unittest.TestCase):\n    \"\"\"Tests for ReplayBuffer edge cases.\"\"\"\n\n    def test_add_noop_when_max_size_zero(self):\n        from axolotl.core.trainers.grpo.replay_buffer import ReplayBuffer\n\n        buf = ReplayBuffer(max_size=0)\n        buf.add(1.0, {\"data\": \"test\"})\n        self.assertEqual(len(buf), 0)\n\n    def test_add_noop_when_max_size_negative(self):\n        from axolotl.core.trainers.grpo.replay_buffer import ReplayBuffer\n\n        buf = ReplayBuffer(max_size=-1)\n        buf.add(1.0, {\"data\": \"test\"})\n        self.assertEqual(len(buf), 0)\n\n    def test_sample_returns_none_when_max_size_zero(self):\n        from axolotl.core.trainers.grpo.replay_buffer import ReplayBuffer\n\n        buf = ReplayBuffer(max_size=0)\n        self.assertIsNone(buf.sample(1))\n\n    def test_sample_returns_none_when_empty(self):\n        from axolotl.core.trainers.grpo.replay_buffer import ReplayBuffer\n\n        buf = ReplayBuffer(max_size=5)\n        self.assertIsNone(buf.sample(1))\n\n    def test_normal_add_and_sample(self):\n        from axolotl.core.trainers.grpo.replay_buffer import ReplayBuffer\n\n        buf = ReplayBuffer(max_size=3)\n        buf.add(1.0, {\"a\": 1})\n        buf.add(2.0, {\"a\": 2})\n        buf.add(3.0, {\"a\": 3})\n        self.assertEqual(len(buf), 3)\n        result = buf.sample(1)\n        self.assertIsNotNone(result)\n        self.assertEqual(len(result), 1)\n\n    def test_replaces_lowest_when_full(self):\n        from axolotl.core.trainers.grpo.replay_buffer import ReplayBuffer\n\n        buf = ReplayBuffer(max_size=2)\n        buf.add(1.0, {\"a\": 1})\n        buf.add(2.0, {\"a\": 2})\n        buf.add(3.0, {\"a\": 3})  # should replace score=1.0\n        self.assertEqual(len(buf), 2)\n        scores = sorted(item[0] for item in buf._heap)\n        self.assertEqual(scores, [2.0, 3.0])\n\n\nclass TestGRPOStrategyConflict(unittest.TestCase):\n    \"\"\"Tests for sequence_parallel + async_grpo conflict detection.\"\"\"\n\n    def test_raises_on_both_enabled(self):\n        from axolotl.core.trainers.grpo import GRPOStrategy\n\n        with self.assertRaises(ValueError) as ctx:\n            GRPOStrategy.get_trainer_class(sequence_parallel=True, async_grpo=True)\n        self.assertIn(\"sequence_parallel\", str(ctx.exception))\n        self.assertIn(\"async_grpo\", str(ctx.exception))\n\n    def test_sequence_parallel_only(self):\n        from axolotl.core.trainers.grpo import GRPOStrategy\n        from axolotl.core.trainers.grpo.trainer import (\n            AxolotlGRPOSequenceParallelTrainer,\n        )\n\n        cls = GRPOStrategy.get_trainer_class(sequence_parallel=True, async_grpo=False)\n        self.assertIs(cls, AxolotlGRPOSequenceParallelTrainer)\n\n    def test_async_only(self):\n        from axolotl.core.trainers.grpo import GRPOStrategy\n        from axolotl.core.trainers.grpo.trainer import AxolotlAsyncGRPOTrainer\n\n        cls = GRPOStrategy.get_trainer_class(sequence_parallel=False, async_grpo=True)\n        self.assertIs(cls, AxolotlAsyncGRPOTrainer)\n\n    def test_neither(self):\n        from axolotl.core.trainers.grpo import GRPOStrategy\n        from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer\n\n        cls = GRPOStrategy.get_trainer_class(sequence_parallel=False, async_grpo=False)\n        self.assertIs(cls, AxolotlGRPOTrainer)\n\n\nclass TestDequantizeFP8TailBlocks(unittest.TestCase):\n    \"\"\"Tests for FP8 dequantization with non-divisible dimensions.\"\"\"\n\n    def test_exact_divisible_shape(self):\n        from axolotl.kernels.quantize import dequantize_fp8\n\n        W = torch.randn(256, 128, dtype=torch.bfloat16).to(torch.float8_e4m3fn)\n        scale_inv = torch.ones(2, 1, dtype=torch.bfloat16)\n        result = dequantize_fp8(W, scale_inv)\n        self.assertEqual(result.shape, (256, 128))\n        self.assertEqual(result.dtype, torch.bfloat16)\n\n    def test_non_divisible_rows(self):\n        from axolotl.kernels.quantize import dequantize_fp8\n\n        # 130 rows, scale has 2 blocks (block_size ~65 for exact div, but with\n        # tail blocks: first block=65 rows, second=65 rows, 130%2=0 actually).\n        # Use 131 rows with 2 scale blocks to trigger tail handling.\n        W = torch.ones(131, 128, dtype=torch.bfloat16).to(torch.float8_e4m3fn)\n        scale_inv = torch.tensor([[2.0], [3.0]], dtype=torch.bfloat16)\n        result = dequantize_fp8(W, scale_inv)\n        self.assertEqual(result.shape, (131, 128))\n        self.assertEqual(result.dtype, torch.bfloat16)\n\n    def test_non_divisible_cols(self):\n        from axolotl.kernels.quantize import dequantize_fp8\n\n        W = torch.ones(128, 200, dtype=torch.bfloat16).to(torch.float8_e4m3fn)\n        scale_inv = torch.ones(1, 2, dtype=torch.bfloat16)\n        result = dequantize_fp8(W, scale_inv)\n        self.assertEqual(result.shape, (128, 200))\n\n    def test_scalar_scale(self):\n        from axolotl.kernels.quantize import dequantize_fp8\n\n        W = torch.ones(64, 64, dtype=torch.bfloat16).to(torch.float8_e4m3fn)\n        scale_inv = torch.tensor(2.0, dtype=torch.bfloat16)\n        result = dequantize_fp8(W, scale_inv)\n        self.assertEqual(result.shape, (64, 64))\n\n\nclass TestLoraFP8Guard(unittest.TestCase):\n    \"\"\"Tests that get_lora_parameters only uses weight_scale_inv for FP8 weights.\"\"\"\n\n    def test_non_fp8_weight_skips_scale_inv(self):\n        \"\"\"Non-FP8 weight should NOT pick up weight_scale_inv as quant_state.\"\"\"\n        from axolotl.kernels.lora import get_lora_parameters\n\n        proj = MagicMock()\n        proj.disable_adapters = True\n        base_layer = MagicMock(spec=[])  # empty spec to control attrs precisely\n\n        # Use a real tensor for weight (bf16, no quant_state attr)\n        base_layer.weight = torch.randn(64, 64, dtype=torch.bfloat16)\n        base_layer.bias = None\n        base_layer.weight_scale_inv = torch.ones(1)  # should NOT be used for bf16\n\n        proj.base_layer = base_layer\n\n        W, b, quant_state, A, B, s = get_lora_parameters(proj)\n        # quant_state should be None since weight is bf16, not FP8\n        self.assertIsNone(quant_state)\n\n    def test_fp8_weight_uses_scale_inv(self):\n        \"\"\"FP8 weight should pick up weight_scale_inv as quant_state.\"\"\"\n        from axolotl.kernels.lora import get_lora_parameters\n\n        proj = MagicMock()\n        proj.disable_adapters = True\n        base_layer = MagicMock()\n        proj.base_layer = base_layer\n\n        # FP8 weight\n        base_layer.weight = torch.randn(64, 64, dtype=torch.bfloat16).to(\n            torch.float8_e4m3fn\n        )\n        base_layer.bias = None\n        scale_inv = torch.ones(1)\n        base_layer.weight_scale_inv = scale_inv\n\n        W, b, quant_state, A, B, s = get_lora_parameters(proj)\n        self.assertIs(quant_state, scale_inv)\n\n\nclass TestValidateQuantPatchRestore(unittest.TestCase):\n    \"\"\"Test that validate_quantization_for_training is restored after trainer creation.\"\"\"\n\n    def test_patch_restored_on_success(self):\n        \"\"\"Monkeypatch should be restored even after successful trainer creation.\"\"\"\n        import transformers.trainer as _trainer_module\n\n        original = _trainer_module.validate_quantization_for_training\n\n        # After the build() method runs, original should be restored.\n        # We can't easily test the full build(), but we can test the pattern.\n        _orig = _trainer_module.validate_quantization_for_training\n        _trainer_module.validate_quantization_for_training = lambda model: None\n        try:\n            pass  # simulate trainer_cls() succeeding\n        finally:\n            _trainer_module.validate_quantization_for_training = _orig\n\n        self.assertIs(_trainer_module.validate_quantization_for_training, original)\n\n    def test_patch_restored_on_error(self):\n        \"\"\"Monkeypatch should be restored even if trainer creation raises.\"\"\"\n        import transformers.trainer as _trainer_module\n\n        original = _trainer_module.validate_quantization_for_training\n\n        _orig = _trainer_module.validate_quantization_for_training\n        _trainer_module.validate_quantization_for_training = lambda model: None\n        try:\n            raise ValueError(\"test error\")\n        except ValueError:\n            pass\n        finally:\n            _trainer_module.validate_quantization_for_training = _orig\n\n        self.assertIs(_trainer_module.validate_quantization_for_training, original)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/core/test_builders.py",
    "content": "\"\"\"Unit tests for axolotl.core.builders\"\"\"\n\nimport sys\nfrom pathlib import Path\nfrom unittest.mock import MagicMock, patch\n\nimport pytest\n\nfrom axolotl.common.datasets import load_datasets\nfrom axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder\nfrom axolotl.loaders import ModelLoader, load_tokenizer\nfrom axolotl.utils.config import normalize_config\nfrom axolotl.utils.data import prepare_preference_datasets\nfrom axolotl.utils.dict import DictDefault\nfrom axolotl.utils.schemas.enums import RLType\n\nfrom tests.constants import ALPACA_MESSAGES_CONFIG_REVISION\n\n\n@pytest.fixture(name=\"base_cfg\")\ndef fixture_base_cfg():\n    \"\"\"\n    Base config with all common arguments between SFT and RLHF\n    \"\"\"\n    cfg = DictDefault(\n        {\n            # Model and tokenizer settings\n            \"base_model\": \"HuggingFaceTB/SmolLM2-135M-Instruct\",\n            \"sequence_len\": 2048,\n            \"model_config_type\": \"llama\",  # example type\n            # Basic training settings\n            \"micro_batch_size\": 2,\n            \"eval_batch_size\": 2,\n            \"num_epochs\": 1,\n            \"gradient_accumulation_steps\": 1,\n            \"max_steps\": 100,\n            \"val_set_size\": 0,\n            # Optimizer settings\n            \"optimizer\": \"adamw_torch_fused\",\n            \"learning_rate\": 0.00005,\n            \"weight_decay\": 0.01,\n            \"adam_beta1\": 0.998,\n            \"adam_beta2\": 0.9,\n            \"adam_epsilon\": 0.00001,\n            \"max_grad_norm\": 1.0,\n            # LR scheduler settings\n            \"lr_scheduler\": \"cosine\",\n            \"lr_scheduler_kwargs\": {\"foo\": \"bar\"},\n            \"warmup_steps\": 10,\n            \"warmup_ratio\": None,\n            \"cosine_min_lr_ratio\": 0.1,\n            \"cosine_constant_lr_ratio\": 0.2,\n            # Checkpointing and saving\n            \"save_steps\": 100,\n            \"output_dir\": \"./model-out\",\n            \"save_total_limit\": 4,\n            \"save_only_model\": False,\n            # Hardware/performance settings\n            \"gradient_checkpointing\": False,\n            \"gradient_checkpointing_kwargs\": {\"use_reentrant\": False},\n            \"dataloader_num_workers\": 1,\n            \"dataloader_pin_memory\": True,\n            \"dataloader_prefetch_factor\": 2,\n            \"context_parallel_size\": 1,\n            \"tensor_parallel_size\": 1,\n            # Dtype\n            \"fp16\": False,\n            \"bf16\": False,\n            \"tf32\": False,\n            # Logging and evaluation\n            \"logging_steps\": 10,\n            \"eval_steps\": 50,\n            \"eval_strategy\": \"steps\",\n            \"save_strategy\": \"steps\",\n            \"include_tokens_per_second\": True,\n            # Other common settings\n            \"seed\": 42,\n            \"remove_unused_columns\": True,\n            \"ddp_timeout\": 1800,\n            \"ddp_bucket_cap_mb\": 25,\n            \"ddp_broadcast_buffers\": False,\n            \"dataset_num_proc\": 4,\n        }\n    )\n\n    normalize_config(cfg)\n    return cfg\n\n\n@pytest.fixture(name=\"dpo_cfg\")\ndef fixture_dpo_cfg(base_cfg):\n    cfg = base_cfg.copy()\n    cfg.update(\n        {\n            \"rl\": RLType.DPO,\n            \"dpo_use_weighting\": True,\n            \"dpo_label_smoothing\": 0.1,\n            \"beta\": 0.1,  # DPO beta\n        }\n    )\n    return cfg\n\n\n@pytest.fixture(name=\"orpo_cfg\")\ndef fixture_orpo_cfg(base_cfg):\n    cfg = base_cfg.copy()\n    cfg.update(\n        {\n            \"rl\": RLType.ORPO,\n            \"orpo_alpha\": 0.1,\n            \"max_prompt_len\": 512,\n        }\n    )\n    return cfg\n\n\n@pytest.fixture(name=\"kto_cfg\")\ndef fixture_kto_cfg(base_cfg):\n    cfg = base_cfg.copy()\n    cfg.update(\n        {\n            \"rl\": RLType.KTO,\n            \"kto_desirable_weight\": 1.0,\n            \"kto_undesirable_weight\": 1.0,\n            \"max_prompt_len\": 512,\n        }\n    )\n    return cfg\n\n\n@pytest.fixture(name=\"grpo_cfg\")\ndef fixture_grpo_cfg(base_cfg):\n    cfg = base_cfg.copy()\n    cfg.update(\n        {\n            \"rl\": RLType.GRPO,\n            \"trl\": DictDefault(\n                {\n                    \"beta\": 0.001,\n                    \"max_completion_length\": 256,\n                    \"use_vllm\": False,  # run on CPU\n                    # \"vllm_device\": \"auto\",\n                    # \"vllm_gpu_memory_utilization\": 0.15,\n                    \"num_generations\": 4,\n                    \"reward_funcs\": [\"rewards.rand_reward_func\"],\n                }\n            ),\n            # Must be evenly divisible by num_generations\n            \"micro_batch_size\": 4,\n            \"datasets\": [\n                {\n                    \"path\": \"openai/gsm8k\",\n                    \"name\": \"main\",\n                    \"split\": \"train[:1%]\",\n                }\n            ],\n        }\n    )\n    return DictDefault(cfg)\n\n\n@pytest.fixture(name=\"ipo_cfg\")\ndef fixture_ipo_cfg(base_cfg):\n    cfg = base_cfg.copy()\n    cfg.update(\n        {\n            \"rl\": RLType.IPO,\n            \"dpo_label_smoothing\": 0,\n            \"beta\": 0.1,\n        }\n    )\n    return cfg\n\n\n@pytest.fixture(name=\"simpo_cfg\")\ndef fixture_simpo_cfg(base_cfg):\n    cfg = base_cfg.copy()\n    cfg.update(\n        {\n            \"rl\": RLType.SIMPO,\n            \"rl_beta\": 0.2,\n            \"cpo_alpha\": 0.9,\n            \"simpo_gamma\": 0.4,\n        }\n    )\n    return cfg\n\n\n@pytest.fixture(name=\"sft_cfg\")\ndef fixture_sft_cfg(base_cfg):\n    cfg = base_cfg.copy()\n    cfg.update(\n        {\n            \"rl\": None,\n            \"sample_packing\": False,\n            \"eval_sample_packing\": False,\n            \"flash_attention\": False,\n        }\n    )\n    return cfg\n\n\n@pytest.fixture(name=\"rm_cfg\")\ndef fixture_rm_cfg(sft_cfg):\n    cfg = sft_cfg.copy()\n    cfg.update(\n        DictDefault(\n            {\n                \"reward_model\": True,\n                \"datasets\": [\n                    {\n                        \"path\": \"argilla/distilabel-intel-orca-dpo-pairs\",\n                        \"type\": \"bradley_terry.chat_template\",\n                        \"split\": \"train[:1%]\",\n                    }\n                ],\n            }\n        )\n    )\n    return cfg\n\n\n@pytest.fixture(name=\"prm_cfg\")\ndef fixture_prm_cfg(sft_cfg):\n    cfg = sft_cfg.copy()\n    cfg.update(\n        DictDefault(\n            {\n                \"process_reward_model\": True,\n                \"datasets\": [\n                    {\n                        \"path\": \"trl-lib/math_shepherd\",\n                        \"type\": \"stepwise_supervised\",\n                        \"split\": \"train[:1%]\",\n                    }\n                ],\n            }\n        )\n    )\n    return cfg\n\n\n@pytest.fixture(name=\"tokenizer\")\ndef fixture_tokenizer(base_cfg):\n    return load_tokenizer(base_cfg)\n\n\n@pytest.fixture(name=\"model\")\ndef fixture_model(base_cfg, tokenizer):\n    model, _ = ModelLoader(base_cfg, tokenizer).load()\n    return model\n\n\nclass TestHFRLTrainerBuilder:\n    \"\"\"\n    TestCase class for RLHF trainer builders\n    \"\"\"\n\n    def _test_common_training_arguments(self, training_arguments, rl: str):\n        \"\"\"Helper to test common arguments across all variants\"\"\"\n        # Basic training settings\n        if rl == \"grpo\":\n            # grpo_cfg's micro_batch_size is diff from others\n            assert training_arguments.per_device_train_batch_size == 4\n        else:\n            assert training_arguments.per_device_train_batch_size == 2\n        assert training_arguments.gradient_accumulation_steps == 1\n        assert training_arguments.max_steps == 100\n\n        # Optimizer settings\n        assert training_arguments.learning_rate == 0.00005\n        assert training_arguments.weight_decay == 0.01\n        assert training_arguments.adam_beta1 == 0.998\n        assert training_arguments.adam_beta2 == 0.9\n        assert training_arguments.adam_epsilon == 0.00001\n        assert training_arguments.max_grad_norm == 1.0\n\n        # LR scheduler settings\n        assert training_arguments.lr_scheduler_type == \"cosine\"\n        assert training_arguments.warmup_steps == 10\n        assert training_arguments.cosine_min_lr_ratio == 0.1\n        assert training_arguments.cosine_constant_lr_ratio == 0.2\n\n        # Other settings\n        assert training_arguments.dataloader_num_workers == 1\n        assert training_arguments.dataloader_pin_memory is True\n\n        # TODO(wing): restore once trl releases 0.22.0\n        # assert training_arguments.gradient_checkpointing is True\n\n    def test_dpo_training_arguments(self, dpo_cfg, model, tokenizer):\n        builder = HFRLTrainerBuilder(dpo_cfg, model, tokenizer)\n        training_arguments, _ = builder._build_training_arguments(100)\n\n        self._test_common_training_arguments(training_arguments, rl=dpo_cfg.rl)\n        # DPO specific\n        assert training_arguments.beta == 0.1\n        assert hasattr(training_arguments, \"use_weighting\")\n        assert training_arguments.use_weighting is True\n        assert training_arguments.label_smoothing == 0.1\n\n    def test_orpo_training_arguments(self, orpo_cfg, model, tokenizer):\n        builder = HFRLTrainerBuilder(orpo_cfg, model, tokenizer)\n        training_arguments, _ = builder._build_training_arguments(100)\n\n        self._test_common_training_arguments(training_arguments, rl=orpo_cfg.rl)\n        # ORPO specific\n        assert training_arguments.beta == 0.1  # maps from orpo_alpha\n\n    def test_kto_training_arguments(self, kto_cfg, model, tokenizer):\n        builder = HFRLTrainerBuilder(kto_cfg, model, tokenizer)\n        training_arguments, _ = builder._build_training_arguments(100)\n\n        self._test_common_training_arguments(training_arguments, rl=kto_cfg.rl)\n        # KTO specific\n        assert training_arguments.desirable_weight == 1.0\n        assert training_arguments.undesirable_weight == 1.0\n\n    def _write_rewards_file(self, rewards_dir: Path):\n        \"\"\"\n        Writes reward function to local tmp path to be loaded on trainer building\n        \"\"\"\n        # Create rewards.py in a directory we can import from\n        rewards_dir.mkdir()\n        rewards_file = rewards_dir / \"rewards.py\"\n        rewards_file.write_text(\n            \"\"\"import random\ndef rand_reward_func(prompts, completions) -> list[float]:\n    return [random.uniform(0, 1) for _ in completions]\n\"\"\"\n        )\n\n    def test_grpo_training_arguments(self, grpo_cfg, model, tokenizer, tmp_path):\n        rewards_dir = tmp_path / \"rewards_test\"\n        self._write_rewards_file(rewards_dir)\n\n        # Add the directory to Python path so we can import the module\n        sys.path.insert(0, str(rewards_dir))\n\n        try:\n            builder = HFRLTrainerBuilder(grpo_cfg, model, tokenizer)\n            training_arguments, _ = builder._build_training_arguments(100)\n            builder.train_dataset = MagicMock()\n\n            self._test_common_training_arguments(training_arguments, rl=grpo_cfg.rl)\n            # GRPO specific\n            assert training_arguments.beta == 0.001\n            assert training_arguments.max_completion_length == 256\n            assert training_arguments.use_vllm is False\n            # assert training_arguments.vllm_device == \"auto\"\n            # assert training_arguments.vllm_gpu_memory_utilization == 0.15\n            assert training_arguments.num_generations == 4\n\n            # Test trainer creation to verify reward_funcs\n            trainer = builder.build(100)\n\n            # Verify reward functions are properly loaded\n            assert len(trainer.reward_funcs) == 1\n            assert trainer.reward_funcs[0].__module__ == \"rewards\"\n            assert trainer.reward_funcs[0].__name__ == \"rand_reward_func\"\n        finally:\n            # remove imported module from path\n            if str(rewards_dir) in sys.path:\n                sys.path.remove(str(rewards_dir))\n\n    def test_ipo_training_arguments(self, ipo_cfg, model, tokenizer):\n        builder = HFRLTrainerBuilder(ipo_cfg, model, tokenizer)\n        training_arguments, _ = builder._build_training_arguments(100)\n\n        self._test_common_training_arguments(training_arguments, rl=ipo_cfg.rl)\n        # IPO specific\n        assert training_arguments.beta == 0.1\n        assert training_arguments.loss_type == [\"ipo\"]\n        assert training_arguments.label_smoothing == 0\n\n    def test_simpo_training_arguments(self, simpo_cfg, model, tokenizer):\n        builder = HFRLTrainerBuilder(simpo_cfg, model, tokenizer)\n        training_arguments, _ = builder._build_training_arguments(100)\n\n        self._test_common_training_arguments(training_arguments, rl=simpo_cfg.rl)\n        # SIMPO specific\n        assert training_arguments.beta == 0.2\n        assert training_arguments.cpo_alpha == 0.9\n        assert training_arguments.simpo_gamma == 0.4\n\n    @pytest.mark.parametrize(\n        (\"cfg_string\", \"dataset_name\"),\n        [\n            (\n                \"dpo_cfg\",\n                \"dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff\",\n            ),\n            (\n                \"ipo_cfg\",\n                \"dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff\",\n            ),\n            (\n                \"grpo_cfg\",\n                \"dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff\",\n            ),\n            (\"orpo_cfg\", None),  # don't use fixture for orpo to use smaller split\n            (\"kto_cfg\", None),  # no fixture for kto\n            # (\n            #     \"simpo_cfg\",\n            #     \"dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff\",\n            # ),\n        ],\n    )\n    def test_custom_optimizer_cls_and_kwargs(\n        self,\n        request,\n        cfg_string,\n        dataset_name,\n        tmp_path,\n        model,\n        tokenizer,\n    ):\n        cfg = request.getfixturevalue(cfg_string)\n\n        builder = HFRLTrainerBuilder(cfg, model, tokenizer)\n        cfg[\"optimizer\"] = \"muon\"\n\n        if cfg_string in [\"dpo_cfg\", \"ipo_cfg\", \"grpo_cfg\", \"simpo_cfg\"]:\n            cfg[\"datasets\"] = [DictDefault(ALPACA_MESSAGES_CONFIG_REVISION)]\n        elif cfg_string == \"kto_cfg\":\n            cfg[\"datasets\"] = [\n                DictDefault(\n                    {\n                        \"path\": \"argilla/ultrafeedback-binarized-preferences-cleaned-kto\",\n                        \"type\": \"llama3.ultra\",\n                        \"split\": \"train[:1%]\",\n                    }\n                )\n            ]\n        elif cfg_string == \"orpo_cfg\":\n            cfg[\"datasets\"] = [\n                DictDefault(\n                    {\n                        \"path\": \"argilla/ultrafeedback-binarized-preferences-cleaned\",\n                        \"type\": \"chat_template.argilla\",\n                        \"split\": \"train[:1%]\",\n                    }\n                )\n            ]\n        else:\n            raise ValueError(f\"Unhandled cfg_string: {cfg_string}\")\n        cfg[\"dataset_num_proc\"] = 4\n\n        if cfg_string == \"grpo_cfg\":\n            rewards_dir = tmp_path / \"rewards_test\"\n            self._write_rewards_file(rewards_dir)\n\n            # Add the directory to Python path so we can import the module\n            sys.path.insert(0, str(rewards_dir))\n\n        try:\n            # Only use mock for the commented out configs\n            if dataset_name is not None:\n                with patch(\n                    \"axolotl.utils.data.rl.load_dataset_with_config\"\n                ) as mock_load_dataset:\n                    mock_load_dataset.return_value = request.getfixturevalue(\n                        dataset_name\n                    )\n                    train_dataset, eval_dataset = prepare_preference_datasets(\n                        cfg, tokenizer\n                    )\n            else:\n                # Load actual datasets for orpo_cfg and kto_cfg\n                train_dataset, eval_dataset = prepare_preference_datasets(\n                    cfg, tokenizer\n                )\n\n            builder.train_dataset = train_dataset\n            builder.eval_dataset = eval_dataset\n\n            trainer = builder.build(100)\n\n            assert trainer.optimizer_cls_and_kwargs is not None\n\n            from axolotl.contribs.mit.muon import MuonOptimizerFactory\n            from axolotl.contribs.mit.muon.muon import Muon\n\n            optimizer_cls, optimizer_kwargs = trainer.optimizer_cls_and_kwargs\n            assert optimizer_cls is MuonOptimizerFactory\n            assert optimizer_kwargs[\"lr\"] == 0.00005\n            assert optimizer_kwargs[\"weight_decay\"] == 0.01\n            assert optimizer_kwargs[\"betas\"] == (0.998, 0.9)\n            assert optimizer_kwargs[\"eps\"] == 0.00001\n\n            # Ensure optimizer is created with correct class\n            optim = trainer.create_optimizer()\n            assert isinstance(optim, Muon)\n\n        finally:\n            # remove imported module from path\n            if cfg_string == \"grpo_cfg\" and str(rewards_dir) in sys.path:\n                sys.path.remove(str(rewards_dir))\n\n\nclass TestHFCausalTrainerBuilder:\n    \"\"\"\n    TestCase class for SFT trainer builder\n    \"\"\"\n\n    def test_training_arguments(self, sft_cfg, model, tokenizer):\n        builder = HFCausalTrainerBuilder(sft_cfg, model, tokenizer)\n        trainer = builder.build(100)\n        training_arguments = trainer.args\n\n        # Test common arguments\n        assert training_arguments.per_device_train_batch_size == 2\n        assert training_arguments.gradient_accumulation_steps == 1\n        assert training_arguments.max_steps == 100\n\n        assert training_arguments.learning_rate == 0.00005\n        assert training_arguments.weight_decay == 0.01\n        assert training_arguments.adam_beta1 == 0.998\n        assert training_arguments.adam_beta2 == 0.9\n        assert training_arguments.adam_epsilon == 0.00001\n        assert training_arguments.max_grad_norm == 1.0\n\n        assert training_arguments.lr_scheduler_type == \"cosine\"\n        assert training_arguments.warmup_steps == 10\n        assert training_arguments.cosine_min_lr_ratio == 0.1\n\n        assert training_arguments.dataloader_num_workers == 1\n        assert training_arguments.dataloader_pin_memory is True\n        assert training_arguments.gradient_checkpointing is False\n\n        # SFT specific\n        assert training_arguments.sample_packing is False\n        assert training_arguments.eval_sample_packing is False\n\n    @pytest.mark.parametrize(\n        \"cfg_string\",\n        [\n            \"sft_cfg\",\n            \"rm_cfg\",\n            \"prm_cfg\",\n        ],\n    )\n    def test_builder_w_rm_trainers(self, request, cfg_string, model, tokenizer):\n        cfg = request.getfixturevalue(cfg_string)\n        builder = HFCausalTrainerBuilder(cfg, model, tokenizer)\n        cfg[\"optimizer\"] = \"muon\"\n\n        # need to load datasets for reward model and process reward model trainer\n        if cfg_string in [\"rm_cfg\", \"prm_cfg\"]:\n            dataset_meta = load_datasets(cfg=cfg)\n\n            builder.train_dataset = dataset_meta.train_dataset\n            builder.eval_dataset = dataset_meta.eval_dataset\n\n        trainer = builder.build(100)\n\n        assert trainer.optimizer_cls_and_kwargs is not None\n\n        from axolotl.contribs.mit.muon import MuonOptimizerFactory\n        from axolotl.contribs.mit.muon.muon import Muon\n\n        optimizer_cls, optimizer_kwargs = trainer.optimizer_cls_and_kwargs\n        assert optimizer_cls is MuonOptimizerFactory\n        assert optimizer_kwargs[\"lr\"] == 0.00005\n        assert optimizer_kwargs[\"weight_decay\"] == 0.01\n        assert optimizer_kwargs[\"betas\"] == (0.998, 0.9)\n        assert optimizer_kwargs[\"eps\"] == 0.00001\n\n        # Ensure optimizer is created with correct class\n        optim = trainer.create_optimizer()\n        assert isinstance(optim, Muon)\n\n\nclass TestTrainerClsPlugin:\n    \"\"\"\n    TestCase class for trainer builder with plugin\n    \"\"\"\n\n    def test_trainer_cls_is_not_none_with_plugin(self, kto_cfg, model, tokenizer):\n        \"\"\"\n        Test that the trainer cls is not none with plugin\n\n        Fixes #2693\n        \"\"\"\n        cfg = kto_cfg.copy()\n        cfg.plugins = [\"axolotl.integrations.liger.LigerPlugin\"]\n\n        # Expected AttributeError as we don't pass regular model configs to RL trainer builder\n        # If it throws `TypeError: None is not a callable object`, trainer_cls could be None\n        try:\n            builder = HFRLTrainerBuilder(cfg, model, tokenizer)\n\n            builder.build(100)\n        except TypeError as e:\n            # Error raised if trainer_cls is None\n            assert \"'tuple' object has no attribute 'config'\" not in str(e)\n        except Exception:\n            # Another error happens, so we passed trainer_cls to builder\n            pass\n"
  },
  {
    "path": "tests/e2e/.gitignore",
    "content": "last_run_prepared\n"
  },
  {
    "path": "tests/e2e/__init__.py",
    "content": ""
  },
  {
    "path": "tests/e2e/integrations/test_cut_cross_entropy.py",
    "content": "\"\"\"\nSimple end-to-end test for Cut Cross Entropy integration\n\"\"\"\n\nimport pytest\n\nfrom axolotl.common.datasets import load_datasets\nfrom axolotl.train import train\nfrom axolotl.utils import get_pytorch_version\nfrom axolotl.utils.config import normalize_config, prepare_plugins, validate_config\nfrom axolotl.utils.dict import DictDefault\n\nfrom tests.e2e.utils import check_model_output_exists\n\n\n@pytest.fixture()\ndef min_cfg(temp_dir):\n    return {\n        \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n        \"plugins\": [\n            \"axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin\",\n        ],\n        \"cut_cross_entropy\": True,\n        \"sequence_len\": 1024,\n        \"val_set_size\": 0.02,\n        \"special_tokens\": {\n            \"pad_token\": \"<|endoftext|>\",\n        },\n        \"datasets\": [\n            {\n                \"path\": \"mhenrichsen/alpaca_2k_test\",\n                \"type\": \"alpaca\",\n            },\n        ],\n        \"num_epochs\": 1,\n        \"micro_batch_size\": 8,\n        \"gradient_accumulation_steps\": 1,\n        \"learning_rate\": 0.00001,\n        \"optimizer\": \"adamw_torch_fused\",\n        \"output_dir\": temp_dir,\n        \"lr_scheduler\": \"cosine\",\n        \"max_steps\": 10,\n        \"bf16\": \"auto\",\n        \"save_first_step\": False,\n    }\n\n\nclass TestCutCrossEntropyIntegration:\n    \"\"\"\n    e2e tests for cut_cross_entropy integration with Axolotl\n    \"\"\"\n\n    def test_llama_w_cce(self, min_cfg, temp_dir):\n        cfg = DictDefault(min_cfg)\n        cfg = validate_config(cfg)\n        prepare_plugins(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        major, minor, _ = get_pytorch_version()\n        if (major, minor) < (2, 4):\n            with pytest.raises(ImportError):\n                train(cfg=cfg, dataset_meta=dataset_meta)\n        else:\n            train(cfg=cfg, dataset_meta=dataset_meta)\n            check_model_output_exists(temp_dir, cfg)\n\n    def test_qwen2_w_cce(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"Qwen/Qwen2.5-0.5B\",\n                \"plugins\": [\n                    \"axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin\",\n                ],\n                \"cut_cross_entropy\": True,\n                \"sequence_len\": 1024,\n                \"val_set_size\": 0.02,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": \"alpaca\",\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"micro_batch_size\": 4,\n                \"gradient_accumulation_steps\": 1,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_torch_fused\",\n                \"output_dir\": temp_dir,\n                \"lr_scheduler\": \"cosine\",\n                \"max_steps\": 10,\n                \"bf16\": \"auto\",\n                \"save_first_step\": False,\n            }\n        )\n        cfg = validate_config(cfg)\n        prepare_plugins(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        major, minor, _ = get_pytorch_version()\n        if (major, minor) < (2, 4):\n            with pytest.raises(ImportError):\n                train(cfg=cfg, dataset_meta=dataset_meta)\n        else:\n            train(cfg=cfg, dataset_meta=dataset_meta)\n            check_model_output_exists(temp_dir, cfg)\n\n    @pytest.mark.parametrize(\n        \"attention_type\",\n        [\n            \"flash_attention\",\n            \"sdp_attention\",\n            # \"xformers_attention\",\n        ],\n    )\n    def test_llama_w_cce_and_attention(self, min_cfg, temp_dir, attention_type):\n        cfg = DictDefault(\n            min_cfg\n            | {\n                attention_type: True,\n            }\n        )\n        cfg = validate_config(cfg)\n        prepare_plugins(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        major, minor, _ = get_pytorch_version()\n        if (major, minor) < (2, 4):\n            with pytest.raises(ImportError):\n                train(cfg=cfg, dataset_meta=dataset_meta)\n        else:\n            train(cfg=cfg, dataset_meta=dataset_meta)\n            check_model_output_exists(temp_dir, cfg)\n"
  },
  {
    "path": "tests/e2e/integrations/test_fp8.py",
    "content": "\"\"\"\nSimple end-to-end smoke tests for FP8 mixed precision training\n\"\"\"\n\nfrom axolotl.common.datasets import load_datasets\nfrom axolotl.train import train\nfrom axolotl.utils.config import normalize_config, validate_config\nfrom axolotl.utils.dict import DictDefault\n\nfrom tests.e2e.utils import check_model_output_exists, require_torch_2_7_0\n\n\nclass FP8IntegrationTestCase:\n    \"\"\"\n    e2e smoke tests for FP8 mixed precision training with Axolotl\n    \"\"\"\n\n    @require_torch_2_7_0\n    def test_fp8_single_gpu_smoke(self, temp_dir):\n        \"\"\"Smoke test for single GPU FP8 + torch.compile training\"\"\"\n\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"tokenizer_type\": \"AutoTokenizer\",\n                \"trust_remote_code\": True,\n                \"sequence_len\": 512,\n                \"val_set_size\": 0.05,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": \"alpaca\",\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"max_steps\": 3,  # Very short smoke test\n                \"micro_batch_size\": 1,\n                \"gradient_accumulation_steps\": 2,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_torch_fused\",\n                \"lr_scheduler\": \"cosine\",\n                \"sdp_attention\": True,\n                \"pad_to_seq_len\": True,\n                \"sample_packing\": True,\n                \"fp8\": True,\n                \"torch_compile\": True,\n                \"save_first_step\": False,\n            }\n        )\n\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        check_model_output_exists(temp_dir, cfg)\n"
  },
  {
    "path": "tests/e2e/integrations/test_hooks.py",
    "content": "\"\"\"\ne2e tests to make sure all the hooks are fired on the plugin\n\"\"\"\n\nimport os\nfrom pathlib import Path\n\nfrom axolotl.common.datasets import load_datasets\nfrom axolotl.integrations.base import BasePlugin\nfrom axolotl.train import train\nfrom axolotl.utils.config import normalize_config, prepare_plugins, validate_config\nfrom axolotl.utils.dict import DictDefault\n\nfrom tests.e2e.utils import check_model_output_exists\n\n\nclass LogHooksPlugin(BasePlugin):\n    \"\"\"\n    fixture to capture in a log file each hook that was fired\n    \"\"\"\n\n    base_dir = Path(\"/tmp/axolotl-log-hooks\")\n\n    def __init__(self):\n        self.base_dir.mkdir(parents=True, exist_ok=True)\n        try:\n            os.remove(self.base_dir.joinpath(\"plugin_hooks.log\"))\n        except FileNotFoundError:\n            pass\n\n    def post_trainer_create(self, cfg, trainer):\n        with open(\n            self.base_dir.joinpath(\"plugin_hooks.log\"), \"a\", encoding=\"utf-8\"\n        ) as f:\n            f.write(\"post_trainer_create\\n\")\n\n    def pre_model_load(self, cfg):\n        with open(\n            self.base_dir.joinpath(\"plugin_hooks.log\"), \"a\", encoding=\"utf-8\"\n        ) as f:\n            f.write(\"pre_model_load\\n\")\n\n    def post_model_build(self, cfg, model):\n        with open(\n            self.base_dir.joinpath(\"plugin_hooks.log\"), \"a\", encoding=\"utf-8\"\n        ) as f:\n            f.write(\"post_model_build\\n\")\n\n    def pre_lora_load(self, cfg, model):\n        with open(\n            self.base_dir.joinpath(\"plugin_hooks.log\"), \"a\", encoding=\"utf-8\"\n        ) as f:\n            f.write(\"pre_lora_load\\n\")\n\n    def post_lora_load(self, cfg, model):\n        with open(\n            self.base_dir.joinpath(\"plugin_hooks.log\"), \"a\", encoding=\"utf-8\"\n        ) as f:\n            f.write(\"post_lora_load\\n\")\n\n    def post_model_load(self, cfg, model):\n        with open(\n            self.base_dir.joinpath(\"plugin_hooks.log\"), \"a\", encoding=\"utf-8\"\n        ) as f:\n            f.write(\"post_model_load\\n\")\n\n    def create_optimizer(self, cfg, trainer):\n        with open(\n            self.base_dir.joinpath(\"plugin_hooks.log\"), \"a\", encoding=\"utf-8\"\n        ) as f:\n            f.write(\"create_optimizer\\n\")\n\n    def get_trainer_cls(self, cfg):\n        with open(\n            self.base_dir.joinpath(\"plugin_hooks.log\"), \"a\", encoding=\"utf-8\"\n        ) as f:\n            f.write(\"get_trainer_cls\\n\")\n\n    def create_lr_scheduler(self, cfg, trainer, optimizer, num_training_steps):\n        with open(\n            self.base_dir.joinpath(\"plugin_hooks.log\"), \"a\", encoding=\"utf-8\"\n        ) as f:\n            f.write(\"create_lr_scheduler\\n\")\n\n    def add_callbacks_pre_trainer(self, cfg, model):\n        with open(\n            self.base_dir.joinpath(\"plugin_hooks.log\"), \"a\", encoding=\"utf-8\"\n        ) as f:\n            f.write(\"add_callbacks_pre_trainer\\n\")\n        return []\n\n    def add_callbacks_post_trainer(self, cfg, trainer):\n        with open(\n            self.base_dir.joinpath(\"plugin_hooks.log\"), \"a\", encoding=\"utf-8\"\n        ) as f:\n            f.write(\"add_callbacks_post_trainer\\n\")\n        return []\n\n    def post_train(self, cfg, model):\n        with open(\n            self.base_dir.joinpath(\"plugin_hooks.log\"), \"a\", encoding=\"utf-8\"\n        ) as f:\n            f.write(\"post_train\\n\")\n\n    def post_train_unload(self, cfg):\n        with open(\n            self.base_dir.joinpath(\"plugin_hooks.log\"), \"a\", encoding=\"utf-8\"\n        ) as f:\n            f.write(\"post_train_unload\\n\")\n\n\nclass TestPluginHooks:\n    \"\"\"\n    e2e tests to make sure all the hooks are fired during the training\n    \"\"\"\n\n    def test_plugin_hooks(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"plugins\": [\n                    \"tests.e2e.integrations.test_hooks.LogHooksPlugin\",\n                ],\n                \"tokenizer_type\": \"AutoTokenizer\",\n                \"sequence_len\": 1024,\n                \"adapter\": \"lora\",\n                \"lora_r\": 8,\n                \"lora_alpha\": 16,\n                \"lora_dropout\": 0.05,\n                \"lora_target_linear\": True,\n                \"val_set_size\": 0.02,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": \"alpaca\",\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"micro_batch_size\": 2,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_torch_fused\",\n                \"lr_scheduler\": \"cosine\",\n                \"max_steps\": 5,\n                \"flash_attention\": True,\n                \"bf16\": \"auto\",\n                \"save_first_step\": False,\n            }\n        )\n\n        cfg = validate_config(cfg)\n        prepare_plugins(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        check_model_output_exists(temp_dir, cfg)\n\n        with open(\n            \"/tmp/axolotl-log-hooks\" + \"/plugin_hooks.log\", \"r\", encoding=\"utf-8\"\n        ) as f:\n            file_contents = f.readlines()\n            file_contents = \"\\n\".join(file_contents)\n            assert \"post_trainer_create\" in file_contents\n            assert \"pre_model_load\" in file_contents\n            assert \"post_model_build\" in file_contents\n            assert \"pre_lora_load\" in file_contents\n            assert \"post_lora_load\" in file_contents\n            assert \"post_model_load\" in file_contents\n            # assert \"create_optimizer\" in file_contents  # not implemented yet\n            assert \"get_trainer_cls\" in file_contents\n            assert \"create_lr_scheduler\" in file_contents\n            assert \"add_callbacks_pre_trainer\" in file_contents\n            assert \"add_callbacks_post_trainer\" in file_contents\n            assert \"post_train\" in file_contents\n            # assert \"post_train_unload\" in file_contents  # not called from test train call\n\n        try:\n            os.remove(\"/tmp/axolotl-log-hooks\" + \"/plugin_hooks.log\")\n        except FileNotFoundError:\n            pass\n"
  },
  {
    "path": "tests/e2e/integrations/test_kd.py",
    "content": "\"\"\"\ne2e tests for kd trainer support in Axolotl\n\"\"\"\n\nfrom pathlib import Path\n\nimport pytest\nimport yaml\nfrom accelerate.test_utils import execute_subprocess_async, get_torch_dist_unique_port\n\nfrom axolotl.utils.dict import DictDefault\n\nfrom tests.e2e.utils import check_tensorboard, require_torch_2_5_1\n\n\n@pytest.fixture(name=\"kd_min_cfg\")\ndef min_cfg(temp_dir):\n    return {\n        \"base_model\": \"Qwen/Qwen3-0.6B\",\n        \"tokenizer_config\": \"winglian/qwen3-14b-math\",\n        \"plugins\": [\n            \"axolotl.integrations.kd.KDPlugin\",\n            \"axolotl.integrations.liger.LigerPlugin\",\n        ],\n        \"liger_rms_norm\": True,\n        \"liger_glu_activation\": True,\n        \"torch_compile\": True,\n        \"chat_template\": \"qwen3\",\n        \"kd_trainer\": True,\n        \"kd_ce_alpha\": 0.1,\n        \"kd_alpha\": 0.9,\n        \"kd_temperature\": 1.0,\n        \"kd_beta\": 0.0,\n        \"kd_normalize_topk\": True,\n        \"dataloader_prefetch_factor\": 8,\n        \"dataloader_num_workers\": 4,\n        \"dataloader_pin_memory\": True,\n        \"datasets\": [\n            {\n                \"path\": \"winglian/OpenThoughts-114k-math-correct-qwen3-14b-math-prepared-topk128-normalized\",\n                \"type\": \"chat_template\",\n                \"split\": \"train\",\n                \"split_thinking\": True,\n                \"eot_tokens\": [\"<|im_end|>\"],\n                \"data_files\": [\"train/batch-000000.parquet\"],\n            },\n        ],\n        \"skip_prepare_dataset\": True,\n        \"val_set_size\": 0.0,\n        \"sequence_len\": 2048,\n        \"sample_packing\": True,\n        \"pad_to_sequence_len\": True,\n        \"gradient_accumulation_steps\": 2,\n        \"micro_batch_size\": 1,\n        \"num_epochs\": 1,\n        \"optimizer\": \"adamw_8bit\",\n        \"lr_scheduler\": \"cosine\",\n        \"learning_rate\": 0.00001,\n        \"bf16\": \"auto\",\n        \"gradient_checkpointing\": True,\n        \"flash_attention\": True,\n        \"special_tokens\": {\n            \"pad_token\": \"<|end_of_text|>\",\n            \"eos_token\": \"<|eot_id|>\",\n        },\n        \"max_steps\": 5,\n        \"output_dir\": temp_dir,\n        \"use_tensorboard\": True,\n        \"save_first_step\": False,\n    }\n\n\nclass TestKnowledgeDistillation:\n    \"\"\"\n    Test case for Knowledge Distillation\n    \"\"\"\n\n    # While this will run on torch 2.4.x without torch_compile enabled\n    # the VRAM requirement is higher than what is available in CI\n    @require_torch_2_5_1\n    def test_llama_kd(self, temp_dir, kd_min_cfg):\n        cfg = DictDefault(kd_min_cfg)\n\n        # write cfg to yaml file\n        Path(temp_dir).mkdir(parents=True, exist_ok=True)\n        with open(Path(temp_dir) / \"config.yaml\", \"w\", encoding=\"utf-8\") as fout:\n            fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))\n\n        execute_subprocess_async(\n            [\n                \"axolotl\",\n                \"train\",\n                str(Path(temp_dir) / \"config.yaml\"),\n                \"--num-processes\",\n                \"1\",\n                \"--main-process-port\",\n                f\"{get_torch_dist_unique_port()}\",\n            ]\n        )\n\n        assert (Path(temp_dir) / \"model.safetensors\").exists()\n        check_tensorboard(\n            temp_dir + \"/runs\", \"train/loss\", 1.4, \"Train Loss (%s) is too high\"\n        )\n\n    @pytest.mark.parametrize(\n        \"load_in_8bit\",\n        [True, False],\n    )\n    def test_llama_lora_kd(self, temp_dir, kd_min_cfg, load_in_8bit):\n        cfg = DictDefault(\n            {\n                \"load_in_8bit\": load_in_8bit,\n                \"torch_compile\": False,\n                \"adapter\": \"lora\",\n                \"peft_use_dora\": True,\n                \"lora_target_linear\": True,\n                \"lora_r\": 16,\n                \"lora_alpha\": 32,\n                \"lora_dropout\": 0.0,\n                \"lora_modules_to_save\": [\"embed_tokens\", \"lm_head\"],\n                \"lora_mlp_kernel\": False,\n                \"lora_qkv_kernel\": False,\n                \"lora_o_kernel\": False,\n            }\n            | kd_min_cfg\n        )\n\n        # write cfg to yaml file\n        Path(temp_dir).mkdir(parents=True, exist_ok=True)\n        with open(Path(temp_dir) / \"config.yaml\", \"w\", encoding=\"utf-8\") as fout:\n            fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))\n\n        execute_subprocess_async(\n            [\n                \"axolotl\",\n                \"train\",\n                str(Path(temp_dir) / \"config.yaml\"),\n                \"--num-processes\",\n                \"1\",\n                \"--main-process-port\",\n                f\"{get_torch_dist_unique_port()}\",\n            ]\n        )\n        assert (Path(temp_dir) / \"adapter_model.safetensors\").exists()\n        check_tensorboard(\n            temp_dir + \"/runs\", \"train/loss\", 1.2, \"Train Loss (%s) is too high\"\n        )\n"
  },
  {
    "path": "tests/e2e/integrations/test_liger.py",
    "content": "\"\"\"\nSimple end-to-end test for Liger integration\n\"\"\"\n\nimport pytest\n\nfrom axolotl.common.datasets import load_datasets\nfrom axolotl.train import train\nfrom axolotl.utils.config import normalize_config, prepare_plugins, validate_config\nfrom axolotl.utils.dict import DictDefault\n\nfrom tests.e2e.utils import check_model_output_exists, require_torch_2_4_1\n\n\nclass LigerIntegrationTestCase:\n    \"\"\"\n    e2e tests for liger integration with Axolotl\n    \"\"\"\n\n    @require_torch_2_4_1\n    def test_llama_wo_flce(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"plugins\": [\n                    \"axolotl.integrations.liger.LigerPlugin\",\n                ],\n                \"liger_rope\": True,\n                \"liger_rms_norm\": True,\n                \"liger_glu_activation\": True,\n                \"liger_cross_entropy\": True,\n                \"liger_fused_linear_cross_entropy\": False,\n                \"sequence_len\": 1024,\n                \"val_set_size\": 0.05,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": \"alpaca\",\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"micro_batch_size\": 2,\n                \"gradient_accumulation_steps\": 2,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_torch_fused\",\n                \"lr_scheduler\": \"cosine\",\n                \"bf16\": \"auto\",\n                \"max_steps\": 5,\n                \"save_first_step\": False,\n            }\n        )\n\n        cfg = validate_config(cfg)\n        prepare_plugins(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        check_model_output_exists(temp_dir, cfg)\n\n    @require_torch_2_4_1\n    @pytest.mark.parametrize(\n        \"liger_use_token_scaling\",\n        [True, False],\n    )\n    def test_llama_w_flce(self, temp_dir, liger_use_token_scaling):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"plugins\": [\n                    \"axolotl.integrations.liger.LigerPlugin\",\n                ],\n                \"liger_rope\": True,\n                \"liger_rms_norm\": True,\n                \"liger_glu_activation\": True,\n                \"liger_cross_entropy\": False,\n                \"liger_fused_linear_cross_entropy\": True,\n                \"liger_use_token_scaling\": liger_use_token_scaling,\n                \"sequence_len\": 1024,\n                \"val_set_size\": 0.05,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": \"alpaca\",\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"micro_batch_size\": 2,\n                \"gradient_accumulation_steps\": 2,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_torch_fused\",\n                \"lr_scheduler\": \"cosine\",\n                \"bf16\": \"auto\",\n                \"max_steps\": 5,\n                \"save_first_step\": False,\n            }\n        )\n\n        cfg = validate_config(cfg)\n        prepare_plugins(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        check_model_output_exists(temp_dir, cfg)\n"
  },
  {
    "path": "tests/e2e/integrations/test_llm_compressor.py",
    "content": "\"\"\"\nE2E smoke tests for LLMCompressorPlugin integration\n\"\"\"\n\nfrom pathlib import Path\n\nimport pytest\n\nfrom axolotl.common.datasets import load_datasets\nfrom axolotl.train import train\nfrom axolotl.utils.config import normalize_config, prepare_plugins, validate_config\nfrom axolotl.utils.dict import DictDefault\n\nfrom tests.e2e.utils import (\n    check_model_output_exists,\n    require_llmcompressor,\n    require_torch_2_4_1,\n)\n\nMODELS = [\n    \"nm-testing/llama2.c-stories42M-pruned2.4-compressed\",\n    \"nm-testing/llama2.c-stories42M-gsm8k-sparse-only-compressed\",\n]\n\n\n@pytest.mark.parametrize(\n    \"base_model\", MODELS, ids=[\"no-checkpoint-recipe\", \"with-checkpoint-recipe\"]\n)\n@pytest.mark.parametrize(\n    \"save_compressed\", [True, False], ids=[\"save_compressed\", \"save_uncompressed\"]\n)\nclass TestLLMCompressorIntegration:\n    \"\"\"\n    e2e tests for axolotl.integrations.llm_compressor.LLMCompressorPlugin\n    \"\"\"\n\n    @require_llmcompressor\n    @require_torch_2_4_1\n    def test_llmcompressor_plugin(\n        self, temp_dir, base_model: str, save_compressed: bool\n    ):\n        from llmcompressor import active_session\n\n        # core cfg\n        cfg = DictDefault(\n            {\n                \"base_model\": base_model,\n                \"plugins\": [\"axolotl.integrations.llm_compressor.LLMCompressorPlugin\"],\n                \"sequence_len\": 1024,\n                \"val_set_size\": 0.05,\n                \"special_tokens\": {\"pad_token\": \"<|endoftext|>\"},\n                \"datasets\": [{\"path\": \"mhenrichsen/alpaca_2k_test\", \"type\": \"alpaca\"}],\n                \"num_epochs\": 1,\n                \"micro_batch_size\": 2,\n                \"gradient_accumulation_steps\": 2,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 1e-5,\n                \"optimizer\": \"adamw_torch_fused\",\n                \"lr_scheduler\": \"cosine\",\n                \"bf16\": \"auto\",\n                \"max_steps\": 5,\n                \"llmcompressor\": {\n                    \"recipe\": {\n                        \"finetuning_stage\": {\n                            \"finetuning_modifiers\": {\n                                \"ConstantPruningModifier\": {\n                                    \"targets\": [\n                                        \"re:.*q_proj.weight\",\n                                        \"re:.*k_proj.weight\",\n                                        \"re:.*v_proj.weight\",\n                                        \"re:.*o_proj.weight\",\n                                        \"re:.*gate_proj.weight\",\n                                        \"re:.*up_proj.weight\",\n                                        \"re:.*down_proj.weight\",\n                                    ],\n                                    \"start\": 0,\n                                },\n                            },\n                        },\n                    },\n                    \"save_compressed\": save_compressed,\n                },\n                \"save_first_step\": False,\n            }\n        )\n\n        prepare_plugins(cfg)\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        try:\n            train(cfg=cfg, dataset_meta=dataset_meta)\n            check_model_output_exists(temp_dir, cfg)\n            _check_llmcompressor_model_outputs(temp_dir, save_compressed)\n        finally:\n            active_session().reset()\n\n\ndef _check_llmcompressor_model_outputs(temp_dir, save_compressed):\n    if save_compressed:\n        assert (Path(temp_dir) / \"recipe.yaml\").exists()\n\n        from compressed_tensors import ModelCompressor\n        from compressed_tensors.config import Sparse24BitMaskConfig\n\n        compressor = ModelCompressor.from_pretrained(temp_dir)\n        assert compressor is not None\n        assert isinstance(compressor.sparsity_config, Sparse24BitMaskConfig)\n"
  },
  {
    "path": "tests/e2e/integrations/test_scattermoe_lora_kernels.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n# Copyright (c) Axolotl AI\n# Licensed under the Apache License, Version 2.0\n\n\"\"\"\nTests for ScatterMoE + LoRA Fused Kernels\n==========================================\n\nTests verify correctness of:\n1. Forward pass: fused kernel matches naive PyTorch reference\n2. Backward pass: gradients for LoRA A, B, and input match reference\n3. Frozen weights: expert weight gradients are correctly skipped\n4. Various configurations: top-k, grouped_in/out, with/without bias\n5. Numerical stability: bf16/fp16 outputs within tolerance of fp32 reference\n6. HFScatterMoEGatedMLP with sigmoid routing (GLM/DeepSeek/MiniMax M2)\n\nTest strategy:\n- Reference implementation uses pure PyTorch ops (no Triton)\n- ScatterMoE routing (flatten_sort_count) is shared between reference and kernel\n- Tolerances account for tf32 accumulation in Triton kernels\n\"\"\"\n\nfrom types import SimpleNamespace\n\nimport pytest\nimport torch\n\n# Skip all tests if CUDA is not available\npytestmark = pytest.mark.skipif(\n    not torch.cuda.is_available(),\n    reason=\"CUDA required for Triton kernels\",\n)\n\n_SMOE = \"axolotl.integrations.kernels.libs.scattermoe_lora\"\n\n\n# =============================================================================\n# Helpers\n# =============================================================================\n\n\ndef flatten_sort_count_ref(expert_idxs: torch.Tensor, num_experts: int):\n    \"\"\"Reference implementation of routing.\"\"\"\n    with torch.no_grad():\n        flat = expert_idxs.flatten()\n        sorted_expert_idxs, sorted_scattered_idxs = torch.sort(flat)\n        counts = flat.bincount(minlength=num_experts)\n        offsets = counts.cumsum(-1)\n    return sorted_expert_idxs, sorted_scattered_idxs, offsets\n\n\ndef reference_parallel_linear_lora(\n    X,\n    W,\n    k,\n    sorted_expert_idxs,\n    sorted_scattered_idxs,\n    lora_A,\n    lora_B,\n    scaling,\n    x_grouped=False,\n    y_grouped=False,\n    bias=None,\n):\n    \"\"\"\n    Pure PyTorch reference for: Y[i] = X[i] @ W[e] + scaling * (X[i] @ A[e]^T) @ B[e]^T + b[e]\n\n    Args:\n        X: [M, K] input (token order)\n        W: [E, K, N] expert weights\n        sorted_expert_idxs: [M*k] expert assignments (sorted)\n        sorted_scattered_idxs: [M*k] original token indices (sorted)\n        lora_A: [r*E, K] LoRA A weights\n        lora_B: [N, r*E] LoRA B weights\n        scaling: LoRA scaling factor\n    \"\"\"\n    E, K, N = W.shape\n    R = lora_A.size(0) // E\n    L = sorted_expert_idxs.size(0)  # M * k\n\n    output = torch.zeros(L, N, device=X.device, dtype=X.dtype)\n\n    for i in range(L):\n        e = sorted_expert_idxs[i].item()\n        if x_grouped:\n            x_i = X[i]\n        else:\n            token_idx = sorted_scattered_idxs[i].item() // k\n            x_i = X[token_idx]\n\n        w_e = W[e]  # [K, N]\n        a_e = lora_A[e * R : (e + 1) * R, :]  # [r, K]\n        b_e = lora_B[:, e * R : (e + 1) * R]  # [N, r]\n\n        # Y = X @ W + scaling * (X @ A^T) @ B^T\n        base = x_i @ w_e  # [N]\n        lora = scaling * ((x_i @ a_e.T) @ b_e.T)  # [N]\n        out_i = base + lora\n\n        if bias is not None:\n            out_i = out_i + bias[e]\n\n        if y_grouped:\n            output[i] = out_i\n        else:\n            output[sorted_scattered_idxs[i]] = out_i\n\n    return output\n\n\ndef reference_lora_backward(\n    grad_out,\n    X,\n    W,\n    lora_A,\n    lora_B,\n    scaling,\n    sorted_expert_idxs,\n    sorted_scattered_idxs,\n    expert_offsets,\n    k,\n    E,\n):\n    \"\"\"\n    Pure PyTorch reference for LoRA backward pass on grouped data.\n\n    Returns:\n        dX: [M*k, K] input gradient (in grouped order)\n        dA: [r*E, K] LoRA A gradient\n        dB: [N, r*E] LoRA B gradient\n    \"\"\"\n    R = lora_A.size(0) // E\n\n    dA = torch.zeros_like(lora_A)\n    dB = torch.zeros_like(lora_B)\n    dX = torch.zeros_like(X)\n\n    prev_offset = 0\n    for e in range(E):\n        curr_offset = expert_offsets[e].item()\n        if curr_offset > prev_offset:\n            dy_e = grad_out[prev_offset:curr_offset]  # [M_e, N]\n            x_e = X[prev_offset:curr_offset]  # [M_e, K]\n            a_e = lora_A[e * R : (e + 1) * R, :]  # [r, K]\n            b_e = lora_B[:, e * R : (e + 1) * R]  # [N, r]\n            w_e = W[e]  # [K, N]\n\n            # Input gradient: dX = dY @ W^T + scaling * (dY @ B) @ A\n            dx_base = dy_e @ w_e.T  # [M_e, K]\n            dy_b = dy_e @ b_e  # [M_e, r]\n            dx_lora = scaling * (dy_b @ a_e)  # [M_e, K]\n            dX[prev_offset:curr_offset] = dx_base + dx_lora\n\n            # LoRA A gradient: dA = scaling * (dY @ B)^T @ X\n            xa = x_e @ a_e.T  # [M_e, r]\n            dA[e * R : (e + 1) * R, :] = scaling * (dy_b.T @ x_e)\n\n            # LoRA B gradient: dB = scaling * dY^T @ (X @ A^T)\n            dB[:, e * R : (e + 1) * R] = scaling * (dy_e.T @ xa)\n\n        prev_offset = curr_offset\n\n    return dX, dA, dB\n\n\ndef make_test_data(\n    M=32,\n    K=64,\n    N=128,\n    E=4,\n    R=8,\n    k=2,\n    dtype=torch.float32,\n    device=\"cuda\",\n    seed=42,\n):\n    \"\"\"Create test data for ScatterMoE + LoRA tests.\"\"\"\n    torch.manual_seed(seed)\n\n    X = torch.randn(M, K, device=device, dtype=dtype)\n    W = torch.randn(E, K, N, device=device, dtype=dtype) * 0.02\n    lora_A = torch.randn(R * E, K, device=device, dtype=dtype) * 0.01\n    lora_B = torch.randn(N, R * E, device=device, dtype=dtype) * 0.01\n    scaling = 0.5\n\n    # Generate routing\n    selected_experts = torch.randint(0, E, (M, k), device=device)\n    sorted_expert_idxs, sorted_scattered_idxs, expert_offsets = flatten_sort_count_ref(\n        selected_experts, E\n    )\n\n    return {\n        \"X\": X,\n        \"W\": W,\n        \"lora_A\": lora_A,\n        \"lora_B\": lora_B,\n        \"scaling\": scaling,\n        \"k\": k,\n        \"E\": E,\n        \"R\": R,\n        \"sorted_expert_idxs\": sorted_expert_idxs,\n        \"sorted_scattered_idxs\": sorted_scattered_idxs,\n        \"expert_offsets\": expert_offsets,\n    }\n\n\n# =============================================================================\n# Test: Forward Pass Correctness\n# =============================================================================\n\n\nclass TestForwardPass:\n    \"\"\"Test forward pass of fused scatter2scatter_lora kernel.\"\"\"\n\n    def _run_forward_test(\n        self, M, K, N, E, R, k, dtype=torch.float32, atol=1e-2, rtol=1e-2\n    ):\n        from importlib import import_module\n\n        lora_ops = import_module(f\"{_SMOE}.kernels.lora_ops\")\n\n        data = make_test_data(M=M, K=K, N=N, E=E, R=R, k=k, dtype=dtype)\n\n        # Reference\n        ref_output = reference_parallel_linear_lora(\n            data[\"X\"],\n            data[\"W\"],\n            data[\"k\"],\n            data[\"sorted_expert_idxs\"],\n            data[\"sorted_scattered_idxs\"],\n            data[\"lora_A\"],\n            data[\"lora_B\"],\n            data[\"scaling\"],\n        )\n\n        # Kernel\n        kernel_output = lora_ops.scatter2scatter_lora(\n            X=data[\"X\"],\n            W=data[\"W\"],\n            sorted_expert_idxs=data[\"sorted_expert_idxs\"],\n            sorted_scattered_idxs=data[\"sorted_scattered_idxs\"],\n            k=data[\"k\"],\n            lora_A=data[\"lora_A\"],\n            lora_B=data[\"lora_B\"],\n            scaling=data[\"scaling\"],\n        )\n\n        torch.testing.assert_close(kernel_output, ref_output, atol=atol, rtol=rtol)\n\n    def test_basic(self):\n        \"\"\"Basic forward pass with small dimensions.\"\"\"\n        self._run_forward_test(M=16, K=64, N=64, E=4, R=8, k=1)\n\n    def test_topk2(self):\n        \"\"\"Forward pass with top-2 routing.\"\"\"\n        self._run_forward_test(M=32, K=64, N=128, E=4, R=8, k=2)\n\n    def test_larger_rank(self):\n        \"\"\"Forward pass with larger LoRA rank.\"\"\"\n        self._run_forward_test(M=16, K=128, N=128, E=8, R=32, k=2)\n\n    def test_small_rank(self):\n        \"\"\"Forward pass with very small LoRA rank.\"\"\"\n        self._run_forward_test(M=32, K=64, N=64, E=4, R=4, k=1)\n\n    def test_many_experts(self):\n        \"\"\"Forward with many experts, fewer tokens per expert.\"\"\"\n        self._run_forward_test(M=64, K=64, N=64, E=16, R=8, k=2)\n\n    def test_non_power_of_2_dims(self):\n        \"\"\"Test with dimensions that are not powers of 2.\"\"\"\n        self._run_forward_test(M=17, K=96, N=80, E=6, R=16, k=2, atol=2e-2, rtol=2e-2)\n\n    def test_single_token(self):\n        \"\"\"Test with a single token.\"\"\"\n        self._run_forward_test(M=1, K=64, N=64, E=4, R=8, k=1)\n\n    def test_bf16(self):\n        \"\"\"Test with bfloat16 precision.\"\"\"\n        self._run_forward_test(\n            M=32, K=64, N=128, E=4, R=8, k=2, dtype=torch.bfloat16, atol=5e-2, rtol=5e-2\n        )\n\n    def test_fp16(self):\n        \"\"\"Test with float16 precision.\"\"\"\n        self._run_forward_test(\n            M=32, K=64, N=128, E=4, R=8, k=2, dtype=torch.float16, atol=5e-2, rtol=5e-2\n        )\n\n\nclass TestForwardGrouped:\n    \"\"\"Test forward pass with grouped_in/grouped_out configurations.\"\"\"\n\n    def _make_grouped_data(self, M=32, K=64, N=128, E=4, R=8, k=2, dtype=torch.float32):\n        from importlib import import_module\n\n        base_ops = import_module(f\"{_SMOE}.kernels.ops\")\n\n        data = make_test_data(M=M, K=K, N=N, E=E, R=R, k=k, dtype=dtype)\n\n        # Create grouped X\n        grouped_X = base_ops.group(data[\"X\"], data[\"sorted_scattered_idxs\"], fan_out=k)\n        data[\"grouped_X\"] = grouped_X\n        return data\n\n    def test_x_grouped(self):\n        \"\"\"Forward with pre-grouped input.\"\"\"\n        from importlib import import_module\n\n        lora_ops = import_module(f\"{_SMOE}.kernels.lora_ops\")\n\n        data = self._make_grouped_data()\n\n        ref_output = reference_parallel_linear_lora(\n            data[\"grouped_X\"],\n            data[\"W\"],\n            data[\"k\"],\n            data[\"sorted_expert_idxs\"],\n            data[\"sorted_scattered_idxs\"],\n            data[\"lora_A\"],\n            data[\"lora_B\"],\n            data[\"scaling\"],\n            x_grouped=True,\n        )\n\n        kernel_output = lora_ops.scatter2scatter_lora(\n            X=data[\"grouped_X\"],\n            W=data[\"W\"],\n            sorted_expert_idxs=data[\"sorted_expert_idxs\"],\n            sorted_scattered_idxs=data[\"sorted_scattered_idxs\"],\n            k=1,  # When x_grouped, fan_out=1 (already expanded)\n            lora_A=data[\"lora_A\"],\n            lora_B=data[\"lora_B\"],\n            scaling=data[\"scaling\"],\n            x_grouped=True,\n        )\n\n        torch.testing.assert_close(kernel_output, ref_output, atol=1e-2, rtol=1e-2)\n\n    def test_y_grouped(self):\n        \"\"\"Forward with grouped output.\"\"\"\n        from importlib import import_module\n\n        lora_ops = import_module(f\"{_SMOE}.kernels.lora_ops\")\n\n        data = make_test_data()\n\n        ref_output = reference_parallel_linear_lora(\n            data[\"X\"],\n            data[\"W\"],\n            data[\"k\"],\n            data[\"sorted_expert_idxs\"],\n            data[\"sorted_scattered_idxs\"],\n            data[\"lora_A\"],\n            data[\"lora_B\"],\n            data[\"scaling\"],\n            y_grouped=True,\n        )\n\n        kernel_output = lora_ops.scatter2scatter_lora(\n            X=data[\"X\"],\n            W=data[\"W\"],\n            sorted_expert_idxs=data[\"sorted_expert_idxs\"],\n            sorted_scattered_idxs=data[\"sorted_scattered_idxs\"],\n            k=data[\"k\"],\n            lora_A=data[\"lora_A\"],\n            lora_B=data[\"lora_B\"],\n            scaling=data[\"scaling\"],\n            y_grouped=True,\n        )\n\n        torch.testing.assert_close(kernel_output, ref_output, atol=1e-2, rtol=1e-2)\n\n\n# =============================================================================\n# Test: Backward Pass Correctness (LoRA Gradients)\n# =============================================================================\n\n\nclass TestLoRAGradients:\n    \"\"\"Test backward LoRA gradient computation (dA, dB).\"\"\"\n\n    def _run_lora_grad_test(self, M, K, N, E, R, k, atol=1e-2, rtol=1e-2):\n        from importlib import import_module\n\n        lora_ops = import_module(f\"{_SMOE}.kernels.lora_ops\")\n        base_ops = import_module(f\"{_SMOE}.kernels.ops\")\n\n        data = make_test_data(M=M, K=K, N=N, E=E, R=R, k=k)\n\n        # Group X for backward\n        grouped_X = base_ops.group(data[\"X\"], data[\"sorted_scattered_idxs\"], fan_out=k)\n\n        # Create fake grad_out in grouped order\n        grad_out = torch.randn(\n            data[\"sorted_expert_idxs\"].size(0),\n            N,\n            device=\"cuda\",\n            dtype=torch.float32,\n        )\n\n        # Reference\n        _, ref_dA, ref_dB = reference_lora_backward(\n            grad_out,\n            grouped_X,\n            data[\"W\"],\n            data[\"lora_A\"],\n            data[\"lora_B\"],\n            data[\"scaling\"],\n            data[\"sorted_expert_idxs\"],\n            data[\"sorted_scattered_idxs\"],\n            data[\"expert_offsets\"],\n            k,\n            E,\n        )\n\n        # Kernel\n        kernel_dA, kernel_dB = lora_ops.group_bwd_lora(\n            DY=grad_out,\n            X=grouped_X,\n            lora_A=data[\"lora_A\"],\n            lora_B=data[\"lora_B\"],\n            expert_offsets=data[\"expert_offsets\"],\n            E=E,\n            scaling=data[\"scaling\"],\n        )\n\n        torch.testing.assert_close(kernel_dA, ref_dA, atol=atol, rtol=rtol)\n        torch.testing.assert_close(kernel_dB, ref_dB, atol=atol, rtol=rtol)\n\n    def test_basic_lora_grads(self):\n        self._run_lora_grad_test(M=32, K=64, N=128, E=4, R=8, k=2)\n\n    def test_small_rank(self):\n        self._run_lora_grad_test(M=16, K=64, N=64, E=4, R=4, k=1)\n\n    def test_larger_rank(self):\n        self._run_lora_grad_test(\n            M=16, K=128, N=128, E=8, R=32, k=2, atol=5e-2, rtol=5e-2\n        )\n\n    def test_many_experts(self):\n        self._run_lora_grad_test(M=64, K=64, N=64, E=16, R=8, k=2)\n\n    def test_single_token_per_expert(self):\n        \"\"\"Edge case: roughly 1 token per expert.\"\"\"\n        self._run_lora_grad_test(M=8, K=64, N=64, E=8, R=4, k=1)\n\n\n# =============================================================================\n# Test: Full Autograd (Forward + Backward) via torch.autograd\n# =============================================================================\n\n\nclass TestAutograd:\n    \"\"\"Test full autograd integration through ScatterMoELoRA.\"\"\"\n\n    def test_lora_receives_gradients(self):\n        \"\"\"LoRA A and B receive non-zero gradients; frozen W does not.\"\"\"\n        from importlib import import_module\n\n        pll = import_module(f\"{_SMOE}.parallel_linear_lora\")\n\n        M, K, N, E, R, k = 16, 64, 64, 4, 8, 2\n        data = make_test_data(M=M, K=K, N=N, E=E, R=R, k=k)\n\n        X = data[\"X\"].clone().requires_grad_(True)\n        W = data[\"W\"].clone().requires_grad_(False)  # Frozen\n        lora_A = data[\"lora_A\"].clone().requires_grad_(True)\n        lora_B = data[\"lora_B\"].clone().requires_grad_(True)\n\n        output = pll.ScatterMoELoRA.apply(\n            X,\n            W,\n            k,\n            data[\"sorted_expert_idxs\"],\n            data[\"sorted_scattered_idxs\"],\n            data[\"expert_offsets\"],\n            lora_A,\n            lora_B,\n            data[\"scaling\"],\n            None,\n            None,\n            False,\n            False,\n        )\n\n        loss = output.sum()\n        loss.backward()\n\n        # LoRA params should have gradients\n        assert lora_A.grad is not None, \"lora_A should have gradient\"\n        assert lora_B.grad is not None, \"lora_B should have gradient\"\n        assert lora_A.grad.abs().sum() > 0, \"lora_A gradient should be non-zero\"\n        assert lora_B.grad.abs().sum() > 0, \"lora_B gradient should be non-zero\"\n\n        # Input should have gradient (needed for upstream backprop)\n        assert X.grad is not None, \"X should have gradient\"\n        assert X.grad.abs().sum() > 0, \"X gradient should be non-zero\"\n\n    def test_input_gradient_matches_reference(self):\n        \"\"\"Input gradient from autograd matches pure PyTorch reference.\"\"\"\n        from importlib import import_module\n\n        pll = import_module(f\"{_SMOE}.parallel_linear_lora\")\n        base_ops = import_module(f\"{_SMOE}.kernels.ops\")\n\n        M, K, N, E, R, k = 16, 64, 64, 4, 8, 1\n        data = make_test_data(M=M, K=K, N=N, E=E, R=R, k=k)\n\n        # Autograd path\n        X_kern = data[\"X\"].clone().requires_grad_(True)\n        lora_A_kern = data[\"lora_A\"].clone().requires_grad_(True)\n        lora_B_kern = data[\"lora_B\"].clone().requires_grad_(True)\n\n        out_kern = pll.ScatterMoELoRA.apply(\n            X_kern,\n            data[\"W\"],\n            k,\n            data[\"sorted_expert_idxs\"],\n            data[\"sorted_scattered_idxs\"],\n            data[\"expert_offsets\"],\n            lora_A_kern,\n            lora_B_kern,\n            data[\"scaling\"],\n            None,\n            None,\n            False,\n            False,\n        )\n        grad_out = torch.randn_like(out_kern)\n        out_kern.backward(grad_out)\n\n        # Reference path\n        grouped_X = base_ops.group(data[\"X\"], data[\"sorted_scattered_idxs\"], fan_out=k)\n        grouped_grad = base_ops.group(\n            grad_out, data[\"sorted_scattered_idxs\"], fan_out=1\n        )\n\n        ref_dX, ref_dA, ref_dB = reference_lora_backward(\n            grouped_grad,\n            grouped_X,\n            data[\"W\"],\n            data[\"lora_A\"],\n            data[\"lora_B\"],\n            data[\"scaling\"],\n            data[\"sorted_expert_idxs\"],\n            data[\"sorted_scattered_idxs\"],\n            data[\"expert_offsets\"],\n            k,\n            E,\n        )\n\n        # Compare input gradient (for k=1, no reduction needed)\n        # ref_dX is in grouped (expert-sorted) order; X_kern.grad is in original order.\n        # Ungroup ref_dX by scattering back to original positions.\n        ref_dX_ungrouped = torch.zeros_like(ref_dX)\n        ref_dX_ungrouped[data[\"sorted_scattered_idxs\"]] = ref_dX\n        torch.testing.assert_close(X_kern.grad, ref_dX_ungrouped, atol=5e-2, rtol=5e-2)\n\n    def test_lora_gradient_matches_reference(self):\n        \"\"\"LoRA A/B gradients from autograd match reference.\"\"\"\n        from importlib import import_module\n\n        pll = import_module(f\"{_SMOE}.parallel_linear_lora\")\n        base_ops = import_module(f\"{_SMOE}.kernels.ops\")\n\n        M, K, N, E, R, k = 16, 64, 64, 4, 8, 1\n        data = make_test_data(M=M, K=K, N=N, E=E, R=R, k=k)\n\n        # Autograd path\n        X_kern = data[\"X\"].clone().requires_grad_(True)\n        lora_A_kern = data[\"lora_A\"].clone().requires_grad_(True)\n        lora_B_kern = data[\"lora_B\"].clone().requires_grad_(True)\n\n        out_kern = pll.ScatterMoELoRA.apply(\n            X_kern,\n            data[\"W\"],\n            k,\n            data[\"sorted_expert_idxs\"],\n            data[\"sorted_scattered_idxs\"],\n            data[\"expert_offsets\"],\n            lora_A_kern,\n            lora_B_kern,\n            data[\"scaling\"],\n            None,\n            None,\n            False,\n            False,\n        )\n        grad_out = torch.randn_like(out_kern)\n        out_kern.backward(grad_out)\n\n        # Reference path\n        grouped_X = base_ops.group(data[\"X\"], data[\"sorted_scattered_idxs\"], fan_out=k)\n        grouped_grad = base_ops.group(\n            grad_out, data[\"sorted_scattered_idxs\"], fan_out=1\n        )\n\n        _, ref_dA, ref_dB = reference_lora_backward(\n            grouped_grad,\n            grouped_X,\n            data[\"W\"],\n            data[\"lora_A\"],\n            data[\"lora_B\"],\n            data[\"scaling\"],\n            data[\"sorted_expert_idxs\"],\n            data[\"sorted_scattered_idxs\"],\n            data[\"expert_offsets\"],\n            k,\n            E,\n        )\n\n        torch.testing.assert_close(lora_A_kern.grad, ref_dA, atol=5e-2, rtol=5e-2)\n        torch.testing.assert_close(lora_B_kern.grad, ref_dB, atol=5e-2, rtol=5e-2)\n\n\n# =============================================================================\n# Test: Equivalence with Base ScatterMoE (scaling=0 should match base)\n# =============================================================================\n\n\nclass TestBaseEquivalence:\n    \"\"\"When scaling=0, fused kernel should match base scatter2scatter.\"\"\"\n\n    def test_zero_scaling_matches_base(self):\n        \"\"\"With scaling=0, LoRA contribution vanishes; should match base.\"\"\"\n        from importlib import import_module\n\n        lora_ops = import_module(f\"{_SMOE}.kernels.lora_ops\")\n        base_ops = import_module(f\"{_SMOE}.kernels.ops\")\n\n        data = make_test_data(M=32, K=64, N=128, E=4, R=8, k=2)\n\n        base_output = base_ops.scatter2scatter(\n            X=data[\"X\"],\n            W=data[\"W\"],\n            sorted_expert_idxs=data[\"sorted_expert_idxs\"],\n            sorted_scattered_idxs=data[\"sorted_scattered_idxs\"],\n            k=data[\"k\"],\n        )\n\n        lora_output = lora_ops.scatter2scatter_lora(\n            X=data[\"X\"],\n            W=data[\"W\"],\n            sorted_expert_idxs=data[\"sorted_expert_idxs\"],\n            sorted_scattered_idxs=data[\"sorted_scattered_idxs\"],\n            k=data[\"k\"],\n            lora_A=data[\"lora_A\"],\n            lora_B=data[\"lora_B\"],\n            scaling=0.0,\n        )\n\n        torch.testing.assert_close(lora_output, base_output, atol=1e-3, rtol=1e-3)\n\n    def test_zero_lora_weights_matches_base(self):\n        \"\"\"With A=0, B=0, should match base scatter2scatter.\"\"\"\n        from importlib import import_module\n\n        lora_ops = import_module(f\"{_SMOE}.kernels.lora_ops\")\n        base_ops = import_module(f\"{_SMOE}.kernels.ops\")\n\n        data = make_test_data(M=32, K=64, N=128, E=4, R=8, k=2)\n\n        zero_A = torch.zeros_like(data[\"lora_A\"])\n        zero_B = torch.zeros_like(data[\"lora_B\"])\n\n        base_output = base_ops.scatter2scatter(\n            X=data[\"X\"],\n            W=data[\"W\"],\n            sorted_expert_idxs=data[\"sorted_expert_idxs\"],\n            sorted_scattered_idxs=data[\"sorted_scattered_idxs\"],\n            k=data[\"k\"],\n        )\n\n        lora_output = lora_ops.scatter2scatter_lora(\n            X=data[\"X\"],\n            W=data[\"W\"],\n            sorted_expert_idxs=data[\"sorted_expert_idxs\"],\n            sorted_scattered_idxs=data[\"sorted_scattered_idxs\"],\n            k=data[\"k\"],\n            lora_A=zero_A,\n            lora_B=zero_B,\n            scaling=1.0,\n        )\n\n        torch.testing.assert_close(lora_output, base_output, atol=1e-3, rtol=1e-3)\n\n\n# =============================================================================\n# Test: LoRA Additivity\n# =============================================================================\n\n\nclass TestLoRAAdditivity:\n    \"\"\"Test that the LoRA component is correctly additive.\"\"\"\n\n    def test_lora_additivity(self):\n        \"\"\"\n        Verify: fused(X, W, A, B, s) == base(X, W) + s * per_expert_lora(X, A, B)\n        \"\"\"\n        from importlib import import_module\n\n        lora_ops = import_module(f\"{_SMOE}.kernels.lora_ops\")\n        base_ops = import_module(f\"{_SMOE}.kernels.ops\")\n\n        data = make_test_data(M=32, K=64, N=128, E=4, R=8, k=2)\n\n        # Base output (no LoRA)\n        base_output = base_ops.scatter2scatter(\n            X=data[\"X\"],\n            W=data[\"W\"],\n            sorted_expert_idxs=data[\"sorted_expert_idxs\"],\n            sorted_scattered_idxs=data[\"sorted_scattered_idxs\"],\n            k=data[\"k\"],\n        )\n\n        # Fused output\n        fused_output = lora_ops.scatter2scatter_lora(\n            X=data[\"X\"],\n            W=data[\"W\"],\n            sorted_expert_idxs=data[\"sorted_expert_idxs\"],\n            sorted_scattered_idxs=data[\"sorted_scattered_idxs\"],\n            k=data[\"k\"],\n            lora_A=data[\"lora_A\"],\n            lora_B=data[\"lora_B\"],\n            scaling=data[\"scaling\"],\n        )\n\n        # Compute LoRA contribution manually (reference)\n        lora_only = reference_parallel_linear_lora(\n            data[\"X\"],\n            torch.zeros_like(data[\"W\"]),\n            data[\"k\"],\n            data[\"sorted_expert_idxs\"],\n            data[\"sorted_scattered_idxs\"],\n            data[\"lora_A\"],\n            data[\"lora_B\"],\n            data[\"scaling\"],\n        )\n\n        # fused = base + lora\n        expected = base_output + lora_only\n        torch.testing.assert_close(fused_output, expected, atol=2e-2, rtol=2e-2)\n\n\n# =============================================================================\n# Test: ParallelExperts module integration\n# =============================================================================\n\n\nclass TestParallelExpertsModule:\n    \"\"\"Test the ParallelExperts module with LoRA.\"\"\"\n\n    def test_set_and_clear_lora(self):\n        \"\"\"Test set_lora/clear_lora lifecycle.\"\"\"\n        from importlib import import_module\n\n        lora_module = import_module(f\"{_SMOE}.lora_ops\")\n\n        pe = lora_module.ParallelExperts(4, 64, 128).cuda()\n\n        A = torch.randn(32, 64, device=\"cuda\")  # r=8, E=4\n        B = torch.randn(128, 32, device=\"cuda\")\n        pe.set_lora(A, B, 0.5)\n\n        assert pe._lora_A is A\n        assert pe._lora_B is B\n        assert pe._lora_scaling == 0.5\n\n        pe.clear_lora()\n        assert pe._lora_A is None\n        assert pe._lora_B is None\n\n    def test_forward_with_lora(self):\n        \"\"\"ParallelExperts forward with LoRA matches reference.\"\"\"\n        from importlib import import_module\n\n        lora_module = import_module(f\"{_SMOE}.lora_ops\")\n\n        E, K, N, R = 4, 64, 128, 8\n        M, k = 16, 2\n        data = make_test_data(M=M, K=K, N=N, E=E, R=R, k=k)\n\n        pe = lora_module.ParallelExperts(E, K, N).cuda()\n        # Set weights to match test data\n        with torch.no_grad():\n            pe.weight.copy_(data[\"W\"].permute(0, 2, 1))  # [E, N, K]\n\n        pe.set_lora(data[\"lora_A\"], data[\"lora_B\"], data[\"scaling\"])\n\n        output = pe(\n            data[\"X\"],\n            k,\n            data[\"sorted_expert_idxs\"],\n            data[\"sorted_scattered_idxs\"],\n            data[\"expert_offsets\"],\n        )\n\n        ref = reference_parallel_linear_lora(\n            data[\"X\"],\n            data[\"W\"],\n            k,\n            data[\"sorted_expert_idxs\"],\n            data[\"sorted_scattered_idxs\"],\n            data[\"lora_A\"],\n            data[\"lora_B\"],\n            data[\"scaling\"],\n        )\n\n        torch.testing.assert_close(output, ref, atol=2e-2, rtol=2e-2)\n\n\n# =============================================================================\n# Test: Edge Cases\n# =============================================================================\n\n\nclass TestEdgeCases:\n    \"\"\"Edge cases and boundary conditions.\"\"\"\n\n    def test_all_tokens_one_expert(self):\n        \"\"\"All tokens routed to a single expert.\"\"\"\n        from importlib import import_module\n\n        lora_ops = import_module(f\"{_SMOE}.kernels.lora_ops\")\n\n        M, K, N, E, R, k = 16, 64, 64, 4, 8, 1\n        torch.manual_seed(42)\n\n        X = torch.randn(M, K, device=\"cuda\")\n        W = torch.randn(E, K, N, device=\"cuda\") * 0.02\n        lora_A = torch.randn(R * E, K, device=\"cuda\") * 0.01\n        lora_B = torch.randn(N, R * E, device=\"cuda\") * 0.01\n\n        # All tokens go to expert 0\n        selected_experts = torch.zeros(M, k, device=\"cuda\", dtype=torch.long)\n        sorted_expert_idxs, sorted_scattered_idxs, expert_offsets = (\n            flatten_sort_count_ref(selected_experts, E)\n        )\n\n        ref = reference_parallel_linear_lora(\n            X,\n            W,\n            k,\n            sorted_expert_idxs,\n            sorted_scattered_idxs,\n            lora_A,\n            lora_B,\n            0.5,\n        )\n\n        kernel = lora_ops.scatter2scatter_lora(\n            X=X,\n            W=W,\n            sorted_expert_idxs=sorted_expert_idxs,\n            sorted_scattered_idxs=sorted_scattered_idxs,\n            k=k,\n            lora_A=lora_A,\n            lora_B=lora_B,\n            scaling=0.5,\n        )\n\n        torch.testing.assert_close(kernel, ref, atol=1e-2, rtol=1e-2)\n\n    def test_empty_experts(self):\n        \"\"\"Some experts have no tokens assigned.\"\"\"\n        from importlib import import_module\n\n        lora_ops = import_module(f\"{_SMOE}.kernels.lora_ops\")\n\n        M, K, N, E, R, k = 8, 64, 64, 8, 4, 1\n        torch.manual_seed(42)\n\n        X = torch.randn(M, K, device=\"cuda\")\n        W = torch.randn(E, K, N, device=\"cuda\") * 0.02\n        lora_A = torch.randn(R * E, K, device=\"cuda\") * 0.01\n        lora_B = torch.randn(N, R * E, device=\"cuda\") * 0.01\n\n        # Only use experts 0 and 1\n        selected_experts = torch.randint(0, 2, (M, k), device=\"cuda\")\n        sorted_expert_idxs, sorted_scattered_idxs, expert_offsets = (\n            flatten_sort_count_ref(selected_experts, E)\n        )\n\n        ref = reference_parallel_linear_lora(\n            X,\n            W,\n            k,\n            sorted_expert_idxs,\n            sorted_scattered_idxs,\n            lora_A,\n            lora_B,\n            0.5,\n        )\n\n        kernel = lora_ops.scatter2scatter_lora(\n            X=X,\n            W=W,\n            sorted_expert_idxs=sorted_expert_idxs,\n            sorted_scattered_idxs=sorted_scattered_idxs,\n            k=k,\n            lora_A=lora_A,\n            lora_B=lora_B,\n            scaling=0.5,\n        )\n\n        torch.testing.assert_close(kernel, ref, atol=1e-2, rtol=1e-2)\n\n\n# =============================================================================\n# Test: Optimization 1 - Fused dX Kernel\n# =============================================================================\n\n\nclass TestFusedDX:\n    \"\"\"Test fused backward dX kernel: dX = dY @ W^T + scaling * (dY @ B) @ A.\"\"\"\n\n    def _run_fused_dX_test(\n        self, M, K, N, E, R, k, dtype=torch.float32, atol=5e-2, rtol=5e-2\n    ):\n        from importlib import import_module\n\n        lora_ops = import_module(f\"{_SMOE}.kernels.lora_ops\")\n        base_ops = import_module(f\"{_SMOE}.kernels.ops\")\n        pll = import_module(f\"{_SMOE}.parallel_linear_lora\")\n\n        data = make_test_data(M=M, K=K, N=N, E=E, R=R, k=k, dtype=dtype)\n\n        # Create dummy grad_out in grouped order\n        grad_out = torch.randn(\n            data[\"sorted_expert_idxs\"].size(0), N, device=\"cuda\", dtype=dtype\n        )\n        grouped_grad = base_ops.group(\n            grad_out,\n            data[\"sorted_scattered_idxs\"],\n            fan_out=1,\n        )\n\n        # Reference: separate scatter2scatter(DY, W^T) + _compute_lora_input_grad\n        ref_base = base_ops.scatter2scatter(\n            X=grouped_grad,\n            x_grouped=True,\n            W=data[\"W\"].permute(0, 2, 1),\n            sorted_expert_idxs=data[\"sorted_expert_idxs\"],\n            sorted_scattered_idxs=data[\"sorted_scattered_idxs\"],\n            k=1,\n            y_grouped=False,\n        )\n\n        ref_lora = pll._compute_lora_input_grad(\n            grouped_grad,\n            data[\"lora_A\"],\n            data[\"lora_B\"],\n            data[\"expert_offsets\"],\n            E,\n            data[\"scaling\"],\n        )\n        # Scatter lora from grouped to ungrouped order\n        ref_lora_ungrouped = torch.zeros_like(ref_base)\n        ref_lora_ungrouped[data[\"sorted_scattered_idxs\"]] = ref_lora\n        ref_total = ref_base + ref_lora_ungrouped\n\n        # Fused kernel\n        fused_result = lora_ops.scatter2scatter_lora_dX(\n            DY=grouped_grad,\n            W=data[\"W\"],\n            sorted_expert_idxs=data[\"sorted_expert_idxs\"],\n            sorted_scattered_idxs=data[\"sorted_scattered_idxs\"],\n            k=1,\n            lora_A=data[\"lora_A\"],\n            lora_B=data[\"lora_B\"],\n            scaling=data[\"scaling\"],\n            dy_grouped=True,\n            dx_grouped=False,\n        )\n\n        torch.testing.assert_close(fused_result, ref_total, atol=atol, rtol=rtol)\n\n    def test_basic(self):\n        self._run_fused_dX_test(M=32, K=64, N=128, E=4, R=8, k=2)\n\n    def test_large(self):\n        self._run_fused_dX_test(M=256, K=256, N=512, E=8, R=16, k=2)\n\n    def test_single_expert(self):\n        self._run_fused_dX_test(M=64, K=128, N=256, E=1, R=8, k=1)\n\n    def test_k1(self):\n        self._run_fused_dX_test(M=64, K=64, N=128, E=4, R=8, k=1)\n\n    def test_bf16(self):\n        self._run_fused_dX_test(\n            M=64,\n            K=128,\n            N=256,\n            E=4,\n            R=16,\n            k=2,\n            dtype=torch.bfloat16,\n            atol=1e-1,\n            rtol=1e-1,\n        )\n\n    def test_grouped_output(self):\n        \"\"\"Test fused dX with dx_grouped=True.\"\"\"\n        from importlib import import_module\n\n        lora_ops = import_module(f\"{_SMOE}.kernels.lora_ops\")\n        base_ops = import_module(f\"{_SMOE}.kernels.ops\")\n        pll = import_module(f\"{_SMOE}.parallel_linear_lora\")\n\n        M, K, N, E, R, k = 32, 64, 128, 4, 8, 2\n        data = make_test_data(M=M, K=K, N=N, E=E, R=R, k=k)\n\n        grad_out = torch.randn(data[\"sorted_expert_idxs\"].size(0), N, device=\"cuda\")\n        grouped_grad = base_ops.group(\n            grad_out, data[\"sorted_scattered_idxs\"], fan_out=1\n        )\n\n        # Reference: grouped output\n        ref_base = base_ops.scatter2scatter(\n            X=grouped_grad,\n            x_grouped=True,\n            W=data[\"W\"].permute(0, 2, 1),\n            sorted_expert_idxs=data[\"sorted_expert_idxs\"],\n            sorted_scattered_idxs=data[\"sorted_scattered_idxs\"],\n            k=1,\n            y_grouped=True,  # grouped output\n        )\n\n        ref_lora = pll._compute_lora_input_grad(\n            grouped_grad,\n            data[\"lora_A\"],\n            data[\"lora_B\"],\n            data[\"expert_offsets\"],\n            E,\n            data[\"scaling\"],\n        )\n        ref_total = ref_base + ref_lora\n\n        # Fused kernel with grouped output\n        fused_result = lora_ops.scatter2scatter_lora_dX(\n            DY=grouped_grad,\n            W=data[\"W\"],\n            sorted_expert_idxs=data[\"sorted_expert_idxs\"],\n            sorted_scattered_idxs=data[\"sorted_scattered_idxs\"],\n            k=1,\n            lora_A=data[\"lora_A\"],\n            lora_B=data[\"lora_B\"],\n            scaling=data[\"scaling\"],\n            dy_grouped=True,\n            dx_grouped=True,\n        )\n\n        torch.testing.assert_close(fused_result, ref_total, atol=5e-2, rtol=5e-2)\n\n    def test_autograd_with_fused_dX(self):\n        \"\"\"Full autograd round-trip with use_fused_dX=True.\"\"\"\n        from importlib import import_module\n\n        pll = import_module(f\"{_SMOE}.parallel_linear_lora\")\n\n        M, K, N, E, R, k = 32, 64, 128, 4, 8, 2\n        data = make_test_data(M=M, K=K, N=N, E=E, R=R, k=k)\n\n        # Run without fused dX\n        X1 = data[\"X\"].clone().requires_grad_(True)\n        A1 = data[\"lora_A\"].clone().requires_grad_(True)\n        B1 = data[\"lora_B\"].clone().requires_grad_(True)\n        out1 = pll.ScatterMoELoRA.apply(\n            X1,\n            data[\"W\"],\n            k,\n            data[\"sorted_expert_idxs\"],\n            data[\"sorted_scattered_idxs\"],\n            data[\"expert_offsets\"],\n            A1,\n            B1,\n            data[\"scaling\"],\n            None,\n            None,\n            False,\n            False,\n            False,  # use_fused_dX=False\n        )\n        out1.sum().backward()\n\n        # Run with fused dX\n        X2 = data[\"X\"].clone().requires_grad_(True)\n        A2 = data[\"lora_A\"].clone().requires_grad_(True)\n        B2 = data[\"lora_B\"].clone().requires_grad_(True)\n        out2 = pll.ScatterMoELoRA.apply(\n            X2,\n            data[\"W\"],\n            k,\n            data[\"sorted_expert_idxs\"],\n            data[\"sorted_scattered_idxs\"],\n            data[\"expert_offsets\"],\n            A2,\n            B2,\n            data[\"scaling\"],\n            None,\n            None,\n            False,\n            False,\n            True,  # use_fused_dX=True\n        )\n        out2.sum().backward()\n\n        # Forward should be identical\n        torch.testing.assert_close(out1, out2, atol=1e-5, rtol=1e-5)\n\n        # Gradients should match\n        torch.testing.assert_close(X1.grad, X2.grad, atol=5e-2, rtol=5e-2)\n        torch.testing.assert_close(A1.grad, A2.grad, atol=5e-2, rtol=5e-2)\n        torch.testing.assert_close(B1.grad, B2.grad, atol=5e-2, rtol=5e-2)\n\n\n# =============================================================================\n# Test: Optimization 2 - Fused Gather Backward\n# =============================================================================\n\n\nclass TestFusedGatherBackward:\n    \"\"\"Test fused gather + backward dA/dB kernel.\"\"\"\n\n    def _run_fused_gather_test(\n        self, M, K, N, E, R, k, dtype=torch.float32, atol=5e-2, rtol=5e-2\n    ):\n        from importlib import import_module\n\n        lora_ops = import_module(f\"{_SMOE}.kernels.lora_ops\")\n        base_ops = import_module(f\"{_SMOE}.kernels.ops\")\n\n        data = make_test_data(M=M, K=K, N=N, E=E, R=R, k=k, dtype=dtype)\n\n        # Create grad_out in ungrouped order (M*k, N)\n        M_total = data[\"sorted_expert_idxs\"].size(0)\n        grad_out = torch.randn(M_total, N, device=\"cuda\", dtype=dtype)\n\n        # Reference: group() + group_bwd_lora()\n        grouped_grad = base_ops.group(\n            grad_out, data[\"sorted_scattered_idxs\"], fan_out=1\n        )\n        grouped_x = base_ops.group(data[\"X\"], data[\"sorted_scattered_idxs\"], fan_out=k)\n\n        ref_dA, ref_dB = lora_ops.group_bwd_lora(\n            DY=grouped_grad,\n            X=grouped_x,\n            lora_A=data[\"lora_A\"],\n            lora_B=data[\"lora_B\"],\n            expert_offsets=data[\"expert_offsets\"],\n            E=E,\n            scaling=data[\"scaling\"],\n        )\n\n        # Fused kernel: no group() calls\n        fused_dA, fused_dB = lora_ops.group_bwd_lora_fused(\n            DY=grad_out,\n            X=data[\"X\"],\n            lora_A=data[\"lora_A\"],\n            lora_B=data[\"lora_B\"],\n            expert_offsets=data[\"expert_offsets\"],\n            sorted_scattered_idxs=data[\"sorted_scattered_idxs\"],\n            E=E,\n            k=k,\n            scaling=data[\"scaling\"],\n        )\n\n        torch.testing.assert_close(fused_dA, ref_dA, atol=atol, rtol=rtol)\n        torch.testing.assert_close(fused_dB, ref_dB, atol=atol, rtol=rtol)\n\n    def test_basic(self):\n        self._run_fused_gather_test(M=32, K=64, N=128, E=4, R=8, k=2)\n\n    def test_large(self):\n        self._run_fused_gather_test(M=256, K=256, N=512, E=8, R=16, k=2)\n\n    def test_single_expert(self):\n        self._run_fused_gather_test(M=64, K=128, N=256, E=1, R=8, k=1)\n\n    def test_k1(self):\n        self._run_fused_gather_test(M=64, K=64, N=128, E=4, R=8, k=1)\n\n    def test_many_experts(self):\n        self._run_fused_gather_test(M=128, K=64, N=128, E=16, R=8, k=4)\n\n    def test_bf16(self):\n        self._run_fused_gather_test(\n            M=64,\n            K=128,\n            N=256,\n            E=4,\n            R=16,\n            k=2,\n            dtype=torch.bfloat16,\n            atol=1e-1,\n            rtol=1e-1,\n        )\n\n    def test_autograd_with_fused_gather(self):\n        \"\"\"Full autograd round-trip with use_fused_gather=True.\"\"\"\n        from importlib import import_module\n\n        pll = import_module(f\"{_SMOE}.parallel_linear_lora\")\n\n        M, K, N, E, R, k = 32, 64, 128, 4, 8, 2\n        data = make_test_data(M=M, K=K, N=N, E=E, R=R, k=k)\n\n        # Run without fused gather\n        X1 = data[\"X\"].clone().requires_grad_(True)\n        A1 = data[\"lora_A\"].clone().requires_grad_(True)\n        B1 = data[\"lora_B\"].clone().requires_grad_(True)\n        out1 = pll.ScatterMoELoRA.apply(\n            X1,\n            data[\"W\"],\n            k,\n            data[\"sorted_expert_idxs\"],\n            data[\"sorted_scattered_idxs\"],\n            data[\"expert_offsets\"],\n            A1,\n            B1,\n            data[\"scaling\"],\n            None,\n            None,\n            False,\n            False,\n            False,\n            False,  # use_fused_dX=False, use_fused_gather=False\n        )\n        out1.sum().backward()\n\n        # Run with fused gather\n        X2 = data[\"X\"].clone().requires_grad_(True)\n        A2 = data[\"lora_A\"].clone().requires_grad_(True)\n        B2 = data[\"lora_B\"].clone().requires_grad_(True)\n        out2 = pll.ScatterMoELoRA.apply(\n            X2,\n            data[\"W\"],\n            k,\n            data[\"sorted_expert_idxs\"],\n            data[\"sorted_scattered_idxs\"],\n            data[\"expert_offsets\"],\n            A2,\n            B2,\n            data[\"scaling\"],\n            None,\n            None,\n            False,\n            False,\n            False,\n            True,  # use_fused_dX=False, use_fused_gather=True\n        )\n        out2.sum().backward()\n\n        # Forward identical\n        torch.testing.assert_close(out1, out2, atol=1e-5, rtol=1e-5)\n\n        # dA/dB should match\n        torch.testing.assert_close(A1.grad, A2.grad, atol=5e-2, rtol=5e-2)\n        torch.testing.assert_close(B1.grad, B2.grad, atol=5e-2, rtol=5e-2)\n        # dX should also match (same path for dX)\n        torch.testing.assert_close(X1.grad, X2.grad, atol=5e-2, rtol=5e-2)\n\n\n# =============================================================================\n# Test: Optimization 3 - Token Rounding\n# =============================================================================\n\n\nclass TestTokenRounding:\n    \"\"\"Test token rounding utility and its integration with backward kernels.\"\"\"\n\n    def test_round_expert_counts_basic(self):\n        \"\"\"Verify round_expert_counts produces correct shapes and values.\"\"\"\n        from importlib import import_module\n\n        lora_ops = import_module(f\"{_SMOE}.kernels.lora_ops\")\n\n        M, K, N, E, R, k = 32, 64, 128, 4, 8, 2\n        data = make_test_data(M=M, K=K, N=N, E=E, R=R, k=k)\n\n        padded_ei, padded_si, padded_offsets, real_offsets = (\n            lora_ops.round_expert_counts(\n                data[\"sorted_expert_idxs\"],\n                data[\"sorted_scattered_idxs\"],\n                data[\"expert_offsets\"],\n                E=E,\n                block_m=lora_ops.BLOCK_M,\n            )\n        )\n\n        # Real offsets should match original\n        torch.testing.assert_close(real_offsets, data[\"expert_offsets\"])\n\n        # Padded offsets should be >= real offsets\n        assert (padded_offsets >= real_offsets).all(), (\n            \"Padded offsets should be >= real offsets\"\n        )\n\n        # Each expert's padded count should be multiple of BLOCK_M (if non-zero)\n        prev = 0\n        for e in range(E):\n            count = padded_offsets[e].item() - prev\n            real_count = real_offsets[e].item() - (\n                real_offsets[e - 1].item() if e > 0 else 0\n            )\n            if real_count > 0:\n                assert count % lora_ops.BLOCK_M == 0, (\n                    f\"Expert {e}: padded count {count} not multiple of {lora_ops.BLOCK_M}\"\n                )\n                assert count >= real_count, (\n                    f\"Expert {e}: padded count {count} < real count {real_count}\"\n                )\n            prev = padded_offsets[e].item()\n\n    def test_round_with_fused_gather(self):\n        \"\"\"Token rounding + fused gather gives same result as plain fused gather.\"\"\"\n        from importlib import import_module\n\n        lora_ops = import_module(f\"{_SMOE}.kernels.lora_ops\")\n        base_ops = import_module(f\"{_SMOE}.kernels.ops\")\n\n        M, K, N, E, R, k = 64, 64, 128, 4, 8, 2\n        data = make_test_data(M=M, K=K, N=N, E=E, R=R, k=k)\n\n        M_total = data[\"sorted_expert_idxs\"].size(0)\n        grad_out = torch.randn(M_total, N, device=\"cuda\")\n\n        # Reference: group() + group_bwd_lora() (the gold standard)\n        grouped_grad = base_ops.group(\n            grad_out, data[\"sorted_scattered_idxs\"], fan_out=1\n        )\n        grouped_x = base_ops.group(data[\"X\"], data[\"sorted_scattered_idxs\"], fan_out=k)\n        ref_dA, ref_dB = lora_ops.group_bwd_lora(\n            DY=grouped_grad,\n            X=grouped_x,\n            lora_A=data[\"lora_A\"],\n            lora_B=data[\"lora_B\"],\n            expert_offsets=data[\"expert_offsets\"],\n            E=E,\n            scaling=data[\"scaling\"],\n        )\n\n        # Apply token rounding\n        padded_ei, padded_si, padded_offsets, real_offsets = (\n            lora_ops.round_expert_counts(\n                data[\"sorted_expert_idxs\"],\n                data[\"sorted_scattered_idxs\"],\n                data[\"expert_offsets\"],\n                E=E,\n            )\n        )\n\n        # Fused gather with token rounding\n        rounded_dA, rounded_dB = lora_ops.group_bwd_lora_fused(\n            DY=grad_out,\n            X=data[\"X\"],\n            lora_A=data[\"lora_A\"],\n            lora_B=data[\"lora_B\"],\n            expert_offsets=padded_offsets,\n            sorted_scattered_idxs=padded_si,\n            E=E,\n            k=k,\n            scaling=data[\"scaling\"],\n            real_expert_offsets=real_offsets,\n        )\n\n        torch.testing.assert_close(rounded_dA, ref_dA, atol=5e-2, rtol=5e-2)\n        torch.testing.assert_close(rounded_dB, ref_dB, atol=5e-2, rtol=5e-2)\n\n    def test_empty_experts_with_rounding(self):\n        \"\"\"Token rounding handles experts with 0 tokens correctly.\"\"\"\n        from importlib import import_module\n\n        lora_ops = import_module(f\"{_SMOE}.kernels.lora_ops\")\n\n        E, k = 8, 1\n        M = 8\n        torch.manual_seed(42)\n\n        # Only use experts 0 and 1 (rest have 0 tokens)\n        selected_experts = torch.randint(0, 2, (M, k), device=\"cuda\")\n        sorted_expert_idxs, sorted_scattered_idxs, expert_offsets = (\n            flatten_sort_count_ref(selected_experts, E)\n        )\n\n        padded_ei, padded_si, padded_offsets, real_offsets = (\n            lora_ops.round_expert_counts(\n                sorted_expert_idxs,\n                sorted_scattered_idxs,\n                expert_offsets,\n                E=E,\n            )\n        )\n\n        # Verify empty experts have same count (0)\n        for e in range(E):\n            real_count = real_offsets[e].item() - (\n                real_offsets[e - 1].item() if e > 0 else 0\n            )\n            padded_count = padded_offsets[e].item() - (\n                padded_offsets[e - 1].item() if e > 0 else 0\n            )\n            if real_count == 0:\n                assert padded_count == 0, (\n                    f\"Expert {e}: empty expert should have padded_count=0, got {padded_count}\"\n                )\n\n\n# =============================================================================\n# Test: Combined Optimizations\n# =============================================================================\n\n\nclass TestCombinedOptimizations:\n    \"\"\"Test all optimizations together.\"\"\"\n\n    def test_fused_dX_and_fused_gather(self):\n        \"\"\"Both fused dX and fused gather together.\"\"\"\n        from importlib import import_module\n\n        pll = import_module(f\"{_SMOE}.parallel_linear_lora\")\n\n        M, K, N, E, R, k = 64, 128, 256, 4, 8, 2\n        data = make_test_data(M=M, K=K, N=N, E=E, R=R, k=k)\n\n        # Baseline: no optimizations\n        X1 = data[\"X\"].clone().requires_grad_(True)\n        A1 = data[\"lora_A\"].clone().requires_grad_(True)\n        B1 = data[\"lora_B\"].clone().requires_grad_(True)\n        out1 = pll.ScatterMoELoRA.apply(\n            X1,\n            data[\"W\"],\n            k,\n            data[\"sorted_expert_idxs\"],\n            data[\"sorted_scattered_idxs\"],\n            data[\"expert_offsets\"],\n            A1,\n            B1,\n            data[\"scaling\"],\n            None,\n            None,\n            False,\n            False,\n            False,\n            False,  # no optimizations\n        )\n        out1.sum().backward()\n\n        # Both optimizations\n        X2 = data[\"X\"].clone().requires_grad_(True)\n        A2 = data[\"lora_A\"].clone().requires_grad_(True)\n        B2 = data[\"lora_B\"].clone().requires_grad_(True)\n        out2 = pll.ScatterMoELoRA.apply(\n            X2,\n            data[\"W\"],\n            k,\n            data[\"sorted_expert_idxs\"],\n            data[\"sorted_scattered_idxs\"],\n            data[\"expert_offsets\"],\n            A2,\n            B2,\n            data[\"scaling\"],\n            None,\n            None,\n            False,\n            False,\n            True,\n            True,  # use_fused_dX=True, use_fused_gather=True\n        )\n        out2.sum().backward()\n\n        # Forward identical\n        torch.testing.assert_close(out1, out2, atol=1e-5, rtol=1e-5)\n\n        # All gradients match\n        torch.testing.assert_close(X1.grad, X2.grad, atol=5e-2, rtol=5e-2)\n        torch.testing.assert_close(A1.grad, A2.grad, atol=5e-2, rtol=5e-2)\n        torch.testing.assert_close(B1.grad, B2.grad, atol=5e-2, rtol=5e-2)\n\n\n# =============================================================================\n# Test: HFScatterMoEGatedMLP with Sigmoid Routing\n# =============================================================================\n\n\ndef _reference_moe_forward(\n    hidden_states,\n    gate_weight,\n    gate_up_proj,\n    down_proj,\n    act_fn,\n    routing_weights,\n    selected_experts,\n    num_experts,\n):\n    \"\"\"Pure PyTorch reference for a full MoE forward pass.\n\n    Args:\n        hidden_states: [T, H]\n        gate_weight: [E, H]\n        gate_up_proj: [E, 2*FF, H]\n        down_proj: [E, H, FF]\n        act_fn: activation function (e.g. torch.nn.SiLU())\n        routing_weights: [T, K] routing weights\n        selected_experts: [T, K] expert indices\n        num_experts: int\n\n    Returns:\n        output: [T, H]\n    \"\"\"\n    T, H = hidden_states.shape\n    K = selected_experts.shape[1]\n    output = torch.zeros(T, H, device=hidden_states.device, dtype=hidden_states.dtype)\n\n    for t in range(T):\n        for j in range(K):\n            e = selected_experts[t, j].item()\n            w = routing_weights[t, j].item()\n\n            # gate_up projection\n            gup = hidden_states[t] @ gate_up_proj[e].T  # [2*I]\n            I_dim = gup.shape[0] // 2\n            gates = gup[:I_dim]\n            up = gup[I_dim:]\n\n            # activation\n            h = act_fn(gates) * up\n\n            # down projection\n            out = h @ down_proj[e].T  # [H]\n\n            output[t] += w * out\n\n    return output\n\n\ndef _make_mock_sigmoid_moe_block(\n    T=16, H=64, FF=32, E=8, K=2, n_group=2, topk_group=1, bias_on_gate=True\n):\n    \"\"\"Create a mock MoE block with sigmoid routing for GPU testing.\"\"\"\n    gate_up_proj = torch.randn(E, 2 * FF, H, device=\"cuda\") * 0.02\n    down_proj = torch.randn(E, H, FF, device=\"cuda\") * 0.02\n    act_fn = torch.nn.SiLU()\n\n    experts = SimpleNamespace(\n        gate_up_proj=gate_up_proj,\n        down_proj=down_proj,\n        act_fn=act_fn,\n        num_experts=E,\n    )\n\n    if bias_on_gate:\n        gate = SimpleNamespace(\n            weight=torch.randn(E, H, device=\"cuda\") * 0.1,\n            e_score_correction_bias=torch.zeros(E, device=\"cuda\"),\n        )\n        moe_block = SimpleNamespace(\n            gate=gate,\n            experts=experts,\n            top_k=K,\n            n_routed_experts=E,\n            n_group=n_group,\n            topk_group=topk_group,\n            norm_topk_prob=True,\n            routed_scaling_factor=1.0,\n        )\n    else:\n        # minimax_m2 style\n        gate = SimpleNamespace(\n            weight=torch.randn(E, H, device=\"cuda\") * 0.1,\n            top_k=K,\n        )\n        moe_block = SimpleNamespace(\n            gate=gate,\n            experts=experts,\n            top_k=K,\n            e_score_correction_bias=torch.zeros(E, device=\"cuda\"),\n        )\n\n    return moe_block, T, H, FF, E, K\n\n\nclass TestHFScatterMoESigmoidRouting:\n    \"\"\"Test HFScatterMoEGatedMLP forward with sigmoid routing on GPU.\"\"\"\n\n    def test_forward_matches_reference_bias_on_gate(self):\n        \"\"\"Forward pass with sigmoid routing (bias on gate) matches reference.\"\"\"\n        from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (\n            HFScatterMoEGatedMLP,\n            _sigmoid_topk_route,\n        )\n\n        moe_block, T, H, FF, E, K = _make_mock_sigmoid_moe_block(\n            T=16, H=64, FF=32, E=8, K=2, n_group=2, topk_group=1, bias_on_gate=True\n        )\n\n        hidden = torch.randn(1, T, H, device=\"cuda\")\n\n        # Get routing for reference\n        gate = moe_block.gate\n        hidden_flat = hidden.view(-1, H)\n        routing_weights, selected_experts, _, _ = _sigmoid_topk_route(\n            moe_block, gate, hidden_flat, gate.weight, None\n        )\n\n        # Reference output\n        ref_output = _reference_moe_forward(\n            hidden_flat,\n            gate.weight,\n            moe_block.experts.gate_up_proj,\n            moe_block.experts.down_proj,\n            moe_block.experts.act_fn,\n            routing_weights,\n            selected_experts,\n            E,\n        )\n\n        # Kernel output\n        kernel_output = HFScatterMoEGatedMLP.forward(moe_block, hidden)\n        kernel_output_flat = kernel_output.view(-1, H)\n\n        torch.testing.assert_close(\n            kernel_output_flat.float(),\n            ref_output.float(),\n            atol=5e-2,\n            rtol=5e-2,\n        )\n\n    def test_forward_matches_reference_bias_on_block(self):\n        \"\"\"Forward pass with sigmoid routing (minimax_m2 style, bias on block).\"\"\"\n        from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (\n            HFScatterMoEGatedMLP,\n            _sigmoid_topk_route,\n        )\n\n        moe_block, T, H, FF, E, K = _make_mock_sigmoid_moe_block(\n            T=16, H=64, FF=32, E=8, K=2, n_group=1, bias_on_gate=False\n        )\n\n        hidden = torch.randn(1, T, H, device=\"cuda\")\n        hidden_flat = hidden.view(-1, H)\n\n        gate = moe_block.gate\n        routing_weights, selected_experts, _, _ = _sigmoid_topk_route(\n            moe_block, gate, hidden_flat, gate.weight, None\n        )\n\n        ref_output = _reference_moe_forward(\n            hidden_flat,\n            gate.weight,\n            moe_block.experts.gate_up_proj,\n            moe_block.experts.down_proj,\n            moe_block.experts.act_fn,\n            routing_weights,\n            selected_experts,\n            E,\n        )\n\n        kernel_output = HFScatterMoEGatedMLP.forward(moe_block, hidden)\n        kernel_output_flat = kernel_output.view(-1, H)\n\n        torch.testing.assert_close(\n            kernel_output_flat.float(),\n            ref_output.float(),\n            atol=5e-2,\n            rtol=5e-2,\n        )\n\n    def test_softmax_routing_still_works(self):\n        \"\"\"Verify softmax routing (Qwen/OLMoE) is not broken.\"\"\"\n        from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (\n            HFScatterMoEGatedMLP,\n            _softmax_topk_route,\n        )\n\n        T, H, FF, E, K = 16, 64, 32, 4, 2\n        gate_up_proj = torch.randn(E, 2 * FF, H, device=\"cuda\") * 0.02\n        down_proj = torch.randn(E, H, FF, device=\"cuda\") * 0.02\n        act_fn = torch.nn.SiLU()\n\n        experts = SimpleNamespace(\n            gate_up_proj=gate_up_proj,\n            down_proj=down_proj,\n            act_fn=act_fn,\n            num_experts=E,\n        )\n        gate = SimpleNamespace(\n            weight=torch.randn(E, H, device=\"cuda\") * 0.1,\n            top_k=K,\n            num_experts=E,\n            norm_topk_prob=True,\n        )\n        moe_block = SimpleNamespace(gate=gate, experts=experts)\n\n        hidden = torch.randn(1, T, H, device=\"cuda\")\n        hidden_flat = hidden.view(-1, H)\n\n        routing_weights, selected_experts, _, _ = _softmax_topk_route(\n            moe_block, gate, hidden_flat, gate.weight, None\n        )\n\n        ref_output = _reference_moe_forward(\n            hidden_flat,\n            gate.weight,\n            gate_up_proj,\n            down_proj,\n            act_fn,\n            routing_weights,\n            selected_experts,\n            E,\n        )\n\n        kernel_output = HFScatterMoEGatedMLP.forward(moe_block, hidden)\n        kernel_output_flat = kernel_output.view(-1, H)\n\n        torch.testing.assert_close(\n            kernel_output_flat.float(),\n            ref_output.float(),\n            atol=5e-2,\n            rtol=5e-2,\n        )\n\n\nclass TestHFScatterMoESigmoidWithSharedExperts:\n    \"\"\"Test HFScatterMoEGatedMLP with sigmoid routing + shared experts.\"\"\"\n\n    def test_shared_experts_plural(self):\n        \"\"\"DeepSeek V3 style: shared_experts attribute (plural).\"\"\"\n        from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (\n            HFScatterMoEGatedMLP,\n        )\n\n        T, H, FF, E, K = 8, 64, 32, 8, 2\n        gate_up_proj = torch.randn(E, 2 * FF, H, device=\"cuda\") * 0.02\n        down_proj = torch.randn(E, H, FF, device=\"cuda\") * 0.02\n        act_fn = torch.nn.SiLU()\n\n        experts = SimpleNamespace(\n            gate_up_proj=gate_up_proj,\n            down_proj=down_proj,\n            act_fn=act_fn,\n            num_experts=E,\n        )\n\n        # Shared expert as a simple linear for testing\n        shared_W = torch.randn(H, H, device=\"cuda\") * 0.01\n        shared_experts_fn = lambda x: x @ shared_W.T  # noqa: E731\n\n        gate = SimpleNamespace(\n            weight=torch.randn(E, H, device=\"cuda\") * 0.1,\n            e_score_correction_bias=torch.zeros(E, device=\"cuda\"),\n        )\n        moe_block = SimpleNamespace(\n            gate=gate,\n            experts=experts,\n            shared_experts=shared_experts_fn,\n            top_k=K,\n            n_routed_experts=E,\n            n_group=1,\n            norm_topk_prob=True,\n            routed_scaling_factor=1.0,\n        )\n\n        hidden = torch.randn(1, T, H, device=\"cuda\")\n\n        # Should not raise; output should include shared expert contribution\n        output = HFScatterMoEGatedMLP.forward(moe_block, hidden)\n        assert output.shape == (1, T, H)\n\n        # Run without shared expert to verify it changes the output\n        moe_block_no_shared = SimpleNamespace(\n            gate=gate,\n            experts=experts,\n            top_k=K,\n            n_routed_experts=E,\n            n_group=1,\n            norm_topk_prob=True,\n            routed_scaling_factor=1.0,\n        )\n        output_no_shared = HFScatterMoEGatedMLP.forward(moe_block_no_shared, hidden)\n        assert not torch.equal(output, output_no_shared)\n\n    def test_shared_expert_with_gate(self):\n        \"\"\"Qwen2MoE style: shared_expert + shared_expert_gate.\"\"\"\n        from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (\n            HFScatterMoEGatedMLP,\n        )\n\n        T, H, FF, E, K = 8, 64, 32, 4, 2\n        gate_up_proj = torch.randn(E, 2 * FF, H, device=\"cuda\") * 0.02\n        down_proj = torch.randn(E, H, FF, device=\"cuda\") * 0.02\n        act_fn = torch.nn.SiLU()\n\n        experts = SimpleNamespace(\n            gate_up_proj=gate_up_proj,\n            down_proj=down_proj,\n            act_fn=act_fn,\n            num_experts=E,\n        )\n\n        shared_W = torch.randn(H, H, device=\"cuda\") * 0.01\n        shared_expert_fn = lambda x: x @ shared_W.T  # noqa: E731\n        # Gate that returns 0 -> sigmoid(0) = 0.5\n        gate_W = torch.zeros(H, H, device=\"cuda\")\n        shared_expert_gate_fn = lambda x: x @ gate_W.T  # noqa: E731\n\n        gate = SimpleNamespace(\n            weight=torch.randn(E, H, device=\"cuda\") * 0.1,\n            top_k=K,\n            num_experts=E,\n            norm_topk_prob=True,\n        )\n        moe_block = SimpleNamespace(\n            gate=gate,\n            experts=experts,\n            shared_expert=shared_expert_fn,\n            shared_expert_gate=shared_expert_gate_fn,\n        )\n\n        hidden = torch.randn(1, T, H, device=\"cuda\")\n        output = HFScatterMoEGatedMLP.forward(moe_block, hidden)\n        assert output.shape == (1, T, H)\n"
  },
  {
    "path": "tests/e2e/integrations/test_scattermoe_lora_olmoe.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n# Copyright (c) Axolotl AI\n# Licensed under the Apache License, Version 2.0\n\n\"\"\"\nIntegration tests: OLMoE + peft LoRA + ScatterMoE fused kernels.\n\nValidates that scattermoe_lora fused kernels produce correct results when used\nwith HuggingFace OLMoE models and peft LoRA adapters applied via\n``target_parameters``.\n\nKey things tested\n-----------------\n- LoRA weight layout conversion between peft (rank-major) and scattermoe (expert-major)\n- Base forward equivalence: per-expert reference vs ScatterMoE kernels (no LoRA)\n- LoRA forward equivalence: peft merged-weight approach vs scattermoe fused kernels\n- Backward gradient correctness through the fused LoRA path\n- ``kernelize()`` integration via ``LocalLayerRepository``\n\"\"\"\n\nfrom pathlib import Path\n\nimport pytest\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom peft import LoraConfig, get_peft_model\nfrom transformers import OlmoeConfig\nfrom transformers.models.olmoe.modeling_olmoe import OlmoeSparseMoeBlock\n\n_SMOE = \"axolotl.integrations.kernels.libs.scattermoe_lora\"\n\n# Try to import from axolotl's scattermoe_lora.layers; may fail on CPU without triton.\ntry:\n    from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (\n        _unwrap_experts_lora,\n        _unwrap_gate_lora,\n        peft_lora_B_to_scattermoe,\n        peft_lora_to_scattermoe,\n    )\n\n    HAS_SCATTERMOE = True\nexcept (ImportError, ModuleNotFoundError):\n    HAS_SCATTERMOE = False\n\n    # Provide pure-torch fallbacks for CPU-only layout conversion tests.\n    def peft_lora_B_to_scattermoe(peft_B, num_experts, rank):\n        N = peft_B.shape[0]\n        return (\n            peft_B.reshape(N, rank, num_experts)\n            .permute(0, 2, 1)\n            .contiguous()\n            .reshape(N, num_experts * rank)\n        )\n\n    def peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank):\n        peft_B_em = peft_lora_B_to_scattermoe(peft_B, num_experts, rank)\n        K_inter, N_hidden = peft_B.shape[0], peft_A.shape[1]\n        smoe_A = torch.zeros(\n            rank * num_experts,\n            K_inter,\n            device=peft_A.device,\n            dtype=peft_A.dtype,\n        )\n        smoe_B = torch.zeros(\n            N_hidden,\n            rank * num_experts,\n            device=peft_A.device,\n            dtype=peft_A.dtype,\n        )\n        for e in range(num_experts):\n            s = e * rank\n            smoe_A[s : s + rank, :] = peft_B_em[:, s : s + rank].T\n            smoe_B[:, s : s + rank] = peft_A[s : s + rank, :].T\n        return smoe_A, smoe_B\n\n    def _unwrap_experts_lora(experts_module):\n        return experts_module, None, None\n\n    def _unwrap_gate_lora(gate_module):\n        if hasattr(gate_module, \"base_layer\") and hasattr(gate_module, \"lora_A\"):\n            base_gate = gate_module.base_layer\n            active = getattr(gate_module, \"active_adapters\", [\"default\"])\n            name = active[0] if active else \"default\"\n            lora_A_dict = getattr(gate_module, \"lora_A\", {})\n            lora_B_dict = getattr(gate_module, \"lora_B\", {})\n            scaling_dict = getattr(gate_module, \"scaling\", {})\n            if name in lora_A_dict:\n                lora_A = lora_A_dict[name].weight\n                lora_B = lora_B_dict[name].weight\n                s = scaling_dict[name]\n                delta = s * (lora_B @ lora_A)\n                return base_gate, base_gate.weight, delta\n            return base_gate, base_gate.weight, None\n        return gate_module, gate_module.weight, None\n\n\n# =============================================================================\n# Configuration\n# =============================================================================\n\nFULL_OLMOE_CONFIG = dict(\n    hidden_size=2048,\n    intermediate_size=1024,\n    num_experts=64,\n    num_experts_per_tok=8,\n    hidden_act=\"silu\",\n    norm_topk_prob=False,\n)\n\nSMALL_OLMOE_CONFIG = dict(\n    hidden_size=128,\n    intermediate_size=48,  # non-square: 2*inter=96 != hidden=128\n    num_experts=8,\n    num_experts_per_tok=2,\n    hidden_act=\"silu\",\n    norm_topk_prob=False,\n)\n\nrequires_cuda = pytest.mark.skipif(\n    not torch.cuda.is_available(), reason=\"CUDA not available\"\n)\n\n\ndef make_olmoe_config(use_full=False):\n    cfg = dict(FULL_OLMOE_CONFIG if use_full else SMALL_OLMOE_CONFIG)\n    cfg[\"experts_implementation\"] = \"grouped_mm\"\n    return OlmoeConfig(**cfg)\n\n\n# =============================================================================\n# Layout conversion utilities (test-local helpers)\n# =============================================================================\n\n\ndef scattermoe_lora_B_to_peft(smoe_B, num_experts, rank):\n    \"\"\"Inverse of ``peft_lora_B_to_scattermoe``.\"\"\"\n    N = smoe_B.shape[0]\n    return (\n        smoe_B.reshape(N, num_experts, rank)\n        .permute(0, 2, 1)\n        .contiguous()\n        .reshape(N, num_experts * rank)\n    )\n\n\ndef peft_gate_up_lora_to_scattermoe(peft_A, peft_B, num_experts, rank):\n    \"\"\"Convert peft LoRA for gate_up_proj to scattermoe layout.\n\n    Both gate_up_proj and down_proj need the A<->B swap because\n    scattermoe transposes the parameter (W = param.T).\n    \"\"\"\n    return peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank)\n\n\n# =============================================================================\n# Helpers\n# =============================================================================\n\n\ndef _init_expert_weights(moe_block):\n    \"\"\"Initialize OlmoeExperts parameters which use torch.empty (uninitialized).\n\n    Without this, gate_up_proj and down_proj contain garbage/NaN values.\n    \"\"\"\n    with torch.no_grad():\n        nn.init.kaiming_uniform_(moe_block.experts.gate_up_proj)\n        nn.init.kaiming_uniform_(moe_block.experts.down_proj)\n    return moe_block\n\n\nclass MinimalOLMoEModel(nn.Module):\n    \"\"\"Thin wrapper so peft's get_peft_model can attach adapters.\"\"\"\n\n    def __init__(self, config):\n        super().__init__()\n        self.moe = OlmoeSparseMoeBlock(config)\n        _init_expert_weights(self.moe)\n\n    def forward(self, x):\n        return self.moe(x)\n\n\ndef _get_routing(moe_block, hidden_states):\n    \"\"\"Run the router and return (routing_weights, selected_experts).\"\"\"\n    with torch.no_grad():\n        _, routing_weights, selected_experts = moe_block.gate(\n            hidden_states.view(-1, hidden_states.size(-1))\n        )\n    return routing_weights, selected_experts\n\n\ndef _reference_moe_forward(\n    x_flat,\n    gate_up_proj,\n    down_proj,\n    act_fn,\n    top_k_index,\n    top_k_weights,\n    num_experts,\n):\n    \"\"\"Pure-PyTorch per-expert reference MoE forward (no LoRA).\n\n    Uses F.linear per expert for an apples-to-apples comparison with\n    the ScatterMoE kernel path.\n    \"\"\"\n    final = torch.zeros_like(x_flat)\n    expert_mask = F.one_hot(top_k_index, num_classes=num_experts).permute(2, 1, 0)\n    for e in range(num_experts):\n        top_k_pos, token_idx = torch.where(expert_mask[e])\n        if token_idx.numel() == 0:\n            continue\n        cur = x_flat[token_idx]\n        gate_up = F.linear(cur, gate_up_proj[e])\n        g, u = gate_up.chunk(2, dim=-1)\n        h = act_fn(g) * u\n        out = F.linear(h, down_proj[e])\n        out = out * top_k_weights[token_idx, top_k_pos, None]\n        final.index_add_(0, token_idx, out.to(final.dtype))\n    return final\n\n\ndef _reference_moe_forward_with_lora(\n    x_flat,\n    gate_up_proj,\n    down_proj,\n    act_fn,\n    top_k_index,\n    top_k_weights,\n    num_experts,\n    gup_delta,\n    down_delta,\n):\n    \"\"\"Pure-PyTorch reference MoE forward with pre-computed weight deltas.\"\"\"\n    merged_gup = gate_up_proj + gup_delta\n    merged_down = down_proj + down_delta\n    return _reference_moe_forward(\n        x_flat,\n        merged_gup,\n        merged_down,\n        act_fn,\n        top_k_index,\n        top_k_weights,\n        num_experts,\n    )\n\n\ndef _compute_delta_from_scattermoe_lora(lora_A, lora_B, scaling, E, r, param_shape):\n    \"\"\"Compute additive weight delta from scattermoe-layout LoRA weights.\n\n    delta[e] = scaling * B_e @ A_e  where A_e [r,K], B_e [N,r] -> [N,K].\n    \"\"\"\n    delta = torch.zeros(param_shape, device=lora_A.device, dtype=lora_A.dtype)\n    for e in range(E):\n        A_e = lora_A[e * r : (e + 1) * r, :]\n        B_e = lora_B[:, e * r : (e + 1) * r]\n        delta[e] = scaling * (B_e @ A_e)\n    return delta\n\n\n# =============================================================================\n# Tests: Layout conversion\n# =============================================================================\n\n\nclass TestLoRABLayoutConversion:\n    \"\"\"Test the peft <-> scattermoe lora_B layout conversion.\"\"\"\n\n    def test_roundtrip(self):\n        E, r, N = 8, 4, 64\n        original = torch.randn(N, E * r)\n        converted = peft_lora_B_to_scattermoe(original, E, r)\n        back = scattermoe_lora_B_to_peft(converted, E, r)\n        torch.testing.assert_close(back, original)\n\n    def test_per_expert_slices(self):\n        \"\"\"After conversion, scattermoe slicing gives the same per-expert\n        matrices as peft's reshape slicing.\"\"\"\n        E, r, N = 4, 2, 16\n        peft_B = torch.randn(N, E * r)\n        smoe_B = peft_lora_B_to_scattermoe(peft_B, E, r)\n\n        peft_reshaped = peft_B.reshape(N, r, E)\n        for e in range(E):\n            torch.testing.assert_close(\n                smoe_B[:, e * r : (e + 1) * r],\n                peft_reshaped[:, :, e],\n            )\n\n    def test_lora_A_already_compatible(self):\n        \"\"\"lora_A layout is identical between peft and scattermoe.\"\"\"\n        E, r, K = 4, 2, 16\n        lora_A = torch.randn(E * r, K)\n        peft_reshaped = lora_A.reshape(E, r, K)\n        for e in range(E):\n            torch.testing.assert_close(\n                lora_A[e * r : (e + 1) * r, :],\n                peft_reshaped[e],\n            )\n\n    def test_delta_weight_equivalence(self):\n        \"\"\"peft's einsum delta matches per-expert B @ A with converted layouts.\"\"\"\n        E, r, K, N = 8, 4, 32, 64\n        peft_A = torch.randn(E * r, K)\n        peft_B = torch.randn(N, E * r)\n        scaling = 2.0\n\n        A_r = peft_A.reshape(E, r, K)\n        B_r = peft_B.reshape(N, r, E)\n        delta_peft = torch.einsum(\"o r e, e r i -> e i o\", B_r, A_r) * scaling\n\n        smoe_B = peft_lora_B_to_scattermoe(peft_B, E, r)\n        for e in range(E):\n            A_e = peft_A[e * r : (e + 1) * r, :]\n            B_e = smoe_B[:, e * r : (e + 1) * r]\n            delta_e = scaling * (B_e @ A_e)\n            torch.testing.assert_close(delta_e, delta_peft[e].T, atol=1e-5, rtol=1e-5)\n\n    def test_down_proj_conversion(self):\n        \"\"\"Verify peft_lora_to_scattermoe produces correct delta.\"\"\"\n        E, r = 4, 2\n        hidden, inter = 32, 16\n        scaling = 2.0\n\n        peft_A = torch.randn(E * r, hidden)\n        peft_B = torch.randn(inter, E * r)\n\n        A_r = peft_A.reshape(E, r, hidden)\n        B_r = peft_B.reshape(inter, r, E)\n        delta_peft = torch.einsum(\"o r e, e r i -> e i o\", B_r, A_r) * scaling\n\n        smoe_A, smoe_B = peft_lora_to_scattermoe(peft_A, peft_B, E, r)\n        for e in range(E):\n            A_e = smoe_A[e * r : (e + 1) * r, :]\n            B_e = smoe_B[:, e * r : (e + 1) * r]\n            delta_smoe_e = scaling * (B_e @ A_e)\n            torch.testing.assert_close(\n                delta_smoe_e, delta_peft[e], atol=1e-5, rtol=1e-5\n            )\n\n    def test_gate_up_proj_conversion(self):\n        \"\"\"Verify gate_up_proj LoRA conversion with non-square dims (Qwen3-like).\n\n        gate_up_proj param: [E, 2*inter, hidden].\n        peft: in_features=2*inter, out_features=hidden.\n        peft lora_A: [r*E, 2*inter], lora_B: [hidden, r*E].\n\n        scattermoe W = param.T = [E, hidden, 2*inter], K=hidden, N=2*inter.\n        scattermoe needs: lora_A [r*E, K=hidden], lora_B [N=2*inter, r*E].\n\n        Uses non-square dims (hidden=32 != 2*inter=24) to catch A<->B swap bugs.\n        \"\"\"\n        E, r = 4, 2\n        hidden, inter = 32, 12  # 2*inter=24 != hidden=32\n        scaling = 2.0\n\n        # peft assigns: in_features=2*inter, out_features=hidden\n        peft_A = torch.randn(E * r, 2 * inter)  # [r*E, in_features=2*inter]\n        peft_B = torch.randn(hidden, E * r)  # [out_features=hidden, r*E]\n\n        # peft delta via einsum: \"o r e, e r i -> e i o\"\n        A_r = peft_A.reshape(E, r, 2 * inter)\n        B_r = peft_B.reshape(hidden, r, E)\n        delta_peft = torch.einsum(\"o r e, e r i -> e i o\", B_r, A_r) * scaling\n        # delta_peft[e] has shape [in_features, out_features] = [2*inter, hidden]\n        # = param[e] shape [2*inter, hidden]\n\n        smoe_A, smoe_B = peft_gate_up_lora_to_scattermoe(peft_A, peft_B, E, r)\n        # smoe_A should be [r*E, K=hidden], smoe_B should be [N=2*inter, r*E]\n        assert smoe_A.shape == (E * r, hidden), (\n            f\"Expected {(E * r, hidden)}, got {smoe_A.shape}\"\n        )\n        assert smoe_B.shape == (2 * inter, E * r), (\n            f\"Expected {(2 * inter, E * r)}, got {smoe_B.shape}\"\n        )\n\n        for e in range(E):\n            A_e = smoe_A[e * r : (e + 1) * r, :]  # [r, K=hidden]\n            B_e = smoe_B[:, e * r : (e + 1) * r]  # [N=2*inter, r]\n            delta_smoe_e = scaling * (B_e @ A_e)  # [2*inter, hidden]\n            # Should match peft delta which is [2*inter, hidden] = param[e]\n            torch.testing.assert_close(\n                delta_smoe_e, delta_peft[e], atol=1e-5, rtol=1e-5\n            )\n\n\n# =============================================================================\n# Tests: peft weight extraction\n# =============================================================================\n\n\nclass TestPeftLoRAWeightExtraction:\n    \"\"\"Test extracting peft LoRA weights for OLMoE.\"\"\"\n\n    def test_peft_creates_correct_shapes(self):\n        config = make_olmoe_config(use_full=False)\n        E, r = config.num_experts, 4\n\n        model = MinimalOLMoEModel(config)\n        lora_config = LoraConfig(\n            r=r,\n            lora_alpha=16,\n            target_modules=[],\n            target_parameters=[\n                \"gate.weight\",\n                \"experts.gate_up_proj\",\n                \"experts.down_proj\",\n            ],\n            bias=\"none\",\n        )\n        peft_model = get_peft_model(model, lora_config)\n        trainable = {n: p for n, p in peft_model.named_parameters() if p.requires_grad}\n\n        # Gate router\n        assert trainable[\"base_model.model.moe.gate.lora_A.default.weight\"].shape == (\n            r,\n            config.hidden_size,\n        )\n        assert trainable[\"base_model.model.moe.gate.lora_B.default.weight\"].shape == (\n            E,\n            r,\n        )\n\n        # gate_up_proj [E, 2*inter, hidden]\n        # peft: in_features=2*inter (dim 1), out_features=hidden (dim 2)\n        assert trainable[\n            \"base_model.model.moe.experts.base_layer.lora_A.default.weight\"\n        ].shape == (E * r, 2 * config.intermediate_size)\n        assert trainable[\n            \"base_model.model.moe.experts.base_layer.lora_B.default.weight\"\n        ].shape == (config.hidden_size, E * r)\n\n        # down_proj [E, hidden, inter]\n        # peft: in_features=hidden (dim 1), out_features=inter (dim 2)\n        assert trainable[\n            \"base_model.model.moe.experts.lora_A.default.weight\"\n        ].shape == (E * r, config.hidden_size)\n        assert trainable[\n            \"base_model.model.moe.experts.lora_B.default.weight\"\n        ].shape == (config.intermediate_size, E * r)\n\n    @requires_cuda\n    def test_peft_forward_runs(self):\n        \"\"\"Smoke test: peft model forward pass completes (needs CUDA for grouped_mm).\"\"\"\n        config = make_olmoe_config(use_full=False)\n        model = MinimalOLMoEModel(config)\n        lora_config = LoraConfig(\n            r=4,\n            lora_alpha=16,\n            target_modules=[],\n            target_parameters=[\n                \"gate.weight\",\n                \"experts.gate_up_proj\",\n                \"experts.down_proj\",\n            ],\n            bias=\"none\",\n        )\n        peft_model = get_peft_model(model, lora_config)\n        x = torch.randn(1, 4, config.hidden_size)\n        out = peft_model(x)\n        assert out.shape == x.shape\n\n    @pytest.mark.skipif(\n        not HAS_SCATTERMOE, reason=\"scattermoe_lora not importable (no triton)\"\n    )\n    def test_unwrap_experts_lora(self):\n        \"\"\"Test that _unwrap_experts_lora correctly detects LoRA wrappers.\"\"\"\n        config = make_olmoe_config(use_full=False)\n        model = MinimalOLMoEModel(config)\n        lora_config = LoraConfig(\n            r=4,\n            lora_alpha=16,\n            target_modules=[],\n            target_parameters=[\"experts.gate_up_proj\", \"experts.down_proj\"],\n            bias=\"none\",\n        )\n        peft_model = get_peft_model(model, lora_config)\n        base_moe = peft_model.base_model.model.moe\n\n        # Experts should be wrapped by ParamWrapper\n        experts, gup_lora, down_lora = _unwrap_experts_lora(base_moe.experts)\n\n        # Base experts should have the raw parameters\n        assert hasattr(experts, \"gate_up_proj\")\n        assert hasattr(experts, \"down_proj\")\n\n        # LoRA should be detected\n        assert gup_lora is not None, \"gate_up_proj LoRA not detected\"\n        assert down_lora is not None, \"down_proj LoRA not detected\"\n\n        # Check shapes (after peft->scattermoe conversion with A<->B swap)\n        # gate_up_proj W = param.T = [E, hidden, 2*inter], K=hidden, N=2*inter\n        E, r = config.num_experts, 4\n        gup_A, gup_B, gup_s = gup_lora\n        assert gup_A.shape == (E * r, config.hidden_size), (\n            f\"gate_up_proj smoe_A: expected [r*E, K=hidden]={(E * r, config.hidden_size)}, \"\n            f\"got {gup_A.shape}\"\n        )\n        assert gup_B.shape == (2 * config.intermediate_size, E * r), (\n            f\"gate_up_proj smoe_B: expected [N=2*inter, r*E]=\"\n            f\"{(2 * config.intermediate_size, E * r)}, got {gup_B.shape}\"\n        )\n\n        # down_proj W = param.T = [E, inter, hidden], K=inter, N=hidden\n        down_A, down_B, down_s = down_lora\n        assert down_A.shape == (E * r, config.intermediate_size), (\n            f\"down_proj smoe_A: expected [r*E, K=inter]={(E * r, config.intermediate_size)}, \"\n            f\"got {down_A.shape}\"\n        )\n        assert down_B.shape == (config.hidden_size, E * r), (\n            f\"down_proj smoe_B: expected [N=hidden, r*E]={(config.hidden_size, E * r)}, \"\n            f\"got {down_B.shape}\"\n        )\n\n    def test_unwrap_no_lora(self):\n        \"\"\"Without peft, _unwrap_experts_lora returns no LoRA.\"\"\"\n        config = make_olmoe_config(use_full=False)\n        moe = OlmoeSparseMoeBlock(config)\n        experts, gup_lora, down_lora = _unwrap_experts_lora(moe.experts)\n        assert gup_lora is None\n        assert down_lora is None\n        assert hasattr(experts, \"gate_up_proj\")\n\n    def test_unwrap_gate_lora(self):\n        \"\"\"Test that _unwrap_gate_lora detects LoRA on the router gate.\"\"\"\n        config = make_olmoe_config(use_full=False)\n        model = MinimalOLMoEModel(config)\n        r = 4\n        lora_config = LoraConfig(\n            r=r,\n            lora_alpha=16,\n            target_modules=[],\n            target_parameters=[\"gate.weight\"],\n            bias=\"none\",\n        )\n        peft_model = get_peft_model(model, lora_config)\n        base_moe = peft_model.base_model.model.moe\n\n        # Set non-zero LoRA weights (peft initializes lora_B to zeros)\n        with torch.no_grad():\n            base_moe.gate.lora_B[\"default\"].weight.normal_(0, 0.01)\n\n        base_gate, gate_weight, gate_delta = _unwrap_gate_lora(base_moe.gate)\n\n        # Base gate should be the original router\n        assert hasattr(base_gate, \"top_k\")\n        assert hasattr(base_gate, \"num_experts\")\n        assert base_gate.top_k == config.num_experts_per_tok\n        assert base_gate.num_experts == config.num_experts\n\n        # Gate weight should be the base weight (delta returned separately)\n        assert gate_weight.shape == (config.num_experts, config.hidden_size)\n        torch.testing.assert_close(gate_weight, base_gate.weight)\n\n        # Delta should be non-zero (LoRA was applied)\n        assert gate_delta is not None\n        assert gate_delta.shape == (config.num_experts, config.hidden_size)\n        assert gate_delta.abs().max() > 0, \"Gate LoRA delta should be non-zero\"\n\n    def test_unwrap_gate_no_lora(self):\n        \"\"\"Without peft, _unwrap_gate_lora returns the original gate.\"\"\"\n        config = make_olmoe_config(use_full=False)\n        moe = OlmoeSparseMoeBlock(config)\n        base_gate, gate_weight, gate_delta = _unwrap_gate_lora(moe.gate)\n        assert base_gate is moe.gate\n        torch.testing.assert_close(gate_weight, moe.gate.weight)\n        assert gate_delta is None\n\n    def test_gate_lora_delta_matches_peft(self):\n        \"\"\"Verify _unwrap_gate_lora computes the same delta as peft.\"\"\"\n        config = make_olmoe_config(use_full=False)\n        model = MinimalOLMoEModel(config)\n        r = 4\n        lora_alpha = 16\n        scaling = lora_alpha / r\n        lora_config = LoraConfig(\n            r=r,\n            lora_alpha=lora_alpha,\n            target_modules=[],\n            target_parameters=[\"gate.weight\"],\n            bias=\"none\",\n        )\n        peft_model = get_peft_model(model, lora_config)\n        base_moe = peft_model.base_model.model.moe\n\n        # Our unwrapped weight + delta\n        _, gate_weight, gate_delta = _unwrap_gate_lora(base_moe.gate)\n\n        # Manually compute expected delta\n        lora_A = base_moe.gate.lora_A[\"default\"].weight  # [r, hidden]\n        lora_B = base_moe.gate.lora_B[\"default\"].weight  # [E, r]\n        base_weight = base_moe.gate.base_layer.weight  # [E, hidden]\n        expected_delta = scaling * (lora_B @ lora_A)\n\n        torch.testing.assert_close(gate_weight, base_weight)\n        torch.testing.assert_close(gate_delta, expected_delta)\n        # Combined should match the old behavior\n        torch.testing.assert_close(\n            gate_weight + gate_delta, base_weight + expected_delta\n        )\n\n\n# =============================================================================\n# Tests: Base forward equivalence (no LoRA)\n# =============================================================================\n\n\n@requires_cuda\nclass TestOLMoEReferenceVsScatterMoE:\n    \"\"\"Base forward equivalence: per-expert reference vs ScatterMoE kernels.\"\"\"\n\n    def test_small(self):\n        self._run(use_full=False, M=16)\n\n    @pytest.mark.slow\n    def test_full(self):\n        self._run(use_full=True, M=32)\n\n    def _run(self, use_full, M):\n        from axolotl.integrations.kernels.libs.scattermoe_lora import (\n            flatten_sort_count,\n            parallel_linear,\n        )\n\n        config = make_olmoe_config(use_full=use_full)\n        torch.manual_seed(42)\n        moe = _init_expert_weights(OlmoeSparseMoeBlock(config)).cuda().float()\n        E, k = config.num_experts, config.num_experts_per_tok\n\n        x = torch.randn(1, M, config.hidden_size, device=\"cuda\")\n        x_flat = x.view(-1, config.hidden_size)\n\n        with torch.no_grad():\n            # Shared routing for both paths\n            _, rw, sel = moe.gate(x_flat)\n            sei, ssi, eo = flatten_sort_count(sel, num_experts=E)\n\n            # Per-expert reference\n            ref_out = _reference_moe_forward(\n                x_flat,\n                moe.experts.gate_up_proj,\n                moe.experts.down_proj,\n                moe.experts.act_fn,\n                sel,\n                rw,\n                E,\n            ).view(1, M, config.hidden_size)\n\n            # ScatterMoE kernel path\n            gup = parallel_linear(\n                x_flat,\n                moe.experts.gate_up_proj.transpose(2, 1),\n                k,\n                sei,\n                ssi,\n                eo,\n                grouped_in=False,\n                grouped_out=True,\n            )\n            g, u = gup.chunk(2, dim=-1)\n            h = moe.experts.act_fn(g) * u\n\n            smoe_out = parallel_linear(\n                h,\n                moe.experts.down_proj.transpose(2, 1),\n                1,\n                sei,\n                ssi,\n                eo,\n                grouped_in=True,\n                grouped_out=False,\n                gates=rw,\n            ).view(1, M, config.hidden_size)\n\n        torch.testing.assert_close(smoe_out, ref_out, atol=1e-3, rtol=1e-3)\n\n\n# =============================================================================\n# Tests: LoRA forward equivalence (peft vs scattermoe fused)\n# =============================================================================\n\n\n@requires_cuda\nclass TestOLMoEPeftLoRAForward:\n    \"\"\"Fused LoRA forward: peft merged-weight vs scattermoe_lora kernel.\"\"\"\n\n    def test_small(self):\n        self._run(use_full=False, M=16, r=4)\n\n    @pytest.mark.slow\n    def test_full(self):\n        self._run(use_full=True, M=32, r=8)\n\n    def _run(self, use_full, M, r):\n        from axolotl.integrations.kernels.libs.scattermoe_lora import (\n            flatten_sort_count,\n            parallel_linear_lora,\n        )\n\n        config = make_olmoe_config(use_full=use_full)\n        E, k = config.num_experts, config.num_experts_per_tok\n        lora_alpha = 16\n        scaling = lora_alpha / r\n\n        # Create peft model\n        model = MinimalOLMoEModel(config).cuda().float()\n        lora_config = LoraConfig(\n            r=r,\n            lora_alpha=lora_alpha,\n            target_modules=[],\n            target_parameters=[\"experts.gate_up_proj\", \"experts.down_proj\"],\n            bias=\"none\",\n        )\n        peft_model = get_peft_model(model, lora_config)\n\n        torch.manual_seed(42)\n        x = torch.randn(1, M, config.hidden_size, device=\"cuda\")\n\n        # peft forward\n        with torch.no_grad():\n            peft_out = peft_model(x)\n\n        # Extract base weights and LoRA weights\n        base_moe = peft_model.base_model.model.moe\n        base_experts = base_moe.experts.base_layer.base_layer\n        gate_up_proj = base_experts.gate_up_proj\n        down_proj = base_experts.down_proj\n        act_fn = base_experts.act_fn\n\n        # gate_up_proj LoRA\n        gup_w = base_moe.experts.base_layer\n        peft_gup_A = gup_w.lora_A[\"default\"].weight.detach()\n        peft_gup_B = gup_w.lora_B[\"default\"].weight.detach()\n        smoe_gup_A, smoe_gup_B = peft_gate_up_lora_to_scattermoe(\n            peft_gup_A, peft_gup_B, E, r\n        )\n\n        # down_proj LoRA\n        down_w = base_moe.experts\n        peft_down_A = down_w.lora_A[\"default\"].weight.detach()\n        peft_down_B = down_w.lora_B[\"default\"].weight.detach()\n        smoe_down_A, smoe_down_B = peft_lora_to_scattermoe(\n            peft_down_A, peft_down_B, E, r\n        )\n\n        # ScatterMoE fused forward -- gate is NOT peft-wrapped, access directly\n        x_flat = x.view(-1, config.hidden_size)\n\n        with torch.no_grad():\n            _, rw, sel = base_moe.gate(x_flat)\n            sei, ssi, eo = flatten_sort_count(sel, num_experts=E)\n\n            gup = parallel_linear_lora(\n                x_flat,\n                gate_up_proj.transpose(2, 1),\n                k,\n                sei,\n                ssi,\n                eo,\n                lora_A=smoe_gup_A,\n                lora_B=smoe_gup_B,\n                scaling=scaling,\n                grouped_in=False,\n                grouped_out=True,\n            )\n            g, u = gup.chunk(2, dim=-1)\n            h = act_fn(g) * u\n\n            smoe_out = parallel_linear_lora(\n                h,\n                down_proj.transpose(2, 1),\n                1,\n                sei,\n                ssi,\n                eo,\n                lora_A=smoe_down_A,\n                lora_B=smoe_down_B,\n                scaling=scaling,\n                grouped_in=True,\n                grouped_out=False,\n                gates=rw,\n            ).view(1, M, config.hidden_size)\n\n        torch.testing.assert_close(smoe_out, peft_out, atol=5e-3, rtol=5e-3)\n\n\n# =============================================================================\n# Tests: Backward gradient correctness\n# =============================================================================\n\n\n@requires_cuda\nclass TestOLMoEPeftLoRABackward:\n    \"\"\"Backward gradients through scattermoe_lora vs pure-PyTorch reference.\"\"\"\n\n    def test_small(self):\n        self._run(use_full=False, M=16, r=4)\n\n    def _run(self, use_full, M, r):\n        from axolotl.integrations.kernels.libs.scattermoe_lora import (\n            flatten_sort_count,\n            parallel_linear_lora,\n        )\n\n        config = make_olmoe_config(use_full=use_full)\n        E, k = config.num_experts, config.num_experts_per_tok\n        lora_alpha = 16\n        scaling = lora_alpha / r\n\n        torch.manual_seed(42)\n        moe = _init_expert_weights(OlmoeSparseMoeBlock(config)).cuda().float()\n        x = torch.randn(1, M, config.hidden_size, device=\"cuda\")\n        x_flat = x.view(-1, config.hidden_size)\n        gate_up_proj = moe.experts.gate_up_proj\n        down_proj = moe.experts.down_proj\n\n        # Create LoRA weights in scattermoe layout directly\n        gup_A = torch.randn(r * E, config.hidden_size, device=\"cuda\") * 0.01\n        gup_B = torch.randn(2 * config.intermediate_size, r * E, device=\"cuda\") * 0.01\n        down_A = torch.randn(r * E, config.intermediate_size, device=\"cuda\") * 0.01\n        down_B = torch.randn(config.hidden_size, r * E, device=\"cuda\") * 0.01\n\n        rw, sel = _get_routing(moe, x)\n        sei, ssi, eo = flatten_sort_count(sel, num_experts=E)\n\n        # --- Reference ---\n        gup_delta = _compute_delta_from_scattermoe_lora(\n            gup_A, gup_B, scaling, E, r, gate_up_proj.shape\n        )\n        down_delta = _compute_delta_from_scattermoe_lora(\n            down_A, down_B, scaling, E, r, down_proj.shape\n        )\n\n        x_ref = x_flat.clone().detach().requires_grad_(True)\n        ref_out = _reference_moe_forward_with_lora(\n            x_ref,\n            gate_up_proj,\n            down_proj,\n            moe.experts.act_fn,\n            sel,\n            rw,\n            E,\n            gup_delta,\n            down_delta,\n        )\n        ref_out.sum().backward()\n\n        # --- ScatterMoE fused path ---\n        x_smoe = x_flat.clone().detach().requires_grad_(True)\n        gup_A_s = gup_A.clone().requires_grad_(True)\n        gup_B_s = gup_B.clone().requires_grad_(True)\n        down_A_s = down_A.clone().requires_grad_(True)\n        down_B_s = down_B.clone().requires_grad_(True)\n\n        gup_out = parallel_linear_lora(\n            x_smoe,\n            gate_up_proj.transpose(2, 1),\n            k,\n            sei,\n            ssi,\n            eo,\n            lora_A=gup_A_s,\n            lora_B=gup_B_s,\n            scaling=scaling,\n            grouped_in=False,\n            grouped_out=True,\n        )\n        g, u = gup_out.chunk(2, dim=-1)\n        h = moe.experts.act_fn(g) * u\n\n        smoe_out = parallel_linear_lora(\n            h,\n            down_proj.transpose(2, 1),\n            1,\n            sei,\n            ssi,\n            eo,\n            lora_A=down_A_s,\n            lora_B=down_B_s,\n            scaling=scaling,\n            grouped_in=True,\n            grouped_out=False,\n            gates=rw,\n        )\n        smoe_out.sum().backward()\n\n        torch.testing.assert_close(\n            smoe_out.detach(),\n            ref_out.detach(),\n            atol=5e-3,\n            rtol=5e-3,\n        )\n        torch.testing.assert_close(\n            x_smoe.grad,\n            x_ref.grad,\n            atol=5e-2,\n            rtol=5e-2,\n        )\n\n\n# =============================================================================\n# Tests: kernelize() integration via LocalLayerRepository\n# =============================================================================\n\n\n@requires_cuda\nclass TestKernelizeIntegration:\n    \"\"\"Test the HF kernels library integration with LocalLayerRepository.\"\"\"\n\n    @staticmethod\n    def _get_kernelize_imports():\n        \"\"\"Import kernels library components, skip if not available.\"\"\"\n        try:\n            from kernels import (\n                LocalLayerRepository,\n                Mode,\n                kernelize,\n                register_kernel_mapping,\n                replace_kernel_forward_from_hub,\n            )\n\n            return (\n                LocalLayerRepository,\n                Mode,\n                register_kernel_mapping,\n                replace_kernel_forward_from_hub,\n                kernelize,\n            )\n        except ImportError:\n            pytest.skip(\"kernels library not installed\")\n\n    @staticmethod\n    def _get_repo_path():\n        \"\"\"Get the path to scattermoe_lora within axolotl's plugin.\"\"\"\n        return (\n            Path(__file__).parent.parent.parent\n            / \"src\"\n            / \"axolotl\"\n            / \"integrations\"\n            / \"kernels\"\n            / \"libs\"\n            / \"scattermoe_lora\"\n        )\n\n    def _setup_kernels(\n        self,\n        LocalLayerRepository,\n        Mode,\n        register_kernel_mapping,\n        replace_kernel_forward_from_hub,\n    ):\n        \"\"\"Register kernel mapping for tests.\"\"\"\n        repo_path = self._get_repo_path()\n        local_repo = LocalLayerRepository(\n            repo_path=repo_path,\n            package_name=\"scattermoe_lora\",\n            layer_name=\"HFScatterMoEGatedMLP\",\n        )\n\n        replace_kernel_forward_from_hub(\n            OlmoeSparseMoeBlock, \"HFScatterMoEParallelExperts\"\n        )\n        register_kernel_mapping(\n            {\n                \"HFScatterMoEParallelExperts\": {\n                    \"cuda\": {\n                        Mode.TRAINING: local_repo,\n                        Mode.INFERENCE: local_repo,\n                    },\n                }\n            }\n        )\n\n    def test_base_forward_via_kernelize(self):\n        \"\"\"Kernelized OlmoeSparseMoeBlock (no LoRA) matches per-expert reference.\"\"\"\n        (\n            LocalLayerRepository,\n            Mode,\n            register_kernel_mapping,\n            replace_kernel_forward_from_hub,\n            kernelize,\n        ) = self._get_kernelize_imports()\n\n        config = make_olmoe_config(use_full=False)\n        E = config.num_experts\n\n        # Create model\n        torch.manual_seed(42)\n        moe = _init_expert_weights(OlmoeSparseMoeBlock(config)).cuda().float()\n        x = torch.randn(1, 8, config.hidden_size, device=\"cuda\")\n        x_flat = x.view(-1, config.hidden_size)\n\n        # Compute reference BEFORE kernelizing\n        with torch.no_grad():\n            _, rw, sel = moe.gate(x_flat)\n            ref_out = _reference_moe_forward(\n                x_flat,\n                moe.experts.gate_up_proj,\n                moe.experts.down_proj,\n                moe.experts.act_fn,\n                sel,\n                rw,\n                E,\n            ).view(1, 8, config.hidden_size)\n\n        # Set up kernel mapping\n        self._setup_kernels(\n            LocalLayerRepository,\n            Mode,\n            register_kernel_mapping,\n            replace_kernel_forward_from_hub,\n        )\n\n        # Kernelize the model\n        kernelize(moe, mode=Mode.TRAINING, device=\"cuda\")\n\n        # Forward through kernelized model\n        with torch.no_grad():\n            kern_out = moe(x)\n\n        torch.testing.assert_close(kern_out, ref_out, atol=1e-3, rtol=1e-3)\n\n    def test_lora_forward_via_kernelize(self):\n        \"\"\"Kernelized OlmoeSparseMoeBlock with peft LoRA matches reference.\"\"\"\n        (\n            LocalLayerRepository,\n            Mode,\n            register_kernel_mapping,\n            replace_kernel_forward_from_hub,\n            kernelize,\n        ) = self._get_kernelize_imports()\n\n        config = make_olmoe_config(use_full=False)\n        r = 4\n\n        # Create peft model\n        torch.manual_seed(42)\n        model = MinimalOLMoEModel(config).cuda().float()\n        lora_config = LoraConfig(\n            r=r,\n            lora_alpha=16,\n            target_modules=[],\n            target_parameters=[\"experts.gate_up_proj\", \"experts.down_proj\"],\n            bias=\"none\",\n        )\n        peft_model = get_peft_model(model, lora_config)\n\n        x = torch.randn(1, 8, config.hidden_size, device=\"cuda\")\n\n        # Reference: peft's own forward (uses _activate_lora context manager)\n        with torch.no_grad():\n            ref_out = peft_model(x)\n\n        # Set up kernel mapping\n        self._setup_kernels(\n            LocalLayerRepository,\n            Mode,\n            register_kernel_mapping,\n            replace_kernel_forward_from_hub,\n        )\n\n        # Kernelize the MoE block inside the peft model\n        base_moe = peft_model.base_model.model.moe\n        kernelize(base_moe, mode=Mode.TRAINING, device=\"cuda\")\n\n        # Forward through kernelized peft model\n        with torch.no_grad():\n            kern_out = peft_model(x)\n\n        torch.testing.assert_close(kern_out, ref_out, atol=5e-3, rtol=5e-3)\n\n    def test_gate_lora_forward_via_kernelize(self):\n        \"\"\"Kernelized forward with gate LoRA matches peft reference.\"\"\"\n        (\n            LocalLayerRepository,\n            Mode,\n            register_kernel_mapping,\n            replace_kernel_forward_from_hub,\n            kernelize,\n        ) = self._get_kernelize_imports()\n\n        config = make_olmoe_config(use_full=False)\n        r = 4\n\n        # Create peft model with gate + experts LoRA\n        torch.manual_seed(42)\n        model = MinimalOLMoEModel(config).cuda().float()\n        lora_config = LoraConfig(\n            r=r,\n            lora_alpha=16,\n            target_modules=[],\n            target_parameters=[\n                \"gate.weight\",\n                \"experts.gate_up_proj\",\n                \"experts.down_proj\",\n            ],\n            bias=\"none\",\n        )\n        peft_model = get_peft_model(model, lora_config)\n\n        x = torch.randn(1, 8, config.hidden_size, device=\"cuda\")\n\n        # Reference: peft's own forward\n        with torch.no_grad():\n            ref_out = peft_model(x)\n\n        # Set up kernel mapping\n        self._setup_kernels(\n            LocalLayerRepository,\n            Mode,\n            register_kernel_mapping,\n            replace_kernel_forward_from_hub,\n        )\n\n        # Kernelize the MoE block inside the peft model\n        base_moe = peft_model.base_model.model.moe\n        kernelize(base_moe, mode=Mode.TRAINING, device=\"cuda\")\n\n        # Forward through kernelized peft model\n        with torch.no_grad():\n            kern_out = peft_model(x)\n\n        torch.testing.assert_close(kern_out, ref_out, atol=5e-3, rtol=5e-3)\n\n\n# =============================================================================\n# Tests: Shared expert handling\n# =============================================================================\n\n\nclass TestSharedExpertHandling:\n    \"\"\"Test that HFScatterMoEGatedMLP.forward handles shared experts.\"\"\"\n\n    @staticmethod\n    def _make_shared_expert_block(config):\n        \"\"\"Create an OlmoeSparseMoeBlock with a mock shared expert attached.\"\"\"\n        moe = OlmoeSparseMoeBlock(config)\n        _init_expert_weights(moe)\n\n        hidden = config.hidden_size\n        inter = config.intermediate_size\n\n        # Attach a simple shared expert MLP (mimics Qwen2MoE structure)\n        class SharedExpertMLP(nn.Module):\n            def __init__(self, hidden_size, intermediate_size):\n                super().__init__()\n                self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)\n                self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)\n                self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)\n                self.act_fn = nn.SiLU()\n\n            def forward(self, x):\n                return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))\n\n        moe.shared_expert = SharedExpertMLP(hidden, inter)\n        moe.shared_expert_gate = nn.Linear(hidden, 1, bias=False)\n\n        return moe\n\n    def test_shared_expert_is_used(self):\n        \"\"\"Verify shared expert output affects final result.\"\"\"\n        config = make_olmoe_config(use_full=False)\n        moe = self._make_shared_expert_block(config)\n\n        # Compute reference without shared expert\n        torch.manual_seed(42)\n        x = torch.randn(1, 4, config.hidden_size)\n        x_flat = x.view(-1, config.hidden_size)\n\n        with torch.no_grad():\n            # Shared expert contribution\n            shared_out = moe.shared_expert(x_flat)\n            gate_val = F.sigmoid(moe.shared_expert_gate(x_flat))\n            shared_contribution = shared_out * gate_val\n\n        # Verify shared expert produces non-zero output\n        assert shared_contribution.abs().max() > 0\n\n    @requires_cuda\n    def test_shared_expert_forward_via_kernelize(self):\n        \"\"\"Kernelized forward with shared expert matches manual reference.\"\"\"\n        try:\n            from kernels import (\n                LocalLayerRepository,\n                Mode,\n                kernelize,\n                register_kernel_mapping,\n                replace_kernel_forward_from_hub,\n            )\n        except ImportError:\n            pytest.skip(\"kernels library not installed\")\n\n        config = make_olmoe_config(use_full=False)\n        E = config.num_experts\n\n        torch.manual_seed(42)\n        moe = self._make_shared_expert_block(config).cuda().float()\n        x = torch.randn(1, 8, config.hidden_size, device=\"cuda\")\n        x_flat = x.view(-1, config.hidden_size)\n\n        # Compute reference: per-expert + shared expert\n        with torch.no_grad():\n            _, rw, sel = moe.gate(x_flat)\n\n            expert_out = _reference_moe_forward(\n                x_flat,\n                moe.experts.gate_up_proj,\n                moe.experts.down_proj,\n                moe.experts.act_fn,\n                sel,\n                rw,\n                E,\n            )\n            shared_out = moe.shared_expert(x_flat)\n            gate_val = F.sigmoid(moe.shared_expert_gate(x_flat))\n            ref_out = (expert_out + shared_out * gate_val).view(\n                1, 8, config.hidden_size\n            )\n\n        # Kernelize\n        repo_path = (\n            Path(__file__).parent.parent.parent\n            / \"src\"\n            / \"axolotl\"\n            / \"integrations\"\n            / \"kernels\"\n            / \"libs\"\n            / \"scattermoe_lora\"\n        )\n        local_repo = LocalLayerRepository(\n            repo_path=repo_path,\n            package_name=\"scattermoe_lora\",\n            layer_name=\"HFScatterMoEGatedMLP\",\n        )\n\n        replace_kernel_forward_from_hub(\n            OlmoeSparseMoeBlock, \"HFScatterMoEParallelExperts\"\n        )\n        register_kernel_mapping(\n            {\n                \"HFScatterMoEParallelExperts\": {\n                    \"cuda\": {\n                        Mode.TRAINING: local_repo,\n                        Mode.INFERENCE: local_repo,\n                    },\n                }\n            }\n        )\n\n        kernelize(moe, mode=Mode.TRAINING, device=\"cuda\")\n\n        with torch.no_grad():\n            kern_out = moe(x)\n\n        torch.testing.assert_close(kern_out, ref_out, atol=1e-3, rtol=1e-3)\n"
  },
  {
    "path": "tests/e2e/integrations/test_sonicmoe.py",
    "content": "\"\"\"\nEnd-to-end gradient and convergence tests for SonicMoE integration.\n\nRequires:\n    - H100/H200 GPU (SonicMoE CUTLASS kernels target sm_90)\n    - sonicmoe package installed\n    - transformers with Qwen3MoE support\n\nUsage:\n    pytest tests/e2e/integrations/test_sonicmoe.py -v -s\n\"\"\"\n\nimport importlib.util\nimport math\n\nimport pytest\nimport torch\n\n_sonicmoe_available = importlib.util.find_spec(\"sonicmoe\") is not None\n_is_hopper = torch.cuda.is_available() and torch.cuda.get_device_capability() == (9, 0)\n\npytestmark = [\n    pytest.mark.skipif(not torch.cuda.is_available(), reason=\"Requires CUDA GPU\"),\n    pytest.mark.skipif(\n        not _is_hopper, reason=\"SonicMoE CUTLASS kernels require Hopper (sm_90)\"\n    ),\n    pytest.mark.skipif(not _sonicmoe_available, reason=\"SonicMoE not installed\"),\n]\n\n\ndef _create_tiny_qwen3_config():\n    \"\"\"Create a minimal Qwen3MoE config for fast testing.\"\"\"\n    from transformers import AutoConfig\n\n    config = AutoConfig.for_model(\"qwen3_moe\")\n    config.hidden_size = 512\n    config.intermediate_size = 1024\n    config.moe_intermediate_size = 64\n    config.num_attention_heads = 16\n    config.num_key_value_heads = 2\n    config.head_dim = 32\n    config.num_hidden_layers = 2\n    config.num_experts = 8\n    config.num_experts_per_tok = 2\n    config.vocab_size = 1000\n    config.max_position_embeddings = 128\n    config.norm_topk_prob = True\n    config.torch_dtype = torch.bfloat16\n    return config\n\n\ndef _interleave_gate_up_weights(model):\n    \"\"\"Interleave all gate_up_proj parameters in-place for SonicMoE.\"\"\"\n    from axolotl.integrations.kernels.sonicmoe.weight_converter import (\n        interleave_gate_up,\n    )\n\n    with torch.no_grad():\n        for name, param in model.named_parameters():\n            if \"gate_up_proj\" in name:\n                param.copy_(interleave_gate_up(param))\n\n\ndef _unpatch_sonicmoe():\n    \"\"\"Restore original forward on the MoE block class if it was patched.\"\"\"\n    from axolotl.integrations.kernels.constants import resolve_moe_block_classes\n\n    for moe_cls in resolve_moe_block_classes(\"qwen3_moe\"):\n        if hasattr(moe_cls, \"_original_forward\"):\n            moe_cls.forward = moe_cls._original_forward\n            del moe_cls._original_forward\n\n\nclass TestSonicMoEForwardCorrectness:\n    \"\"\"Verify SonicMoE-patched model produces same output as original.\"\"\"\n\n    def teardown_method(self):\n        _unpatch_sonicmoe()\n\n    def test_forward_output_matches(self):\n        from transformers import AutoModelForCausalLM\n\n        from axolotl.integrations.kernels.sonicmoe.patch import patch_sonicmoe\n\n        config = _create_tiny_qwen3_config()\n        input_ids = torch.randint(0, config.vocab_size, (1, 16), device=\"cuda\")\n\n        # Original model\n        model_orig = AutoModelForCausalLM.from_config(config).cuda().bfloat16()\n\n        with torch.no_grad():\n            out_orig = model_orig(input_ids)\n\n        # Patched model (same weights, interleaved for SonicMoE)\n        model_patched = AutoModelForCausalLM.from_config(config).cuda().bfloat16()\n        model_patched.load_state_dict(model_orig.state_dict())\n\n        patch_sonicmoe(\"qwen3_moe\")\n        _interleave_gate_up_weights(model_patched)\n\n        with torch.no_grad():\n            out_patched = model_patched(input_ids)\n\n        max_diff = (out_orig.logits - out_patched.logits).abs().max().item()\n        assert torch.allclose(\n            out_orig.logits, out_patched.logits, atol=1e-1, rtol=1e-1\n        ), f\"Output mismatch: max diff={max_diff:.6f}\"\n\n\nclass TestSonicMoEGradientCorrectness:\n    \"\"\"Compare gradients between original HuggingFace and SonicMoE-patched forward.\"\"\"\n\n    def teardown_method(self):\n        _unpatch_sonicmoe()\n\n    def test_gradients_match(self):\n        \"\"\"Verify all parameter gradients match between original and patched.\"\"\"\n        from transformers import AutoModelForCausalLM\n\n        from axolotl.integrations.kernels.sonicmoe.patch import patch_sonicmoe\n        from axolotl.integrations.kernels.sonicmoe.weight_converter import (\n            deinterleave_gate_up,\n        )\n\n        config = _create_tiny_qwen3_config()\n        input_ids = torch.randint(0, config.vocab_size, (1, 16), device=\"cuda\")\n\n        # ---------- Original model ----------\n        model_orig = AutoModelForCausalLM.from_config(config).cuda().bfloat16()\n        out_orig = model_orig(input_ids, labels=input_ids)\n        out_orig.loss.backward()\n        grads_orig = {\n            n: p.grad.float().clone()\n            for n, p in model_orig.named_parameters()\n            if p.grad is not None\n        }\n        loss_orig = out_orig.loss.item()\n\n        # ---------- SonicMoE-patched model (same weights, interleaved) ----------\n        model_patched = AutoModelForCausalLM.from_config(config).cuda().bfloat16()\n        model_patched.load_state_dict(model_orig.state_dict())\n\n        patch_sonicmoe(\"qwen3_moe\")\n        _interleave_gate_up_weights(model_patched)\n\n        out_patched = model_patched(input_ids, labels=input_ids)\n        out_patched.loss.backward()\n        grads_patched = {}\n        for n, p in model_patched.named_parameters():\n            if p.grad is None:\n                continue\n            g = p.grad.float().clone()\n            # gate_up_proj grads are in interleaved layout, de-interleave to match orig\n            if \"gate_up_proj\" in n:\n                g = deinterleave_gate_up(g)\n            grads_patched[n] = g\n        loss_patched = out_patched.loss.item()\n\n        # ---------- Compare ----------\n        assert abs(loss_orig - loss_patched) < 0.5, (\n            f\"Loss mismatch: orig={loss_orig:.4f}, patched={loss_patched:.4f}\"\n        )\n\n        # All parameters with gradients in original should have them in patched\n        missing = set(grads_orig.keys()) - set(grads_patched.keys())\n        assert not missing, f\"Missing gradients in patched model: {missing}\"\n\n        # Compare gradient values\n        # bf16 with different GEMM impls (cuBLAS vs CUTLASS) can diverge,\n        # so use generous tolerance: flag only if both rel >10% AND abs >1e-2\n        mismatches = []\n        for name in grads_orig:\n            if name not in grads_patched:\n                continue\n            g_orig = grads_orig[name]\n            g_patched = grads_patched[name]\n            max_diff = (g_orig - g_patched).abs().max().item()\n            rel_diff = max_diff / (g_orig.abs().max().item() + 1e-8)\n\n            if rel_diff > 0.1 and max_diff > 1e-2:\n                mismatches.append(\n                    f\"  {name}: max_abs_diff={max_diff:.6f}, rel_diff={rel_diff:.4f}\"\n                )\n\n        assert not mismatches, (\n            \"Gradient mismatches (rel_diff > 10% and abs_diff > 1e-2):\\n\"\n            + \"\\n\".join(mismatches)\n        )\n\n    def test_router_weights_receive_gradients(self):\n        \"\"\"Verify that router (gate) weights get non-zero gradients.\"\"\"\n        from transformers import AutoModelForCausalLM\n\n        from axolotl.integrations.kernels.sonicmoe.patch import patch_sonicmoe\n\n        config = _create_tiny_qwen3_config()\n        input_ids = torch.randint(0, config.vocab_size, (1, 16), device=\"cuda\")\n\n        model = AutoModelForCausalLM.from_config(config).cuda().bfloat16()\n        patch_sonicmoe(\"qwen3_moe\")\n        _interleave_gate_up_weights(model)\n\n        out = model(input_ids, labels=input_ids)\n        out.loss.backward()\n\n        gate_grads_found = False\n        for name, param in model.named_parameters():\n            if \"gate\" in name and \"weight\" in name:\n                gate_grads_found = True\n                assert param.grad is not None, f\"No gradient for router: {name}\"\n                assert param.grad.abs().max() > 0, f\"Zero gradient for router: {name}\"\n\n        assert gate_grads_found, \"No gate.weight parameters found in model\"\n\n\nclass TestSonicMoETrainingConvergence:\n    \"\"\"Verify loss decreases during training with SonicMoE.\"\"\"\n\n    def teardown_method(self):\n        _unpatch_sonicmoe()\n\n    def test_loss_decreases(self):\n        \"\"\"Run 30 training steps, verify loss decreases and no NaN/Inf.\"\"\"\n        from transformers import AutoModelForCausalLM\n\n        from axolotl.integrations.kernels.sonicmoe.patch import patch_sonicmoe\n\n        config = _create_tiny_qwen3_config()\n        input_ids = torch.randint(0, config.vocab_size, (2, 32), device=\"cuda\")\n\n        model = AutoModelForCausalLM.from_config(config).cuda().bfloat16()\n        patch_sonicmoe(\"qwen3_moe\")\n        _interleave_gate_up_weights(model)\n\n        optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)\n        losses = []\n\n        for step in range(30):\n            out = model(input_ids, labels=input_ids)\n            loss = out.loss\n            assert not math.isnan(loss.item()), f\"NaN loss at step {step}\"\n            assert not math.isinf(loss.item()), f\"Inf loss at step {step}\"\n            losses.append(loss.item())\n\n            loss.backward()\n            optimizer.step()\n            optimizer.zero_grad()\n\n        assert losses[-1] < losses[0], (\n            f\"Loss did not decrease: first={losses[0]:.4f}, last={losses[-1]:.4f}\"\n        )\n\n    def test_expert_weights_update(self):\n        \"\"\"Verify expert weights change during training (not frozen).\"\"\"\n        from transformers import AutoModelForCausalLM\n\n        from axolotl.integrations.kernels.sonicmoe.patch import patch_sonicmoe\n\n        config = _create_tiny_qwen3_config()\n        input_ids = torch.randint(0, config.vocab_size, (2, 32), device=\"cuda\")\n\n        model = AutoModelForCausalLM.from_config(config).cuda().bfloat16()\n        patch_sonicmoe(\"qwen3_moe\")\n        _interleave_gate_up_weights(model)\n\n        # Snapshot expert weights before training\n        expert_weights_before = {}\n        for name, param in model.named_parameters():\n            if \"experts\" in name:\n                expert_weights_before[name] = param.data.clone()\n\n        assert expert_weights_before, \"No expert parameters found\"\n\n        optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)\n        for _ in range(5):\n            out = model(input_ids, labels=input_ids)\n            out.loss.backward()\n            optimizer.step()\n            optimizer.zero_grad()\n\n        # Check that expert weights changed\n        changed = 0\n        for name, param in model.named_parameters():\n            if name in expert_weights_before:\n                if not torch.equal(param.data, expert_weights_before[name]):\n                    changed += 1\n\n        assert changed > 0, \"No expert weights changed after 5 training steps\"\n"
  },
  {
    "path": "tests/e2e/kernels/test_geglu.py",
    "content": "\"\"\"Tests for GEGLU activation function Triton kernels.\"\"\"\n\nimport pytest\nimport torch\nimport torch.nn.functional as F\n\nfrom axolotl.kernels.geglu import geglu_backward, geglu_forward\n\n\ndef test_geglu_forward_shape():\n    \"\"\"Test that GEGLU forward pass preserves expected shapes.\"\"\"\n    batch, seq_len, hidden_dim = 2, 3, 64\n    gate = torch.randn(batch, seq_len, hidden_dim, device=\"cuda\")\n    up = torch.randn(batch, seq_len, hidden_dim, device=\"cuda\")\n\n    out = geglu_forward(gate, up)\n    assert out.shape == (batch, seq_len, hidden_dim)\n    assert out.dtype == gate.dtype\n    assert out.device == gate.device\n\n\n@pytest.mark.flaky(retries=1, delay=5)\n@pytest.mark.parametrize(\n    \"torch_seed\",\n    [0, 42],\n)\ndef test_geglu_forward_values(torch_seed):\n    \"\"\"Test GEGLU forward pass matches PyTorch reference implementation.\"\"\"\n    torch.manual_seed(torch_seed)\n\n    gate = torch.randn(2, 3, 64, device=\"cuda\")\n    up = torch.randn(2, 3, 64, device=\"cuda\")\n\n    # Custom implementation\n    triton_out = geglu_forward(gate.clone(), up.clone())\n\n    # PyTorch reference\n    torch_out = F.gelu(gate) * up\n\n    assert torch.allclose(triton_out, torch_out, rtol=1e-3)\n\n\n@pytest.mark.flaky(retries=1, delay=5)\n@pytest.mark.parametrize(\n    \"torch_seed\",\n    [0, 42],\n)\ndef test_geglu_backward(torch_seed):\n    \"\"\"Test GEGLU backward pass matches PyTorch autograd.\"\"\"\n    torch.manual_seed(torch_seed)\n\n    gate = torch.randn(2, 3, 64, device=\"cuda\", requires_grad=True)\n    up = torch.randn(2, 3, 64, device=\"cuda\", requires_grad=True)\n    grad_output = torch.randn(2, 3, 64, device=\"cuda\")\n\n    # PyTorch reference - compute intermediates\n    gelu_gate = F.gelu(gate)\n    torch_out = gelu_gate * up\n    torch_out.backward(grad_output)\n\n    # Custom backward pass\n    gate_clone = gate.clone().detach()\n    up_clone = up.clone().detach()\n    grad_output_clone = grad_output.clone()\n\n    h, grad_gate, grad_up = geglu_backward(grad_output_clone, gate_clone, up_clone)\n\n    # Compare outputs and gradients\n    assert torch.allclose(h, torch_out, rtol=1e-3)\n    assert torch.allclose(grad_gate, gate.grad, rtol=1e-3)\n    assert torch.allclose(grad_up, up.grad, rtol=1e-3)\n\n\ndef test_geglu_inplace_preservation():\n    \"\"\"Test that GEGLU backward doesn't modify original tensors unexpectedly.\"\"\"\n    gate = torch.randn(2, 3, 64, device=\"cuda\")\n    up = torch.randn(2, 3, 64, device=\"cuda\")\n    grad_output = torch.randn(2, 3, 64, device=\"cuda\")\n\n    gate_copy = gate.clone()\n    up_copy = up.clone()\n    grad_copy = grad_output.clone()\n\n    geglu_backward(grad_output, gate, up)\n\n    assert not torch.equal(gate, gate_copy), \"Gate should be modified in-place\"\n    assert not torch.equal(up, up_copy), \"Up should be modified in-place\"\n    assert not torch.equal(grad_output, grad_copy), (\n        \"Grad output should be modified in-place\"\n    )\n"
  },
  {
    "path": "tests/e2e/kernels/test_lora.py",
    "content": "\"\"\"Tests for LoRA custom autograd.\"\"\"\n\nimport pytest\nimport torch\nfrom bitsandbytes.functional import QuantState\nfrom torch import nn\n\nfrom axolotl.kernels.geglu import geglu_backward, geglu_forward\nfrom axolotl.kernels.lora import (\n    LoRA_MLP,\n    LoRA_O,\n    LoRA_QKV,\n    apply_lora_mlp_geglu,\n    apply_lora_mlp_swiglu,\n    get_lora_parameters,\n    matmul_lora,\n)\nfrom axolotl.kernels.swiglu import swiglu_backward, swiglu_forward\n\n\n@pytest.fixture\ndef mock_quantstate():\n    \"\"\"Creates a mock QuantState for testing\"\"\"\n    shape = (64, 64)\n    n_blocks = shape[0]  # Assuming blockwise quantization along first dimension\n\n    # Create nested state first\n    nested_state = QuantState(\n        absmax=torch.ones(n_blocks, device=\"cuda\"),  # One value per block\n        shape=shape,\n        code=torch.randint(0, 15, shape, device=\"cuda\"),  # NF4 range is 0-15\n        dtype=torch.float16,\n        blocksize=64,\n        quant_type=\"nf4\",\n        offset=None,\n        state2=None,\n    )\n\n    # Create main state with nested state\n    return QuantState(\n        absmax=torch.ones(n_blocks, device=\"cuda\"),\n        shape=shape,\n        code=torch.randint(0, 15, shape, device=\"cuda\"),\n        dtype=torch.float16,\n        blocksize=64,\n        quant_type=\"nf4\",\n        offset=torch.zeros(n_blocks, dtype=torch.int32, device=\"cuda\"),\n        state2=nested_state,\n    )\n\n\n@pytest.fixture\ndef sample_tensors():\n    \"\"\"Creates sample tensors for testing\"\"\"\n    torch.manual_seed(42)\n    batch_size, seq_len, hidden_dim = 2, 3, 64\n    rank = 8\n    out_dim = hidden_dim\n\n    return {\n        \"X\": torch.randn(\n            batch_size, seq_len, hidden_dim, device=\"cuda\", dtype=torch.float16\n        ),\n        \"W\": torch.randn(out_dim, hidden_dim, device=\"cuda\", dtype=torch.float16),\n        \"b\": torch.randn(out_dim, device=\"cuda\", dtype=torch.float16),\n        \"scale\": 0.5,\n        \"shapes\": {\n            \"batch\": batch_size,\n            \"seq\": seq_len,\n            \"hidden\": hidden_dim,\n            \"out\": out_dim,\n            \"rank\": rank,\n        },\n    }\n\n\n@pytest.fixture\ndef mock_proj():\n    \"\"\"Creates a mock projection module for testing.\"\"\"\n\n    class MockProj(nn.Module):\n        \"\"\"Mock projection class.\"\"\"\n\n        def __init__(self, in_features=64, out_features=128, rank=8):\n            super().__init__()\n            self.base_layer = nn.Linear(in_features, out_features)\n            self.base_layer.to(\"cuda\")\n            self.lora_A = nn.ModuleDict(\n                {\"default\": nn.Linear(in_features, rank, bias=False).to(\"cuda\")}\n            )\n            self.lora_B = nn.ModuleDict(\n                {\"default\": nn.Linear(rank, out_features, bias=False).to(\"cuda\")}\n            )\n            self.scaling = {\"default\": 0.5}\n            self.active_adapter = \"default\"\n            self.disable_adapters = False\n            self.merged = False\n\n    return MockProj()\n\n\ndef test_get_lora_parameters(mock_proj):\n    \"\"\"Tests get_lora_parameters function\"\"\"\n    # Test with LoRA enabled\n    W, b, _, A, B, s = get_lora_parameters(mock_proj)\n\n    assert isinstance(W, torch.Tensor)\n    assert W.shape == (128, 64)\n    assert b.shape == (128,)\n    assert A.shape == (8, 64)\n    assert B.shape == (128, 8)\n    assert s == 0.5\n\n    # Test with LoRA disabled\n    mock_proj.disable_adapters = True\n    W, b, _, A, B, s = get_lora_parameters(mock_proj)\n    assert A is None and B is None and s is None\n\n    # Test with merged state\n    mock_proj.disable_adapters = False\n    mock_proj.merged = True\n    W, b, _, A, B, s = get_lora_parameters(mock_proj)\n    assert A is None and B is None and s is None\n\n\ndef test_matmul_lora(sample_tensors):\n    \"\"\"Tests matmul_lora function\"\"\"\n    X = sample_tensors[\"X\"]\n    W = sample_tensors[\"W\"]\n    b = sample_tensors[\"b\"]\n    scale = sample_tensors[\"scale\"]\n\n    shapes = sample_tensors[\"shapes\"]\n    hidden_dim = shapes[\"hidden\"]\n    out_dim = shapes[\"out\"]\n    rank = shapes[\"rank\"]\n\n    A = torch.randn(rank, hidden_dim, device=\"cuda\", dtype=torch.float16)\n    B = torch.randn(out_dim, rank, device=\"cuda\", dtype=torch.float16)\n\n    # Test base matmul\n    out1 = matmul_lora(X, W, b, None, None, None, None)\n    matmul = torch.matmul(X, W.t())\n    expected1 = matmul + b\n    assert torch.allclose(out1, expected1, rtol=1e-3)\n\n    # Test with LoRA\n    out2 = matmul_lora(X, W, b, None, A, B, scale)\n    lora_term = scale * torch.matmul(torch.matmul(X, A.t()), B.t())\n    expected2 = matmul + lora_term + b\n    assert torch.allclose(out2, expected2, rtol=1e-3)\n\n    # Test 3D input reshaping\n    X_3d = X.clone()\n    out3 = matmul_lora(X_3d, W, b, None, A, B, scale)\n    assert out3.shape == (X.shape[0], X.shape[1], W.shape[0])\n\n\n@pytest.mark.parametrize(\n    \"activation_forward,activation_backward\",\n    [(swiglu_forward, swiglu_backward), (geglu_forward, geglu_backward)],\n)\ndef test_lora_mlp_direct(sample_tensors, activation_forward, activation_backward):\n    \"\"\"Tests LoRA_MLP directly with different activation functions\"\"\"\n    X = sample_tensors[\"X\"]\n    shapes = sample_tensors[\"shapes\"]\n    hidden_dim = shapes[\"hidden\"]\n    out_dim = shapes[\"out\"]\n\n    # Create linear layers\n    gate_proj = nn.Linear(hidden_dim, out_dim).to(device=\"cuda\", dtype=torch.float16)\n    up_proj = nn.Linear(hidden_dim, out_dim).to(device=\"cuda\", dtype=torch.float16)\n    down_proj = nn.Linear(out_dim, hidden_dim).to(device=\"cuda\", dtype=torch.float16)\n\n    # Test SwiGLU path\n    X.requires_grad = True\n    output = LoRA_MLP.apply(\n        X,\n        gate_proj.weight,\n        gate_proj.bias,\n        None,  # gate_quant\n        None,  # gate_A\n        None,  # gate_B\n        None,  # gate_scale\n        up_proj.weight,\n        up_proj.bias,\n        None,  # up_quant\n        None,  # up_A\n        None,  # up_B\n        None,  # up_scale\n        down_proj.weight,\n        down_proj.bias,\n        None,  # down_quant\n        None,  # down_A\n        None,  # down_B\n        None,  # down_scale\n        activation_forward,\n        activation_backward,\n        True,  # inplace\n    )\n\n    assert output.shape == X.shape\n    assert not torch.isnan(output).any()\n\n    # Test backward pass\n    loss = output.sum()\n    loss.backward()\n    assert X.grad is not None\n    assert not torch.isnan(X.grad).any()\n\n\n@pytest.mark.parametrize(\n    \"activation_forward,activation_backward\",\n    [(swiglu_forward, swiglu_backward), (geglu_forward, geglu_backward)],\n)\ndef test_lora_mlp_with_adapters(\n    sample_tensors, activation_forward, activation_backward\n):\n    \"\"\"Tests LoRA_MLP with LoRA adapters\"\"\"\n    X = sample_tensors[\"X\"]\n    shapes = sample_tensors[\"shapes\"]\n    hidden_dim = shapes[\"hidden\"]\n    out_dim = shapes[\"out\"]\n    rank = shapes[\"rank\"]\n\n    # Create LoRA components\n    gate_A = torch.randn(rank, hidden_dim, device=\"cuda\", dtype=torch.float16)\n    gate_B = torch.randn(out_dim, rank, device=\"cuda\", dtype=torch.float16)\n    up_A = torch.randn(rank, hidden_dim, device=\"cuda\", dtype=torch.float16)\n    up_B = torch.randn(out_dim, rank, device=\"cuda\", dtype=torch.float16)\n    down_A = torch.randn(rank, out_dim, device=\"cuda\", dtype=torch.float16)\n    down_B = torch.randn(hidden_dim, rank, device=\"cuda\", dtype=torch.float16)\n    scale = 0.5\n\n    gate_proj = nn.Linear(hidden_dim, out_dim).to(device=\"cuda\", dtype=torch.float16)\n    up_proj = nn.Linear(hidden_dim, out_dim).to(device=\"cuda\", dtype=torch.float16)\n    down_proj = nn.Linear(out_dim, hidden_dim).to(device=\"cuda\", dtype=torch.float16)\n\n    X.requires_grad = True\n    gate_A.requires_grad = True\n    gate_B.requires_grad = True\n    up_A.requires_grad = True\n    up_B.requires_grad = True\n    down_A.requires_grad = True\n    down_B.requires_grad = True\n\n    # Forward pass with adapters\n    output = LoRA_MLP.apply(\n        X,\n        gate_proj.weight,\n        gate_proj.bias,\n        None,\n        gate_A,\n        gate_B,\n        scale,\n        up_proj.weight,\n        up_proj.bias,\n        None,\n        up_A,\n        up_B,\n        scale,\n        down_proj.weight,\n        down_proj.bias,\n        None,\n        down_A,\n        down_B,\n        scale,\n        activation_forward,\n        activation_backward,\n        True,\n    )\n\n    assert output.shape == X.shape\n    assert not torch.isnan(output).any()\n\n    # Test backward pass\n    loss = output.sum()\n    loss.backward()\n\n    # Check all gradients\n    assert X.grad is not None\n    assert gate_A.grad is not None\n    assert gate_B.grad is not None\n    assert up_A.grad is not None\n    assert up_B.grad is not None\n    assert down_A.grad is not None\n    assert down_B.grad is not None\n\n    assert not torch.isnan(X.grad).any()\n    assert not torch.isnan(gate_A.grad).any()\n    assert not torch.isnan(gate_B.grad).any()\n    assert not torch.isnan(up_A.grad).any()\n    assert not torch.isnan(up_B.grad).any()\n    assert not torch.isnan(down_A.grad).any()\n    assert not torch.isnan(down_B.grad).any()\n\n\ndef test_lora_qkv(sample_tensors):\n    \"\"\"Tests LoRA QKV implementation with and without adapters\"\"\"\n    X = sample_tensors[\"X\"]\n    shapes = sample_tensors[\"shapes\"]\n    hidden_dim = shapes[\"hidden\"]\n    rank = shapes[\"rank\"]\n\n    # Create base weights\n    q_weight = torch.randn(hidden_dim, hidden_dim, device=\"cuda\", dtype=torch.float16)\n    k_weight = torch.randn(hidden_dim, hidden_dim, device=\"cuda\", dtype=torch.float16)\n    v_weight = torch.randn(hidden_dim, hidden_dim, device=\"cuda\", dtype=torch.float16)\n\n    # Create LoRA matrices\n    q_A = torch.randn(\n        rank, hidden_dim, device=\"cuda\", dtype=torch.float16, requires_grad=True\n    )\n    q_B = torch.randn(\n        hidden_dim, rank, device=\"cuda\", dtype=torch.float16, requires_grad=True\n    )\n    k_A = torch.randn(\n        rank, hidden_dim, device=\"cuda\", dtype=torch.float16, requires_grad=True\n    )\n    k_B = torch.randn(\n        hidden_dim, rank, device=\"cuda\", dtype=torch.float16, requires_grad=True\n    )\n    v_A = torch.randn(\n        rank, hidden_dim, device=\"cuda\", dtype=torch.float16, requires_grad=True\n    )\n    v_B = torch.randn(\n        hidden_dim, rank, device=\"cuda\", dtype=torch.float16, requires_grad=True\n    )\n    scale = 0.5\n\n    X.requires_grad = True\n\n    # Test without LoRA adapters\n\n    Q1, K1, V1 = LoRA_QKV.apply(\n        X,\n        q_weight,\n        None,\n        None,\n        None,\n        None,\n        None,\n        k_weight,\n        None,\n        None,\n        None,\n        None,\n        None,\n        v_weight,\n        None,\n        None,\n        None,\n        None,\n        None,\n        True,\n    )\n\n    assert Q1.shape == K1.shape == V1.shape == X.shape\n    loss1 = (Q1 + K1 + V1).sum()\n    loss1.backward()\n    assert X.grad is not None\n\n    # Clear gradients\n    X.grad = None\n\n    # Test with LoRA adapters\n    Q2, K2, V2 = LoRA_QKV.apply(\n        X,\n        q_weight,\n        None,\n        None,\n        q_A,\n        q_B,\n        scale,\n        k_weight,\n        None,\n        None,\n        k_A,\n        k_B,\n        scale,\n        v_weight,\n        None,\n        None,\n        v_A,\n        v_B,\n        scale,\n        True,\n    )\n\n    assert Q2.shape == K2.shape == V2.shape == X.shape\n    loss2 = (Q2 + K2 + V2).sum()\n    loss2.backward()\n\n    # Check gradients\n    assert X.grad is not None\n    assert q_A.grad is not None\n    assert q_B.grad is not None\n    assert k_A.grad is not None\n    assert k_B.grad is not None\n    assert v_A.grad is not None\n    assert v_B.grad is not None\n\n    # Check for NaN values\n    assert not torch.isnan(X.grad).any()\n    assert not torch.isnan(q_A.grad).any()\n    assert not torch.isnan(q_B.grad).any()\n    assert not torch.isnan(k_A.grad).any()\n    assert not torch.isnan(k_B.grad).any()\n    assert not torch.isnan(v_A.grad).any()\n    assert not torch.isnan(v_B.grad).any()\n\n\ndef test_lora_o(sample_tensors):\n    \"\"\"Tests LoRA output projection\"\"\"\n    X = sample_tensors[\"X\"]\n    W = sample_tensors[\"W\"]\n    b = sample_tensors[\"b\"]\n    scale = sample_tensors[\"scale\"]\n\n    shapes = sample_tensors[\"shapes\"]\n    hidden_dim = shapes[\"hidden\"]\n    out_dim = shapes[\"out\"]\n    rank = shapes[\"rank\"]\n\n    A = torch.randn(rank, hidden_dim, device=\"cuda\", dtype=torch.float16)\n    B = torch.randn(out_dim, rank, device=\"cuda\", dtype=torch.float16)\n\n    # Test forward pass\n    X.requires_grad = True\n    output = LoRA_O.apply(X, W, b, None, A, B, scale)\n\n    assert output.shape == (X.shape[0], X.shape[1], W.shape[0])\n\n    # Test backward pass\n    loss = output.sum()\n    loss.backward()\n    assert X.grad is not None\n\n\ndef test_with_quantization(sample_tensors, mock_quantstate):\n    \"\"\"Tests LoRA with quantized weights\"\"\"\n    X = sample_tensors[\"X\"]  # [batch, seq, hidden]\n    W = sample_tensors[\"W\"]  # [out, hidden]\n    b = sample_tensors[\"b\"]  # [out]\n    scale = 0.5\n\n    shapes = sample_tensors[\"shapes\"]\n    hidden_dim = shapes[\"hidden\"]\n    out_dim = shapes[\"out\"]\n    rank = shapes[\"rank\"]\n\n    A = torch.randn(rank, hidden_dim, device=\"cuda\", dtype=torch.float16)\n    B = torch.randn(out_dim, rank, device=\"cuda\", dtype=torch.float16)\n\n    # Test matmul with quantization\n    out = matmul_lora(X, W, b, mock_quantstate, A, B, scale)\n    assert out.shape == (X.shape[0], X.shape[1], W.shape[0])\n    assert not torch.isnan(out).any()\n\n    # Test with different batch sizes\n    X2 = torch.randn(4, 6, hidden_dim, device=\"cuda\", dtype=torch.float16)\n    out2 = matmul_lora(X2, W, b, mock_quantstate, A, B, scale)\n    assert out2.shape == (4, 6, W.shape[0])\n    assert not torch.isnan(out2).any()\n\n\n@pytest.mark.parametrize(\n    \"batch,seq,hidden,rank,out\",\n    [\n        (1, 1, 32, 4, 64),\n        (2, 3, 64, 8, 128),\n        (4, 5, 128, 16, 256),\n    ],\n)\ndef test_shapes_and_dimensions(batch, seq, hidden, rank, out):\n    \"\"\"Tests various input shapes and dimensions\"\"\"\n    X = torch.randn(batch, seq, hidden, device=\"cuda\", dtype=torch.float16)\n    W = torch.randn(out, hidden, device=\"cuda\", dtype=torch.float16)\n    b = torch.randn(out, device=\"cuda\", dtype=torch.float16)\n    A = torch.randn(rank, hidden, device=\"cuda\", dtype=torch.float16)\n    B = torch.randn(out, rank, device=\"cuda\", dtype=torch.float16)\n    scale = 0.5\n\n    result = matmul_lora(X, W, b, None, A, B, scale)\n    assert result.shape == (batch, seq, out)\n\n\ndef test_gradient_flow(sample_tensors):\n    \"\"\"Tests gradient flow through LoRA layers\"\"\"\n    X = sample_tensors[\"X\"].clone()\n    W = sample_tensors[\"W\"].clone()\n    b = sample_tensors[\"b\"].clone()\n    scale = sample_tensors[\"scale\"]\n\n    shapes = sample_tensors[\"shapes\"]\n    hidden_dim = shapes[\"hidden\"]\n    out_dim = shapes[\"out\"]\n    rank = shapes[\"rank\"]\n\n    A = torch.randn(rank, hidden_dim, device=\"cuda\", dtype=torch.float16)\n    B = torch.randn(out_dim, rank, device=\"cuda\", dtype=torch.float16)\n\n    X.requires_grad = True\n    A.requires_grad = True\n    B.requires_grad = True\n\n    # Forward pass\n    out = matmul_lora(X, W, b, None, A, B, scale)\n    loss = out.sum()\n\n    # Backward pass\n    loss.backward()\n\n    assert X.grad is not None\n    assert A.grad is not None\n    assert B.grad is not None\n    assert not torch.isnan(X.grad).any()\n    assert not torch.isnan(A.grad).any()\n    assert not torch.isnan(B.grad).any()\n\n\n@pytest.mark.parametrize(\n    \"apply_function\",\n    [apply_lora_mlp_swiglu, apply_lora_mlp_geglu],\n)\ndef test_inplace_operations(sample_tensors, apply_function):\n    \"\"\"Tests inplace operation behavior\"\"\"\n    X = sample_tensors[\"X\"]\n    shapes = sample_tensors[\"shapes\"]\n\n    # Create MLP with both inplace=True and inplace=False\n    mlp = type(\n        \"MLPModule\",\n        (),\n        {\n            \"gate_proj\": nn.Linear(shapes[\"hidden\"], shapes[\"out\"]).to(\n                device=\"cuda\", dtype=torch.float16\n            ),\n            \"up_proj\": nn.Linear(shapes[\"hidden\"], shapes[\"out\"]).to(\n                device=\"cuda\", dtype=torch.float16\n            ),\n            \"down_proj\": nn.Linear(shapes[\"out\"], shapes[\"hidden\"]).to(\n                device=\"cuda\", dtype=torch.float16\n            ),\n        },\n    )\n\n    out1 = apply_function(mlp, X.clone(), inplace=True)\n    out2 = apply_function(mlp, X.clone(), inplace=False)\n\n    assert torch.allclose(out1, out2, rtol=1e-3)\n"
  },
  {
    "path": "tests/e2e/kernels/test_quantize.py",
    "content": "\"\"\"Tests for quantization utility functions.\"\"\"\n\nimport torch\nfrom bitsandbytes.functional import QuantState\n\nfrom axolotl.kernels.quantize import dequantize\n\n\ndef test_dequantize_null_state():\n    \"\"\"Test that dequantize returns input unchanged when quant_state is None\"\"\"\n    W = torch.randn(32, 32)\n    assert torch.equal(dequantize(W, None), W)\n\n\ndef test_dequantize_shape_preservation():\n    \"\"\"Test that dequantization preserves expected shapes\"\"\"\n    shape = (32, 32)\n    W = torch.randn(shape, device=\"cuda\")\n\n    quant_state = QuantState(\n        absmax=torch.ones(shape[0], device=\"cuda\"),\n        shape=shape,\n        code=torch.randint(0, 15, shape, device=\"cuda\"),\n        dtype=torch.float16,\n        blocksize=32,\n        quant_type=\"nf4\",\n        offset=torch.zeros(shape[0], dtype=torch.int32, device=\"cuda\"),\n        state2=QuantState(\n            absmax=torch.ones(shape[0], device=\"cuda\"),\n            shape=shape,\n            code=torch.randint(0, 15, shape, device=\"cuda\"),\n            dtype=torch.float16,\n            blocksize=32,\n            quant_type=\"nf4\",\n            offset=None,\n            state2=None,\n        ),\n    )\n\n    result = dequantize(W, quant_state)\n    assert result.shape == shape\n    assert result.dtype == torch.float16\n    assert result.device == W.device\n\n\ndef test_dequantize_transposed():\n    \"\"\"Test that transposed input produces transposed output\"\"\"\n    shape = (32, 32)\n    W = torch.randn(1, shape[1], device=\"cuda\")  # Transposed input\n\n    quant_state = QuantState(\n        absmax=torch.ones(1),\n        shape=shape,\n        code=torch.randint(0, 15, shape),\n        dtype=torch.float16,\n        blocksize=32,\n        quant_type=\"nf4\",\n        offset=torch.zeros(1, dtype=torch.int32),\n        state2=QuantState(\n            absmax=torch.ones(1),\n            shape=shape,\n            code=torch.randint(0, 15, shape),\n            dtype=torch.float16,\n            blocksize=32,\n            quant_type=\"nf4\",\n            offset=None,\n            state2=None,\n        ),\n    )\n\n    result = dequantize(W, quant_state)\n    assert result.shape[0] == shape[0]\n\n\ndef test_dequantize_output_tensor():\n    \"\"\"Test dequantization with provided output tensor\"\"\"\n    shape = (32, 32)\n    W = torch.randn(shape, device=\"cuda\")\n    out = torch.empty(shape, dtype=torch.float16, device=\"cuda\")\n\n    quant_state = QuantState(\n        absmax=torch.ones(shape[0]),\n        shape=shape,\n        code=torch.randint(0, 15, shape),\n        dtype=torch.float16,\n        blocksize=32,\n        quant_type=\"nf4\",\n        offset=torch.zeros(shape[0], dtype=torch.int32),\n        state2=QuantState(\n            absmax=torch.ones(shape[0]),\n            shape=shape,\n            code=torch.randint(0, 15, shape),\n            dtype=torch.float16,\n            blocksize=32,\n            quant_type=\"nf4\",\n            offset=None,\n            state2=None,\n        ),\n    )\n\n    result = dequantize(W, quant_state, out=out)\n    assert result is out\n"
  },
  {
    "path": "tests/e2e/kernels/test_swiglu.py",
    "content": "\"\"\"Tests for SwiGLU activation function Triton kernels.\"\"\"\n\nimport torch\nimport torch.nn.functional as F\n\nfrom axolotl.kernels.swiglu import swiglu_backward, swiglu_forward\n\n\ndef test_swiglu_forward_shape():\n    \"\"\"Test that SwiGLU forward pass preserves expected shapes\"\"\"\n    batch, seq_len, hidden_dim = 2, 3, 64\n    gate = torch.randn(batch, seq_len, hidden_dim, device=\"cuda\")\n    up = torch.randn(batch, seq_len, hidden_dim, device=\"cuda\")\n\n    out = swiglu_forward(gate, up)\n    assert out.shape == (batch, seq_len, hidden_dim)\n    assert out.dtype == gate.dtype\n    assert out.device == gate.device\n\n\ndef test_swiglu_forward_values():\n    \"\"\"Test SwiGLU forward pass matches PyTorch reference implementation\"\"\"\n    gate = torch.randn(2, 3, 64, device=\"cuda\")\n    up = torch.randn(2, 3, 64, device=\"cuda\")\n\n    # Custom implementation\n    triton_out = swiglu_forward(gate.clone(), up.clone())\n\n    # PyTorch reference\n    torch_out = F.silu(gate) * up\n\n    assert torch.allclose(triton_out, torch_out, rtol=1e-3)\n\n\ndef test_swiglu_backward():\n    \"\"\"Test SwiGLU backward pass matches PyTorch autograd\"\"\"\n    gate = torch.randn(2, 3, 64, device=\"cuda\", requires_grad=True)\n    up = torch.randn(2, 3, 64, device=\"cuda\", requires_grad=True)\n    grad_output = torch.randn(2, 3, 64, device=\"cuda\")\n\n    # PyTorch reference - compute intermediates\n    silu_gate = F.silu(gate)\n    torch_out = silu_gate * up\n    torch_out.backward(grad_output)\n\n    # Custom backward pass\n    gate_clone = gate.clone().detach()\n    up_clone = up.clone().detach()\n    grad_output_clone = grad_output.clone()\n\n    h, our_grad_gate, our_grad_up = swiglu_backward(\n        grad_output_clone, gate_clone, up_clone\n    )\n\n    # Compare outputs and gradients\n    assert torch.allclose(h, torch_out, rtol=1e-3)\n    assert torch.allclose(our_grad_gate, gate.grad, rtol=1e-3)\n    assert torch.allclose(our_grad_up, up.grad, rtol=1e-3)\n\n\ndef test_swiglu_inplace_preservation():\n    \"\"\"Test that SwiGLU backward doesn't modify original tensors unexpectedly\"\"\"\n    gate = torch.randn(2, 3, 64, device=\"cuda\")\n    up = torch.randn(2, 3, 64, device=\"cuda\")\n    grad_output = torch.randn(2, 3, 64, device=\"cuda\")\n\n    gate_copy = gate.clone()\n    up_copy = up.clone()\n    grad_copy = grad_output.clone()\n\n    swiglu_backward(grad_output, gate, up)\n\n    assert not torch.equal(gate, gate_copy), \"Gate should be modified in-place\"\n    assert not torch.equal(up, up_copy), \"Up should be modified in-place\"\n    assert not torch.equal(grad_output, grad_copy), (\n        \"Grad output should be modified in-place\"\n    )\n"
  },
  {
    "path": "tests/e2e/multigpu/__init__.py",
    "content": ""
  },
  {
    "path": "tests/e2e/multigpu/patched/__init__.py",
    "content": ""
  },
  {
    "path": "tests/e2e/multigpu/patched/test_sp.py",
    "content": "\"\"\"E2E tests for sequence parallelism\"\"\"\n\nfrom pathlib import Path\n\nimport pytest\nimport yaml\nfrom accelerate.test_utils import execute_subprocess_async\nfrom transformers.testing_utils import get_torch_dist_unique_port\n\nfrom axolotl.utils.dict import DictDefault\n\nfrom ...utils import check_tensorboard\n\n\nclass TestSequenceParallelism:\n    \"\"\"Test case for training with sequence parallelism enabled\"\"\"\n\n    def _run_sequence_parallel_test(\n        self,\n        temp_dir,\n        sample_packing=True,\n        micro_batch_size=1,\n        pad_to_sequence_len=True,\n        ring_attn_func=None,\n        threshold=2.0,\n    ):\n        \"\"\"Helper method to run sequence parallel tests with different configurations\"\"\"\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"load_in_8bit\": False,\n                \"load_in_4bit\": True,\n                \"strict\": False,\n                \"sequence_len\": 2048,\n                \"adapter\": \"qlora\",\n                \"sample_packing\": sample_packing,\n                \"eval_sample_packing\": sample_packing,\n                \"pad_to_sequence_len\": pad_to_sequence_len,\n                \"lora_r\": 8,\n                \"lora_alpha\": 16,\n                \"lora_dropout\": 0.05,\n                \"lora_target_linear\": True,\n                \"lora_modules_to_save\": [\"embed_tokens\", \"lm_head\"],\n                \"special_tokens\": {\"pad_token\": \"<|endoftext|>\"},\n                \"datasets\": [\n                    {\n                        \"path\": \"tatsu-lab/alpaca\",\n                        \"type\": \"alpaca\",\n                        \"split\": \"train[:10%]\",\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"max_steps\": 8,\n                \"micro_batch_size\": micro_batch_size,\n                \"gradient_accumulation_steps\": 2,\n                \"output_dir\": temp_dir,\n                \"dataset_prepared_path\": temp_dir + \"/last_run_prepared\",\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_8bit\",\n                \"lr_scheduler\": \"cosine\",\n                \"flash_attention\": True,\n                \"loss_watchdog_threshold\": 5.0,\n                \"loss_watchdog_patience\": 3,\n                \"bf16\": \"auto\",\n                \"warmup_steps\": 1,\n                \"saves_per_epoch\": 1,\n                \"logging_steps\": 1,\n                \"weight_decay\": 0.0,\n                \"use_tensorboard\": True,\n                \"context_parallel_size\": 2,\n                \"ring_attn_func\": ring_attn_func,\n                \"save_first_step\": False,\n            }\n        )\n\n        # write cfg to yaml file\n        Path(temp_dir).mkdir(parents=True, exist_ok=True)\n        with open(Path(temp_dir) / \"config.yaml\", \"w\", encoding=\"utf-8\") as fout:\n            fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))\n\n        execute_subprocess_async(\n            [\n                \"accelerate\",\n                \"launch\",\n                \"--num-processes\",\n                \"2\",\n                \"--main_process_port\",\n                f\"{get_torch_dist_unique_port()}\",\n                \"-m\",\n                \"axolotl.cli.train\",\n                str(Path(temp_dir) / \"config.yaml\"),\n            ]\n        )\n\n        check_tensorboard(\n            temp_dir + \"/runs\",\n            \"train/train_loss\",\n            threshold,\n            \"Train Loss (%s) is too high\",\n        )\n\n    @pytest.mark.parametrize(\n        \"sample_packing, micro_batch_size, pad_to_sequence_len, ring_attn_func, threshold\",\n        [\n            (True, 1, True, None, 2.5),  # defaults to varlen_llama3 ring_attn_func\n            (False, 2, True, None, 2.5),  # defaults to batch_ring ring_attn_func\n            # (False, 2, True, \"batch_zigzag\", 2.5),\n            # (False, 2, False, None, 2.65),  # defaults to batch_ring ring_attn_func\n        ],\n        ids=[\n            \"sample_packing, varlen_llama3 ring_attn_func\",\n            \"no sample_packing, pad_to_sequence_len, batch_ring ring_attn_func\",\n            # \"no sample_packing, no pad_to_sequence_len, batch_zigzag ring_attn_func\",\n            # \"no sample_packing, no pad_to_sequence_len, batch_ring ring_attn_func\",\n        ],\n    )\n    def test_sequence_parallel_training(\n        self,\n        temp_dir,\n        sample_packing,\n        micro_batch_size,\n        pad_to_sequence_len,\n        ring_attn_func,\n        threshold,\n    ):\n        \"\"\"Test sequence parallel training with different configurations\"\"\"\n        self._run_sequence_parallel_test(\n            temp_dir,\n            sample_packing=sample_packing,\n            micro_batch_size=micro_batch_size,\n            pad_to_sequence_len=pad_to_sequence_len,\n            ring_attn_func=ring_attn_func,\n            threshold=threshold,\n        )\n"
  },
  {
    "path": "tests/e2e/multigpu/solo/__init__.py",
    "content": "# Tests under this directory should get run \"solo\" on their own as they\n# seem to cause issues when run in the same batch as other tests.\n"
  },
  {
    "path": "tests/e2e/multigpu/solo/test_flex.py",
    "content": "\"\"\"\nE2E tests for multigpu lora tinyllama\n\"\"\"\n\nfrom pathlib import Path\n\nimport pytest\nimport yaml\nfrom accelerate.test_utils import execute_subprocess_async\nfrom huggingface_hub import snapshot_download\nfrom transformers.testing_utils import get_torch_dist_unique_port\nfrom transformers.utils import is_torch_bf16_gpu_available\n\nfrom axolotl.utils.dict import DictDefault\n\nfrom tests.e2e.utils import check_tensorboard, require_torch_2_6_0\n\nAXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent\n\n\n@pytest.fixture(scope=\"session\", autouse=True)\ndef download_model():\n    # download the model\n    snapshot_download(\"HuggingFaceTB/SmolLM2-135M\")\n\n\nclass TestPackedFlex:\n    \"\"\"\n    Test case for Packed training of llama models\n    \"\"\"\n\n    @require_torch_2_6_0\n    def test_loss_llama(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"sequence_len\": 1024,\n                \"sample_packing\": True,\n                \"flex_attention\": True,\n                \"val_set_size\": 0.0,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"tatsu-lab/alpaca\",\n                        \"type\": \"alpaca\",\n                        \"split\": \"train[:10%]\",\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"micro_batch_size\": 2,\n                \"gradient_accumulation_steps\": 2,\n                \"gradient_checkpointing\": True,\n                \"output_dir\": temp_dir,\n                \"dataset_prepared_path\": temp_dir + \"/last_run_prepared\",\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_torch_fused\",\n                \"lr_scheduler\": \"cosine\",\n                \"max_steps\": 2,\n                \"use_tensorboard\": True,\n                \"save_strategy\": \"no\",\n                \"save_first_step\": False,\n            }\n        )\n        if is_torch_bf16_gpu_available():\n            cfg.bf16 = True\n        else:\n            cfg.fp16 = True\n\n        # write cfg to yaml file\n        Path(temp_dir).mkdir(parents=True, exist_ok=True)\n        with open(Path(temp_dir) / \"config.yaml\", \"w\", encoding=\"utf-8\") as fout:\n            fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))\n\n        execute_subprocess_async(\n            [\n                \"axolotl\",\n                \"train\",\n                str(Path(temp_dir) / \"config.yaml\"),\n                \"--num-processes\",\n                \"2\",\n                \"--main-process-port\",\n                f\"{get_torch_dist_unique_port()}\",\n            ]\n        )\n\n        check_tensorboard(\n            temp_dir + \"/runs\", \"train/train_loss\", 2.1, \"Train Loss (%s) is too high\"\n        )\n"
  },
  {
    "path": "tests/e2e/multigpu/solo/test_gdpo.py",
    "content": "\"\"\"\nGDPO test suite\n\nGDPO uses TRL's multi_objective_aggregation=\"normalize_then_sum\" for\nper-reward normalization in multi-reward RL training.\n\"\"\"\n\nimport os\nimport random\nfrom pathlib import Path\n\nimport pytest\nimport yaml\nfrom accelerate.test_utils import execute_subprocess_async\nfrom transformers.testing_utils import get_torch_dist_unique_port\n\nfrom axolotl.utils.dict import DictDefault\n\nfrom tests.e2e.multigpu.solo.test_grpo import recursive_kill, start_vllm\nfrom tests.e2e.utils import require_vllm\n\n\n@pytest.mark.skip(reason=\"flaky vllm tests in modal\")\nclass TestGDPO:\n    \"\"\"Test case for GDPO training using TRL's native multi-objective aggregation.\"\"\"\n\n    def _utils_write_yaml_and_rewards(self, cfg, temp_dir, suffix=\"\"):\n        Path(temp_dir).mkdir(parents=True, exist_ok=True)\n        with open(Path(temp_dir) / \"config.yaml\", \"w\", encoding=\"utf-8\") as fout:\n            fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))\n        with open(f\"rewards_gdpo_{suffix}.py\", \"w\", encoding=\"utf-8\") as fout:\n            fout.write(\n                \"\"\"import random\n\ndef format_reward(prompts, completions, **kwargs) -> list[float]:\n    return [1.0 if len(c) > 10 else 0.0 for c in completions]\n\ndef correctness_reward(prompts, completions, **kwargs) -> list[float]:\n    return [random.uniform(-1, 3) for _ in completions]\n\ndef safety_reward(prompts, completions, **kwargs) -> list[float]:\n    return [1.0 if 'error' not in c.lower() else 0.0 for c in completions]\n\ndef single_reward(prompts, completions, **kwargs) -> list[float]:\n    return [random.uniform(0, 1) for _ in completions]\n\ndef oai_gsm8k_transform(cfg, *args, **kwargs):\n    def transform_fn(example, tokenizer=None):\n        label = example[\"answer\"].split(\"####\")[-1].strip().replace(\",\", \"\")\n        return {\n            \"prompt\": [{\"role\": \"user\", \"content\": example[\"question\"]}],\n            \"answer\": label,\n        }\n    return transform_fn, {\"remove_columns\": [\"question\"]}\n\"\"\"\n            )\n\n    @pytest.mark.parametrize(\"num_gpus\", [1, 2])\n    @require_vllm\n    def test_gdpo_multi_reward_lora(self, temp_dir, num_gpus):\n        \"\"\"Test GDPO with multiple reward functions using LoRA.\"\"\"\n        rnd_suffix = str(random.randint(1000, 9999))\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"chat_template\": \"llama3\",\n                \"rl\": \"gdpo\",\n                \"trl\": {\n                    \"beta\": 0.001,\n                    \"max_completion_length\": 256,\n                    \"use_vllm\": True,\n                    \"num_generations\": 4,\n                    \"reward_funcs\": [\n                        f\"rewards_gdpo_{rnd_suffix}.format_reward\",\n                        f\"rewards_gdpo_{rnd_suffix}.correctness_reward\",\n                    ],\n                    \"reward_weights\": [1.0, 2.0],\n                    \"scale_rewards\": True,\n                },\n                \"vllm\": {\n                    \"max_model_len\": 800,\n                    \"enable_prefix_caching\": True,\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"openai/gsm8k\",\n                        \"name\": \"main\",\n                        \"type\": f\"rewards_gdpo_{rnd_suffix}.oai_gsm8k_transform\",\n                    },\n                ],\n                \"adapter\": \"lora\",\n                \"lora_r\": 8,\n                \"lora_alpha\": 16,\n                \"lora_dropout\": 0.05,\n                \"lora_target_linear\": True,\n                \"flash_attention\": True,\n                \"sequence_len\": 1024,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"max_steps\": 3,\n                \"num_epochs\": 1,\n                \"micro_batch_size\": 4,\n                \"gradient_accumulation_steps\": 2,\n                \"warmup_steps\": 10,\n                \"val_set_size\": 0.0,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.0001,\n                \"optimizer\": \"adamw_torch_fused\",\n                \"lr_scheduler\": \"cosine\",\n                \"save_safetensors\": True,\n                \"bf16\": \"auto\",\n                \"use_tensorboard\": True,\n                \"save_first_step\": False,\n            }\n        )\n\n        self._utils_write_yaml_and_rewards(cfg, temp_dir, suffix=rnd_suffix)\n\n        current_env = os.environ.copy()\n        env = {\n            \"NCCL_P2P_LEVEL\": \"LOC\",\n            **current_env,\n            \"CUDA_VISIBLE_DEVICES\": \"1\",\n        }\n        vllm_process = start_vllm(\n            cfg.base_model,\n            env=env,\n            quiet=True,\n            wait=300,\n            gpu_memory_utilization=0.15,\n            max_model_len=cfg.vllm.max_model_len,\n            enable_prefix_caching=cfg.vllm.enable_prefix_caching,\n            host=\"0.0.0.0\",\n            port=8000,\n        )\n\n        try:\n            execute_subprocess_async(\n                [\n                    \"axolotl\",\n                    \"train\",\n                    str(Path(temp_dir) / \"config.yaml\"),\n                    \"--num-processes\",\n                    str(num_gpus),\n                    \"--main-process-port\",\n                    f\"{get_torch_dist_unique_port()}\",\n                ],\n                env={\n                    \"NCCL_P2P_LEVEL\": \"LOC\",\n                    \"NCCL_DEBUG\": \"INFO\",\n                    **current_env,\n                },\n            )\n        finally:\n            recursive_kill(vllm_process)\n\n    @require_vllm\n    def test_gdpo_three_rewards(self, temp_dir):\n        \"\"\"Test GDPO with three reward functions (format, correctness, safety).\"\"\"\n        rnd_suffix = str(random.randint(1000, 9999))\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"chat_template\": \"llama3\",\n                \"rl\": \"gdpo\",\n                \"trl\": {\n                    \"beta\": 0.001,\n                    \"max_completion_length\": 256,\n                    \"use_vllm\": True,\n                    \"num_generations\": 4,\n                    \"reward_funcs\": [\n                        f\"rewards_gdpo_{rnd_suffix}.format_reward\",\n                        f\"rewards_gdpo_{rnd_suffix}.correctness_reward\",\n                        f\"rewards_gdpo_{rnd_suffix}.safety_reward\",\n                    ],\n                    \"reward_weights\": [1.0, 2.0, 1.5],\n                },\n                \"vllm\": {\n                    \"max_model_len\": 800,\n                    \"enable_prefix_caching\": True,\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"openai/gsm8k\",\n                        \"name\": \"main\",\n                        \"type\": f\"rewards_gdpo_{rnd_suffix}.oai_gsm8k_transform\",\n                    },\n                ],\n                \"adapter\": \"lora\",\n                \"lora_r\": 8,\n                \"lora_alpha\": 16,\n                \"lora_dropout\": 0.05,\n                \"lora_target_linear\": True,\n                \"flash_attention\": True,\n                \"sequence_len\": 1024,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"max_steps\": 3,\n                \"num_epochs\": 1,\n                \"micro_batch_size\": 4,\n                \"gradient_accumulation_steps\": 2,\n                \"warmup_steps\": 10,\n                \"val_set_size\": 0.0,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.0001,\n                \"optimizer\": \"adamw_torch_fused\",\n                \"lr_scheduler\": \"cosine\",\n                \"save_safetensors\": True,\n                \"bf16\": \"auto\",\n            }\n        )\n\n        self._utils_write_yaml_and_rewards(cfg, temp_dir, suffix=rnd_suffix)\n\n        current_env = os.environ.copy()\n        env = {\n            \"NCCL_P2P_LEVEL\": \"LOC\",\n            **current_env,\n            \"CUDA_VISIBLE_DEVICES\": \"1\",\n        }\n        vllm_process = start_vllm(\n            cfg.base_model,\n            env=env,\n            quiet=True,\n            wait=300,\n            gpu_memory_utilization=0.15,\n            max_model_len=cfg.vllm.max_model_len,\n            enable_prefix_caching=cfg.vllm.enable_prefix_caching,\n            host=\"0.0.0.0\",\n            port=8000,\n        )\n\n        try:\n            execute_subprocess_async(\n                [\n                    \"axolotl\",\n                    \"train\",\n                    str(Path(temp_dir) / \"config.yaml\"),\n                    \"--num-processes\",\n                    \"1\",\n                    \"--main-process-port\",\n                    f\"{get_torch_dist_unique_port()}\",\n                ],\n                env={\n                    \"NCCL_P2P_LEVEL\": \"LOC\",\n                    \"NCCL_DEBUG\": \"INFO\",\n                    **current_env,\n                },\n            )\n        finally:\n            recursive_kill(vllm_process)\n\n    @require_vllm\n    def test_gdpo_single_reward_fallback(self, temp_dir):\n        \"\"\"Test GDPO with single reward.\"\"\"\n        rnd_suffix = str(random.randint(1000, 9999))\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"chat_template\": \"llama3\",\n                \"rl\": \"gdpo\",\n                \"trl\": {\n                    \"beta\": 0.001,\n                    \"max_completion_length\": 256,\n                    \"use_vllm\": True,\n                    \"num_generations\": 4,\n                    \"reward_funcs\": [\n                        f\"rewards_gdpo_{rnd_suffix}.single_reward\",\n                    ],\n                    \"reward_weights\": [1.0],\n                },\n                \"vllm\": {\n                    \"max_model_len\": 800,\n                    \"enable_prefix_caching\": True,\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"openai/gsm8k\",\n                        \"name\": \"main\",\n                        \"type\": f\"rewards_gdpo_{rnd_suffix}.oai_gsm8k_transform\",\n                    },\n                ],\n                \"adapter\": \"lora\",\n                \"lora_r\": 8,\n                \"lora_alpha\": 16,\n                \"lora_dropout\": 0.05,\n                \"lora_target_linear\": True,\n                \"flash_attention\": True,\n                \"sequence_len\": 1024,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"max_steps\": 3,\n                \"num_epochs\": 1,\n                \"micro_batch_size\": 4,\n                \"gradient_accumulation_steps\": 2,\n                \"warmup_steps\": 10,\n                \"val_set_size\": 0.0,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.0001,\n                \"optimizer\": \"adamw_torch_fused\",\n                \"lr_scheduler\": \"cosine\",\n                \"save_safetensors\": True,\n                \"bf16\": \"auto\",\n            }\n        )\n\n        self._utils_write_yaml_and_rewards(cfg, temp_dir, suffix=rnd_suffix)\n\n        current_env = os.environ.copy()\n        env = {\n            \"NCCL_P2P_LEVEL\": \"LOC\",\n            **current_env,\n            \"CUDA_VISIBLE_DEVICES\": \"1\",\n        }\n        vllm_process = start_vllm(\n            cfg.base_model,\n            env=env,\n            quiet=True,\n            wait=300,\n            gpu_memory_utilization=0.15,\n            max_model_len=cfg.vllm.max_model_len,\n            enable_prefix_caching=cfg.vllm.enable_prefix_caching,\n            host=\"0.0.0.0\",\n            port=8000,\n        )\n\n        try:\n            execute_subprocess_async(\n                [\n                    \"axolotl\",\n                    \"train\",\n                    str(Path(temp_dir) / \"config.yaml\"),\n                    \"--num-processes\",\n                    \"1\",\n                    \"--main-process-port\",\n                    f\"{get_torch_dist_unique_port()}\",\n                ],\n                env={\n                    \"NCCL_P2P_LEVEL\": \"LOC\",\n                    \"NCCL_DEBUG\": \"INFO\",\n                    **current_env,\n                },\n            )\n        finally:\n            recursive_kill(vllm_process)\n\n    @require_vllm\n    def test_gdpo_fft(self, temp_dir):\n        \"\"\"Test GDPO with full fine-tuning (no adapter).\"\"\"\n        rnd_suffix = str(random.randint(1000, 9999))\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"chat_template\": \"llama3\",\n                \"rl\": \"gdpo\",\n                \"trl\": {\n                    \"beta\": 0.001,\n                    \"max_completion_length\": 256,\n                    \"use_vllm\": True,\n                    \"num_generations\": 4,\n                    \"reward_funcs\": [\n                        f\"rewards_gdpo_{rnd_suffix}.format_reward\",\n                        f\"rewards_gdpo_{rnd_suffix}.correctness_reward\",\n                    ],\n                    \"reward_weights\": [1.0, 2.0],\n                },\n                \"vllm\": {\n                    \"max_model_len\": 800,\n                    \"enable_prefix_caching\": True,\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"openai/gsm8k\",\n                        \"name\": \"main\",\n                        \"type\": f\"rewards_gdpo_{rnd_suffix}.oai_gsm8k_transform\",\n                    },\n                ],\n                # No adapter - full fine-tuning\n                \"flash_attention\": True,\n                \"sequence_len\": 1024,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"max_steps\": 3,\n                \"num_epochs\": 1,\n                \"micro_batch_size\": 4,\n                \"gradient_accumulation_steps\": 2,\n                \"warmup_steps\": 10,\n                \"val_set_size\": 0.0,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.0001,\n                \"optimizer\": \"adamw_torch_fused\",\n                \"lr_scheduler\": \"cosine\",\n                \"save_safetensors\": True,\n                \"bf16\": \"auto\",\n            }\n        )\n\n        self._utils_write_yaml_and_rewards(cfg, temp_dir, suffix=rnd_suffix)\n\n        current_env = os.environ.copy()\n        env = {\n            \"NCCL_P2P_LEVEL\": \"LOC\",\n            **current_env,\n            \"CUDA_VISIBLE_DEVICES\": \"1\",\n        }\n        vllm_process = start_vllm(\n            cfg.base_model,\n            env=env,\n            quiet=True,\n            wait=300,\n            gpu_memory_utilization=0.15,\n            max_model_len=cfg.vllm.max_model_len,\n            enable_prefix_caching=cfg.vllm.enable_prefix_caching,\n            host=\"0.0.0.0\",\n            port=8000,\n        )\n\n        try:\n            execute_subprocess_async(\n                [\n                    \"axolotl\",\n                    \"train\",\n                    str(Path(temp_dir) / \"config.yaml\"),\n                    \"--num-processes\",\n                    \"1\",\n                    \"--main-process-port\",\n                    f\"{get_torch_dist_unique_port()}\",\n                ],\n                env={\n                    \"NCCL_P2P_LEVEL\": \"LOC\",\n                    \"NCCL_DEBUG\": \"INFO\",\n                    **current_env,\n                },\n            )\n        finally:\n            recursive_kill(vllm_process)\n\n    @require_vllm\n    def test_gdpo_sequence_parallel(self, temp_dir):\n        \"\"\"Test GDPO with sequence parallelism.\"\"\"\n        rnd_suffix = str(random.randint(1000, 9999))\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"chat_template\": \"llama3\",\n                \"rl\": \"gdpo\",\n                \"context_parallel_size\": 2,\n                \"trl\": {\n                    \"beta\": 0.001,\n                    \"max_completion_length\": 256,\n                    \"use_vllm\": True,\n                    \"num_generations\": 4,\n                    \"reward_funcs\": [\n                        f\"rewards_gdpo_{rnd_suffix}.format_reward\",\n                        f\"rewards_gdpo_{rnd_suffix}.correctness_reward\",\n                    ],\n                    \"reward_weights\": [1.0, 2.0],\n                },\n                \"vllm\": {\n                    \"max_model_len\": 800,\n                    \"enable_prefix_caching\": True,\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"openai/gsm8k\",\n                        \"name\": \"main\",\n                        \"type\": f\"rewards_gdpo_{rnd_suffix}.oai_gsm8k_transform\",\n                    },\n                ],\n                \"adapter\": \"lora\",\n                \"lora_r\": 8,\n                \"lora_alpha\": 16,\n                \"lora_dropout\": 0.05,\n                \"lora_target_linear\": True,\n                \"flash_attention\": True,\n                \"sequence_len\": 1024,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"max_steps\": 3,\n                \"num_epochs\": 1,\n                \"micro_batch_size\": 4,\n                \"gradient_accumulation_steps\": 2,\n                \"warmup_steps\": 10,\n                \"val_set_size\": 0.0,\n                \"output_dir\": temp_dir,\n                \"dataset_prepared_path\": temp_dir + \"/last_run_prepared\",\n                \"learning_rate\": 0.0001,\n                \"optimizer\": \"adamw_torch_fused\",\n                \"lr_scheduler\": \"cosine\",\n                \"save_safetensors\": True,\n                \"bf16\": \"auto\",\n            }\n        )\n\n        self._utils_write_yaml_and_rewards(cfg, temp_dir, suffix=rnd_suffix)\n\n        current_env = os.environ.copy()\n        env = {\n            \"NCCL_P2P_LEVEL\": \"LOC\",\n            **current_env,\n            \"CUDA_VISIBLE_DEVICES\": \"1\",\n        }\n        vllm_process = start_vllm(\n            cfg.base_model,\n            env=env,\n            quiet=True,\n            wait=300,\n            gpu_memory_utilization=0.15,\n            max_model_len=cfg.vllm.max_model_len,\n            enable_prefix_caching=cfg.vllm.enable_prefix_caching,\n            host=\"0.0.0.0\",\n            port=8000,\n        )\n\n        try:\n            execute_subprocess_async(\n                [\n                    \"axolotl\",\n                    \"train\",\n                    str(Path(temp_dir) / \"config.yaml\"),\n                    \"--num-processes\",\n                    \"2\",\n                    \"--main-process-port\",\n                    f\"{get_torch_dist_unique_port()}\",\n                ],\n                env={\n                    \"NCCL_P2P_LEVEL\": \"LOC\",\n                    \"NCCL_DEBUG\": \"INFO\",\n                    **current_env,\n                },\n            )\n        finally:\n            recursive_kill(vllm_process)\n"
  },
  {
    "path": "tests/e2e/multigpu/solo/test_grpo.py",
    "content": "\"\"\"\nGRPO test suite\n\"\"\"\n\nimport os\nimport random\nimport subprocess  # nosec B404\nimport sys\nimport tempfile\nimport time\nfrom pathlib import Path\n\nimport psutil\nimport pytest\nimport requests\nimport yaml\nfrom accelerate.test_utils import execute_subprocess_async\nfrom transformers.testing_utils import get_torch_dist_unique_port\n\nfrom axolotl.utils.dict import DictDefault\n\nfrom tests.e2e.utils import require_vllm\n\n\ndef start_vllm(\n    model: str, env: dict, wait: int | None = None, quiet=False, **kwargs\n) -> subprocess.Popen:\n    \"\"\"\n    helper function to start the VLLM server in the background, mostly for testing purposes\n    \"\"\"\n    cmd = [sys.executable, \"-m\", \"trl.scripts.vllm_serve\", \"--model\", model]\n\n    if tensor_parallel_size := kwargs.get(\"tensor_parallel_size\"):\n        cmd.extend([\"--tensor-parallel-size\", str(tensor_parallel_size)])\n    if host := kwargs.get(\"host\"):\n        cmd.extend([\"--host\", host])\n    if port := kwargs.get(\"port\"):\n        cmd.extend([\"--port\", str(port)])\n    if gpu_memory_utilization := kwargs.get(\"gpu_memory_utilization\"):\n        cmd.extend([\"--gpu-memory-utilization\", str(gpu_memory_utilization)])\n    if dtype := kwargs.get(\"dtype\"):\n        cmd.extend([\"--dtype\", dtype])\n    if max_model_len := kwargs.get(\"max_model_len\"):\n        cmd.extend([\"--max-model-len\", str(max_model_len)])\n    if kwargs.get(\"enable_prefix_caching\"):\n        cmd.extend([\"--enable-prefix-caching\", \"True\"])\n\n    # print out the command to be executed\n    print(\" \".join(cmd))\n\n    vllm_logging_json = Path(tempfile.mkdtemp()) / \"vllm_logging.json\"\n    with open(vllm_logging_json, \"w\", encoding=\"utf-8\") as temp_file:\n        temp_file.write(\n            \"\"\"{\n  \"formatters\": {\n    \"json\": {\n      \"class\": \"pythonjsonlogger.jsonlogger.JsonFormatter\"\n    }\n  },\n  \"handlers\": {\n    \"file\": {\n      \"class\": \"logging.FileHandler\",\n      \"formatter\": \"json\",\n      \"level\": \"DEBUG\",\n      \"filename\": \"/tmp/vllm.log\",\n      \"mode\": \"a\"\n    }\n  },\n  \"loggers\": {\n    \"vllm\": {\n      \"handlers\": [\"file\"],\n      \"level\": \"DEBUG\",\n      \"propagate\": false\n    }\n  },\n  \"version\": 1\n}\"\"\"\n        )\n\n    cmd_env = env.copy()\n    cmd_env.update({\"VLLM_LOGGING_CONFIG_PATH\": vllm_logging_json})\n    # start `trl vllm-serve` command in the background and capture the process id\n    process = subprocess.Popen(\n        cmd,\n        env=cmd_env,\n        stdout=subprocess.DEVNULL if quiet else subprocess.PIPE,\n        stderr=subprocess.DEVNULL if quiet else subprocess.PIPE,\n    )  # nosec B603\n\n    # print out the process id so the user can easily kill it later\n    print(f\"VLLM server process started (PID: {process.pid})\")\n\n    # wait until the http server is ready, even if it 404s, but timeout after 60 seconds\n    period_seconds = 5\n    started = False\n    if wait and host and port:\n        for i in range(0, int(wait), period_seconds):\n            try:\n                response = requests.get(f\"http://{host}:{port}\", timeout=1)\n                print(f\"{i}: VLLM server (status: {response.status_code})\")\n                if int(response.status_code) in [200, 404]:\n                    started = True\n                    break\n            except requests.exceptions.RequestException as exc:\n                print(f\"{i}: VLLM server failed to start: {str(exc)}\")\n\n            # also check if the process.pid is still running\n            if process.poll() is not None:\n                break\n\n            time.sleep(period_seconds)\n\n    if wait and not started:\n        print(\n            f\"VLLM server process did not start within {wait} seconds. Please check your server logs.\"\n        )\n        recursive_kill(process)\n        with open(\"/tmp/vllm.log\", \"r\", encoding=\"utf-8\") as log_file:\n            print(log_file.read())\n        try:\n            os.remove(\"/tmp/vllm.log\")\n        except FileNotFoundError:\n            pass\n        raise RuntimeError(f\"VLLM server process did not start within {wait} seconds.\")\n\n    # return the process\n    return process\n\n\ndef recursive_kill(process: subprocess.Popen):\n    \"\"\"\n    Recursively kill a process and its children\n    \"\"\"\n    process = psutil.Process(process.pid)\n    for child in psutil.Process(process.pid).children(recursive=True):\n        child.terminate()\n        child.kill()\n        os.kill(child.pid, 9)\n    process.terminate()\n    process.kill()\n    os.kill(process.pid, 9)\n\n\n@pytest.mark.skip(reason=\"flaky vllm tests in modal\")\nclass TestGRPO:\n    \"\"\"\n    Test case for GRPO training using multiple GPUs\n    \"\"\"\n\n    def _utils_write_yaml_and_rewards(self, cfg, temp_dir, suffix=\"\"):\n        # write cfg to yaml file\n        Path(temp_dir).mkdir(parents=True, exist_ok=True)\n        with open(Path(temp_dir) / \"config.yaml\", \"w\", encoding=\"utf-8\") as fout:\n            fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))\n        with open(f\"rewards_{suffix}.py\", \"w\", encoding=\"utf-8\") as fout:\n            fout.write(\n                \"\"\"import random\ndef rand_reward_func(completions, **kwargs) -> list[float]:\n    return [random.uniform(0, 1) for _ in completions]\n\ndef oai_gsm8k_transform(cfg, *args, **kwargs):\n    def transform_fn(example, tokenizer=None):\n        label = example[\"answer\"].split(\"####\")[-1].strip().replace(\",\", \"\")\n        return {\n            \"prompt\": [{\"role\": \"user\", \"content\": example[\"question\"]},],\n            \"answer\": label,\n        }\n    return transform_fn, {\"remove_columns\": [\"question\"]}\n\"\"\"\n            )\n\n    @pytest.mark.parametrize(\n        \"num_gpus\",\n        [1, 2],\n    )\n    @require_vllm\n    def test_llama_dora(self, temp_dir, num_gpus):\n        rnd_reward_suffix = str(random.randint(1000, 9999))\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"chat_template\": \"llama3\",\n                \"rl\": \"grpo\",\n                \"trl\": {\n                    \"beta\": 0.001,\n                    \"max_completion_length\": 256,\n                    \"use_vllm\": True,\n                    \"num_generations\": 4,\n                    \"reward_funcs\": [f\"rewards_{rnd_reward_suffix}.rand_reward_func\"],\n                },\n                \"vllm\": {\n                    \"max_model_len\": 800,\n                    \"enable_prefix_caching\": True,\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"openai/gsm8k\",\n                        \"name\": \"main\",\n                        \"type\": f\"rewards_{rnd_reward_suffix}.oai_gsm8k_transform\",\n                    },\n                ],\n                \"adapter\": \"lora\",\n                \"lora_r\": 8,\n                \"lora_alpha\": 16,\n                \"lora_dropout\": 0.05,\n                \"lora_target_linear\": True,\n                \"peft_use_dora\": True,\n                \"flash_attention\": True,\n                \"sequence_len\": 1024,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"max_steps\": 3,\n                \"num_epochs\": 1,\n                \"micro_batch_size\": 4,\n                \"gradient_accumulation_steps\": 2,\n                \"warmup_steps\": 10,\n                \"val_set_size\": 0.0,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.0001,\n                \"optimizer\": \"adamw_torch_fused\",\n                \"lr_scheduler\": \"cosine\",\n                \"bf16\": \"auto\",\n                \"use_tensorboard\": True,\n                \"save_first_step\": False,\n            }\n        )\n\n        self._utils_write_yaml_and_rewards(cfg, temp_dir, suffix=rnd_reward_suffix)\n\n        current_env = os.environ.copy()\n        env = {\n            \"NCCL_P2P_LEVEL\": \"LOC\",\n            **current_env,\n            \"CUDA_VISIBLE_DEVICES\": \"1\",\n        }\n        vllm_process = start_vllm(\n            cfg.base_model,\n            env=env,\n            quiet=True,\n            wait=300,\n            gpu_memory_utilization=0.15,\n            max_model_len=cfg.vllm.max_model_len,\n            enable_prefix_caching=cfg.vllm.enable_prefix_caching,\n            host=\"0.0.0.0\",\n            port=8000,\n        )\n\n        try:\n            execute_subprocess_async(\n                [\n                    \"axolotl\",\n                    \"train\",\n                    str(Path(temp_dir) / \"config.yaml\"),\n                    \"--num-processes\",\n                    str(num_gpus),\n                    \"--main-process-port\",\n                    f\"{get_torch_dist_unique_port()}\",\n                ],\n                env={\n                    \"NCCL_P2P_LEVEL\": \"LOC\",\n                    \"NCCL_DEBUG\": \"INFO\",\n                    **current_env,\n                },\n            )\n        finally:\n            (recursive_kill(vllm_process))\n\n    @require_vllm\n    def test_llama_lora_sp(self, temp_dir):\n        rnd_reward_suffix = str(random.randint(1000, 9999))\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"chat_template\": \"llama3\",\n                \"rl\": \"grpo\",\n                \"trl\": {\n                    \"beta\": 0.001,\n                    \"max_completion_length\": 256,\n                    \"use_vllm\": True,\n                    \"num_generations\": 4,\n                    \"reward_funcs\": [f\"rewards_{rnd_reward_suffix}.rand_reward_func\"],\n                },\n                \"vllm\": {\n                    \"max_model_len\": 800,\n                    \"enable_prefix_caching\": True,\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"openai/gsm8k\",\n                        \"name\": \"main\",\n                        \"type\": f\"rewards_{rnd_reward_suffix}.oai_gsm8k_transform\",\n                    },\n                ],\n                \"adapter\": \"lora\",\n                \"lora_r\": 8,\n                \"lora_alpha\": 16,\n                \"lora_dropout\": 0.05,\n                \"lora_target_linear\": True,\n                \"context_parallel_size\": 2,\n                \"flash_attention\": True,\n                \"sequence_len\": 1024,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"max_steps\": 3,\n                \"num_epochs\": 1,\n                \"micro_batch_size\": 4,\n                \"gradient_accumulation_steps\": 2,\n                \"warmup_steps\": 10,\n                \"val_set_size\": 0.0,\n                \"output_dir\": temp_dir,\n                \"dataset_prepared_path\": temp_dir + \"/last_run_prepared\",\n                \"learning_rate\": 0.0001,\n                \"optimizer\": \"adamw_torch_fused\",\n                \"lr_scheduler\": \"cosine\",\n                \"bf16\": \"auto\",\n                \"use_tensorboard\": True,\n                \"save_first_step\": False,\n            }\n        )\n\n        self._utils_write_yaml_and_rewards(cfg, temp_dir, suffix=rnd_reward_suffix)\n\n        current_env = os.environ.copy()\n        env = {\n            \"NCCL_P2P_LEVEL\": \"LOC\",\n            **current_env,\n            \"CUDA_VISIBLE_DEVICES\": \"1\",\n        }\n        vllm_process = start_vllm(\n            cfg.base_model,\n            env=env,\n            quiet=True,\n            wait=300,\n            gpu_memory_utilization=0.15,\n            max_model_len=cfg.vllm.max_model_len,\n            enable_prefix_caching=cfg.vllm.enable_prefix_caching,\n            host=\"0.0.0.0\",\n            port=8000,\n        )\n\n        try:\n            execute_subprocess_async(\n                [\n                    \"axolotl\",\n                    \"train\",\n                    str(Path(temp_dir) / \"config.yaml\"),\n                    \"--num-processes\",\n                    str(2),\n                    \"--main-process-port\",\n                    f\"{get_torch_dist_unique_port()}\",\n                ],\n                env={\n                    \"NCCL_P2P_LEVEL\": \"LOC\",\n                    \"NCCL_DEBUG\": \"INFO\",\n                    **current_env,\n                },\n            )\n        finally:\n            recursive_kill(vllm_process)\n\n    @pytest.mark.parametrize(\n        \"num_gpus\",\n        [1, 2],\n    )\n    @require_vllm\n    def test_llama_fft(self, temp_dir, num_gpus):\n        rnd_reward_suffix = str(random.randint(1000, 9999))\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"chat_template\": \"llama3\",\n                \"rl\": \"grpo\",\n                \"trl\": {\n                    \"beta\": 0.001,\n                    \"max_completion_length\": 256,\n                    \"use_vllm\": True,\n                    \"num_generations\": 4,\n                    \"reward_funcs\": [f\"rewards_{rnd_reward_suffix}.rand_reward_func\"],\n                },\n                \"vllm\": {\n                    \"max_model_len\": 800,\n                    \"enable_prefix_caching\": True,\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"openai/gsm8k\",\n                        \"name\": \"main\",\n                        \"type\": f\"rewards_{rnd_reward_suffix}.oai_gsm8k_transform\",\n                    },\n                ],\n                \"flash_attention\": True,\n                \"sequence_len\": 1024,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"max_steps\": 3,\n                \"num_epochs\": 1,\n                \"micro_batch_size\": 4,\n                \"gradient_accumulation_steps\": 2,\n                \"warmup_steps\": 10,\n                \"val_set_size\": 0.0,\n                \"output_dir\": temp_dir,\n                \"dataset_prepared_path\": temp_dir + \"/last_run_prepared\",\n                \"learning_rate\": 0.0001,\n                \"optimizer\": \"adamw_torch_fused\",\n                \"lr_scheduler\": \"cosine\",\n                \"bf16\": \"auto\",\n                \"use_tensorboard\": True,\n                \"save_first_step\": False,\n            }\n        )\n\n        self._utils_write_yaml_and_rewards(cfg, temp_dir, suffix=rnd_reward_suffix)\n\n        current_env = os.environ.copy()\n        env = {\n            \"NCCL_P2P_LEVEL\": \"LOC\",  # nccl can be brittle, assume P2P isn't reliable\n            **current_env,\n            \"CUDA_VISIBLE_DEVICES\": \"1\",\n        }\n        vllm_process = start_vllm(\n            cfg.base_model,\n            env=env,\n            quiet=True,\n            wait=300,\n            gpu_memory_utilization=0.15,\n            max_model_len=cfg.vllm.max_model_len,\n            enable_prefix_caching=cfg.vllm.enable_prefix_caching,\n            host=\"0.0.0.0\",\n            port=8000,\n        )\n\n        try:\n            execute_subprocess_async(\n                [\n                    \"axolotl\",\n                    \"train\",\n                    str(Path(temp_dir) / \"config.yaml\"),\n                    \"--num-processes\",\n                    str(num_gpus),\n                    \"--main-process-port\",\n                    f\"{get_torch_dist_unique_port()}\",\n                ],\n                env={\n                    \"NCCL_P2P_LEVEL\": \"LOC\",\n                    \"NCCL_DEBUG\": \"INFO\",\n                    **current_env,\n                },\n            )\n        finally:\n            recursive_kill(vllm_process)\n"
  },
  {
    "path": "tests/e2e/multigpu/test_dist_muon_fsdp2.py",
    "content": "\"\"\"Test module for DistMuon optimizer with FSDP2 multi-GPU functionality.\"\"\"\n\nimport os\nfrom pathlib import Path\n\nimport torch\nimport yaml\nfrom accelerate.test_utils import execute_subprocess_async\nfrom tbparse import SummaryReader\nfrom transformers.testing_utils import get_torch_dist_unique_port\n\nfrom axolotl.utils.dict import DictDefault\n\nfrom tests.e2e.utils import most_recent_subdir, require_torch_2_7_0\n\nAXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent\n\n\ndef verify_training_success(temp_dir):\n    \"\"\"Verify that training completed successfully by checking artifacts and loss.\"\"\"\n    output_path = Path(temp_dir)\n\n    model_files = list(output_path.glob(\"*.bin\")) + list(\n        output_path.glob(\"*.safetensors\")\n    )\n    assert len(model_files) > 0, \"No model files found - training may have failed\"\n\n    checkpoint_files = list(output_path.glob(\"checkpoint-*\"))\n    assert len(checkpoint_files) > 0, (\n        \"No checkpoint files found - training may have failed\"\n    )\n\n    tb_log_path = most_recent_subdir(temp_dir + \"/runs\")\n    if tb_log_path:\n        event_files = sorted(os.listdir(tb_log_path))\n        if event_files:\n            event_file = os.path.join(tb_log_path, event_files[0])\n            reader = SummaryReader(event_file)\n            df = reader.scalars\n            train_loss_df = df[df.tag == \"train/train_loss\"]\n            if len(train_loss_df) > 0:\n                final_loss = train_loss_df.value.values[-1]\n                assert not torch.isnan(torch.tensor(final_loss)), (\n                    f\"Training loss is NaN: {final_loss}\"\n                )\n\n\nclass TestDistMuon:\n    \"\"\"Test class for DistMuon optimizer with FSDP2 functionality.\"\"\"\n\n    @require_torch_2_7_0\n    def test_fft_sft(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"Qwen/Qwen2.5-0.5B\",\n                \"sequence_len\": 2048,\n                \"val_set_size\": 0.01,\n                \"datasets\": [\n                    {\n                        \"path\": \"tatsu-lab/alpaca\",\n                        \"type\": \"alpaca\",\n                        \"split\": \"train[:10%]\",\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"max_steps\": 2,\n                \"micro_batch_size\": 2,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.02,\n                \"optimizer\": \"muon\",\n                \"weight_decay\": 0.01,\n                \"lr_scheduler\": \"cosine\",\n                \"flash_attention\": True,\n                \"fsdp_version\": 2,\n                \"fsdp_config\": {\n                    \"offload_params\": False,\n                    \"cpu_ram_efficient_loading\": False,\n                    \"transformer_layer_cls_to_wrap\": \"Qwen2DecoderLayer\",\n                    \"state_dict_type\": \"FULL_STATE_DICT\",\n                    \"auto_wrap_policy\": \"TRANSFORMER_BASED_WRAP\",\n                    \"reshard_after_forward\": True,\n                },\n                \"use_tensorboard\": True,\n                \"bf16\": True,\n            }\n        )\n\n        # write cfg to yaml file\n        Path(temp_dir).mkdir(parents=True, exist_ok=True)\n        with open(Path(temp_dir) / \"config.yaml\", \"w\", encoding=\"utf-8\") as fout:\n            fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))\n\n        execute_subprocess_async(\n            [\n                \"axolotl\",\n                \"train\",\n                str(Path(temp_dir) / \"config.yaml\"),\n                \"--num-processes\",\n                \"2\",\n                \"--main-process-port\",\n                f\"{get_torch_dist_unique_port()}\",\n            ]\n        )\n\n        verify_training_success(temp_dir)\n\n    @require_torch_2_7_0\n    def test_lora_sft(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"Qwen/Qwen2.5-0.5B\",\n                \"sequence_len\": 2048,\n                \"val_set_size\": 0.01,\n                \"datasets\": [\n                    {\n                        \"path\": \"tatsu-lab/alpaca\",\n                        \"type\": \"alpaca\",\n                        \"split\": \"train[:10%]\",\n                    },\n                ],\n                \"adapter\": \"lora\",\n                \"lora_r\": 8,\n                \"lora_alpha\": 16,\n                \"lora_dropout\": 0.05,\n                \"lora_target_linear\": True,\n                \"num_epochs\": 1,\n                \"max_steps\": 2,\n                \"micro_batch_size\": 2,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.02,\n                \"optimizer\": \"muon\",\n                \"weight_decay\": 0.01,\n                \"lr_scheduler\": \"cosine\",\n                \"flash_attention\": True,\n                \"fsdp_version\": 2,\n                \"fsdp_config\": {\n                    \"offload_params\": False,\n                    \"cpu_ram_efficient_loading\": False,\n                    \"transformer_layer_cls_to_wrap\": \"Qwen2DecoderLayer\",\n                    \"state_dict_type\": \"FULL_STATE_DICT\",\n                    \"auto_wrap_policy\": \"TRANSFORMER_BASED_WRAP\",\n                    \"reshard_after_forward\": True,\n                },\n                \"use_tensorboard\": True,\n                \"bf16\": True,\n            }\n        )\n\n        # write cfg to yaml file\n        Path(temp_dir).mkdir(parents=True, exist_ok=True)\n        with open(Path(temp_dir) / \"config.yaml\", \"w\", encoding=\"utf-8\") as fout:\n            fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))\n\n        execute_subprocess_async(\n            [\n                \"axolotl\",\n                \"train\",\n                str(Path(temp_dir) / \"config.yaml\"),\n                \"--num-processes\",\n                \"2\",\n                \"--main-process-port\",\n                f\"{get_torch_dist_unique_port()}\",\n            ]\n        )\n\n        verify_training_success(temp_dir)\n"
  },
  {
    "path": "tests/e2e/multigpu/test_eval.py",
    "content": "\"\"\"\nE2E tests for multigpu eval\n\"\"\"\n\nfrom pathlib import Path\n\nimport yaml\nfrom accelerate.test_utils import execute_subprocess_async\nfrom transformers.testing_utils import get_torch_dist_unique_port\n\nfrom axolotl.utils.dict import DictDefault\n\nfrom ..utils import check_tensorboard\n\nAXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent\n\n\nclass TestMultiGPUEval:\n    \"\"\"\n    Test case for MultiGPU Eval Sample Packing\n    \"\"\"\n\n    def test_eval_sample_packing(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"load_in_8bit\": False,\n                \"load_in_4bit\": True,\n                \"strict\": False,\n                \"sequence_len\": 2048,\n                \"adapter\": \"qlora\",\n                \"sample_packing\": True,\n                \"eval_sample_packing\": True,\n                \"pad_to_sequence_len\": True,\n                \"lora_r\": 8,\n                \"lora_alpha\": 16,\n                \"lora_dropout\": 0.05,\n                \"lora_target_linear\": True,\n                \"lora_modules_to_save\": [\"embed_tokens\", \"lm_head\"],\n                \"val_set_size\": 0.05,\n                \"special_tokens\": {\"pad_token\": \"<|endoftext|>\"},\n                \"datasets\": [\n                    {\n                        \"path\": \"teknium/GPT4-LLM-Cleaned\",\n                        \"type\": \"alpaca\",\n                        \"split\": \"train[:5%]\",\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"max_steps\": 2,\n                \"micro_batch_size\": 2,\n                \"gradient_accumulation_steps\": 2,\n                \"output_dir\": temp_dir,\n                \"dataset_prepared_path\": temp_dir + \"/last_run_prepared\",\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_8bit\",\n                \"lr_scheduler\": \"cosine\",\n                \"flash_attention\": True,\n                \"loss_watchdog_threshold\": 5.0,\n                \"loss_watchdog_patience\": 3,\n                \"bf16\": \"auto\",\n                \"warmup_steps\": 1,\n                \"evals_per_epoch\": 2,\n                \"eval_max_new_tokens\": 128,\n                \"saves_per_epoch\": 1,\n                \"logging_steps\": 1,\n                \"weight_decay\": 0.0,\n                \"use_tensorboard\": True,\n                \"save_first_step\": False,\n            }\n        )\n\n        # write cfg to yaml file\n        Path(temp_dir).mkdir(parents=True, exist_ok=True)\n        with open(Path(temp_dir) / \"config.yaml\", \"w\", encoding=\"utf-8\") as fout:\n            fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))\n\n        execute_subprocess_async(\n            [\n                \"accelerate\",\n                \"launch\",\n                \"--num-processes\",\n                \"2\",\n                \"--main_process_port\",\n                f\"{get_torch_dist_unique_port()}\",\n                \"-m\",\n                \"axolotl.cli.train\",\n                str(Path(temp_dir) / \"config.yaml\"),\n            ]\n        )\n\n        check_tensorboard(temp_dir + \"/runs\", \"eval/loss\", 2.5, \"Eval Loss is too high\")\n\n    def test_eval(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"load_in_8bit\": False,\n                \"load_in_4bit\": True,\n                \"strict\": False,\n                \"sequence_len\": 2048,\n                \"adapter\": \"qlora\",\n                \"sample_packing\": True,\n                \"eval_sample_packing\": False,\n                \"pad_to_sequence_len\": True,\n                \"lora_r\": 8,\n                \"lora_alpha\": 16,\n                \"lora_dropout\": 0.05,\n                \"lora_target_linear\": True,\n                \"lora_modules_to_save\": [\"embed_tokens\", \"lm_head\"],\n                \"val_set_size\": 0.01,\n                \"special_tokens\": {\"pad_token\": \"<|endoftext|>\"},\n                \"datasets\": [\n                    {\n                        \"path\": \"teknium/GPT4-LLM-Cleaned\",\n                        \"type\": \"alpaca\",\n                        \"split\": \"train[:5%]\",\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"max_steps\": 2,\n                \"micro_batch_size\": 2,\n                \"gradient_accumulation_steps\": 2,\n                \"output_dir\": temp_dir,\n                \"dataset_prepared_path\": temp_dir + \"/last_run_prepared\",\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_8bit\",\n                \"lr_scheduler\": \"cosine\",\n                \"flash_attention\": True,\n                \"loss_watchdog_threshold\": 5.0,\n                \"loss_watchdog_patience\": 3,\n                \"bf16\": \"auto\",\n                \"warmup_steps\": 1,\n                \"evals_per_epoch\": 2,\n                \"eval_max_new_tokens\": 128,\n                \"saves_per_epoch\": 1,\n                \"logging_steps\": 1,\n                \"weight_decay\": 0.0,\n                \"use_tensorboard\": True,\n                \"save_first_step\": False,\n            }\n        )\n\n        # write cfg to yaml file\n        Path(temp_dir).mkdir(parents=True, exist_ok=True)\n        with open(Path(temp_dir) / \"config.yaml\", \"w\", encoding=\"utf-8\") as fout:\n            fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))\n\n        execute_subprocess_async(\n            [\n                \"accelerate\",\n                \"launch\",\n                \"--num-processes\",\n                \"2\",\n                \"--main_process_port\",\n                f\"{get_torch_dist_unique_port()}\",\n                \"-m\",\n                \"axolotl.cli.train\",\n                str(Path(temp_dir) / \"config.yaml\"),\n            ]\n        )\n\n        check_tensorboard(temp_dir + \"/runs\", \"eval/loss\", 2.9, \"Eval Loss is too high\")\n"
  },
  {
    "path": "tests/e2e/multigpu/test_fp8_fsdp2.py",
    "content": "\"\"\"Test module for FP8 mixed precision with FSDP2 multi-GPU functionality.\"\"\"\n\nimport os\nfrom pathlib import Path\n\nimport torch\nimport yaml\nfrom accelerate.test_utils import execute_subprocess_async\nfrom tbparse import SummaryReader\nfrom transformers.testing_utils import get_torch_dist_unique_port\n\nfrom axolotl.utils.dict import DictDefault\n\nfrom tests.e2e.utils import most_recent_subdir, require_torch_2_7_0, supports_fp8\n\nAXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent\n\n\ndef verify_fp8_training_success(temp_dir):\n    \"\"\"Verify that FP8 training completed successfully by checking artifacts and loss.\"\"\"\n    output_path = Path(temp_dir)\n\n    model_files = list(output_path.glob(\"*.bin\")) + list(\n        output_path.glob(\"*.safetensors\")\n    )\n    assert len(model_files) > 0, \"No model files found - training may have failed\"\n\n    checkpoint_files = list(output_path.glob(\"checkpoint-*\"))\n    assert len(checkpoint_files) > 0, (\n        \"No checkpoint files found - training may have failed\"\n    )\n\n    tb_log_path = most_recent_subdir(temp_dir + \"/runs\")\n    if tb_log_path:\n        event_files = sorted(os.listdir(tb_log_path))\n        if event_files:\n            event_file = os.path.join(tb_log_path, event_files[0])\n            reader = SummaryReader(event_file)\n            df = reader.scalars\n            train_loss_df = df[df.tag == \"train/train_loss\"]\n            if len(train_loss_df) > 0:\n                final_loss = train_loss_df.value.values[-1]\n                assert not torch.isnan(torch.tensor(final_loss)), (\n                    f\"Training loss is NaN: {final_loss}\"\n                )\n\n\nclass TestFP8FSDP2:\n    \"\"\"Test class for FP8 mixed precision with FSDP2 functionality.\"\"\"\n\n    @require_torch_2_7_0\n    @supports_fp8\n    def test_fp8_fsdp2_smoke(self, temp_dir):\n        \"\"\"Smoke test for 2-GPU FP8 + torch.compile + FSDP2 training\"\"\"\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"tokenizer_type\": \"AutoTokenizer\",\n                \"trust_remote_code\": True,\n                \"sequence_len\": 512,\n                \"val_set_size\": 0.05,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": \"alpaca\",\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"max_steps\": 3,  # Very short smoke test\n                \"micro_batch_size\": 1,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_torch_fused\",  # Use standard optimizer for stability\n                \"lr_scheduler\": \"cosine\",\n                \"sdp_attention\": True,\n                \"pad_to_seq_len\": True,\n                \"sample_packing\": True,\n                # FP8 configuration\n                \"fp8\": True,\n                \"fp8_enable_fsdp_float8_all_gather\": True,\n                \"torch_compile\": True,\n                # FSDP2 configuration\n                \"fsdp_version\": 2,\n                \"fsdp_config\": {\n                    \"offload_params\": False,\n                    \"cpu_ram_efficient_loading\": False,\n                    \"transformer_layer_cls_to_wrap\": \"LlamaDecoderLayer\",\n                    \"state_dict_type\": \"FULL_STATE_DICT\",\n                    \"auto_wrap_policy\": \"TRANSFORMER_BASED_WRAP\",\n                    \"reshard_after_forward\": True,\n                },\n                \"use_tensorboard\": True,\n                \"save_first_step\": False,\n            }\n        )\n\n        # write cfg to yaml file\n        Path(temp_dir).mkdir(parents=True, exist_ok=True)\n        with open(Path(temp_dir) / \"config.yaml\", \"w\", encoding=\"utf-8\") as fout:\n            fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))\n\n        execute_subprocess_async(\n            [\n                \"axolotl\",\n                \"train\",\n                str(Path(temp_dir) / \"config.yaml\"),\n                \"--num-processes\",\n                \"2\",\n                \"--main-process-port\",\n                f\"{get_torch_dist_unique_port()}\",\n            ]\n        )\n\n        verify_fp8_training_success(temp_dir)\n"
  },
  {
    "path": "tests/e2e/multigpu/test_fsdp1.py",
    "content": "\"\"\"Test module for FSDP1 multi-GPU functionality.\"\"\"\n\nimport os\nfrom pathlib import Path\n\nimport pytest\nimport torch\nimport yaml\nfrom accelerate.test_utils import execute_subprocess_async\nfrom tbparse import SummaryReader\nfrom transformers.testing_utils import get_torch_dist_unique_port\n\nfrom axolotl.utils.dict import DictDefault\n\nfrom tests.e2e.utils import most_recent_subdir\n\nAXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent\n\n\ndef verify_training_success(temp_dir):\n    \"\"\"Verify that training completed successfully by checking artifacts and loss.\"\"\"\n    output_path = Path(temp_dir)\n\n    model_files = list(output_path.glob(\"*.bin\")) + list(\n        output_path.glob(\"*.safetensors\")\n    )\n    assert len(model_files) > 0, \"No model files found - training may have failed\"\n\n    checkpoint_files = list(output_path.glob(\"checkpoint-*\"))\n    assert len(checkpoint_files) > 0, (\n        \"No checkpoint files found - training may have failed\"\n    )\n\n    tb_log_path = most_recent_subdir(temp_dir + \"/runs\")\n    if tb_log_path:\n        event_files = sorted(os.listdir(tb_log_path))\n        if event_files:\n            event_file = os.path.join(tb_log_path, event_files[0])\n            reader = SummaryReader(event_file)\n            df = reader.scalars\n            train_loss_df = df[df.tag == \"train/train_loss\"]\n            if len(train_loss_df) > 0:\n                final_loss = train_loss_df.value.values[-1]\n                assert not torch.isnan(torch.tensor(final_loss)), (\n                    f\"Training loss is NaN: {final_loss}\"\n                )\n\n\nclass TestFSDP1:\n    \"\"\"Test class for FSDP1 functionality.\"\"\"\n\n    @pytest.mark.parametrize(\n        \"fsdp_cpu_ram_efficient_loading\",\n        [True, False],\n    )\n    def test_fft_sft(self, temp_dir, fsdp_cpu_ram_efficient_loading):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"Qwen/Qwen2.5-0.5B\",\n                \"sequence_len\": 2048,\n                \"val_set_size\": 0.01,\n                \"datasets\": [\n                    {\n                        \"path\": \"tatsu-lab/alpaca\",\n                        \"type\": \"alpaca\",\n                        \"split\": \"train[:10%]\",\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"max_steps\": 2,\n                \"micro_batch_size\": 2,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_torch_fused\",\n                \"lr_scheduler\": \"cosine\",\n                \"flash_attention\": True,\n                \"fsdp_version\": \"1\",\n                \"fsdp_config\": {\n                    \"fsdp_offload_params\": False,\n                    \"fsdp_cpu_ram_efficient_loading\": fsdp_cpu_ram_efficient_loading,\n                    \"fsdp_transformer_layer_cls_to_wrap\": \"Qwen2DecoderLayer\",\n                    \"fsdp_state_dict_type\": \"FULL_STATE_DICT\",\n                    \"fsdp_auto_wrap_policy\": \"TRANSFORMER_BASED_WRAP\",\n                    \"fsdp_sharding_strategy\": \"FULL_SHARD\",\n                    \"fsdp_sync_module_states\": True,\n                    \"fsdp_use_orig_params\": False,\n                },\n                \"use_tensorboard\": True,\n                \"bf16\": True,\n            }\n        )\n\n        # write cfg to yaml file\n        Path(temp_dir).mkdir(parents=True, exist_ok=True)\n        with open(Path(temp_dir) / \"config.yaml\", \"w\", encoding=\"utf-8\") as fout:\n            fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))\n\n        execute_subprocess_async(\n            [\n                \"axolotl\",\n                \"train\",\n                str(Path(temp_dir) / \"config.yaml\"),\n                \"--num-processes\",\n                \"2\",\n                \"--main-process-port\",\n                f\"{get_torch_dist_unique_port()}\",\n            ]\n        )\n\n        verify_training_success(temp_dir)\n\n    @pytest.mark.parametrize(\n        \"adapter_config\",\n        [\n            {\n                \"adapter\": \"lora\",\n                \"load_in_4bit\": False,\n            },\n            {\n                \"adapter\": \"qlora\",\n                \"load_in_4bit\": True,\n            },\n        ],\n    )\n    def test_lora_sft(self, temp_dir, adapter_config):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"Qwen/Qwen2.5-0.5B\",\n                \"sequence_len\": 2048,\n                \"val_set_size\": 0.01,\n                \"datasets\": [\n                    {\n                        \"path\": \"tatsu-lab/alpaca\",\n                        \"type\": \"alpaca\",\n                        \"split\": \"train[:10%]\",\n                    },\n                ],\n                \"adapter\": adapter_config[\"adapter\"],\n                \"load_in_4bit\": adapter_config[\"load_in_4bit\"],\n                \"lora_r\": 8,\n                \"lora_alpha\": 16,\n                \"lora_dropout\": 0.05,\n                \"lora_target_linear\": True,\n                \"num_epochs\": 1,\n                \"max_steps\": 2,\n                \"micro_batch_size\": 2,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_torch_fused\",\n                \"lr_scheduler\": \"cosine\",\n                \"flash_attention\": True,\n                \"fsdp_version\": \"1\",\n                \"fsdp_config\": {\n                    \"fsdp_offload_params\": False,\n                    \"fsdp_cpu_ram_efficient_loading\": True,\n                    \"fsdp_transformer_layer_cls_to_wrap\": \"Qwen2DecoderLayer\",\n                    \"fsdp_state_dict_type\": \"FULL_STATE_DICT\",\n                    \"fsdp_auto_wrap_policy\": \"TRANSFORMER_BASED_WRAP\",\n                    \"fsdp_sharding_strategy\": \"FULL_SHARD\",\n                    \"fsdp_sync_module_states\": True,\n                    \"fsdp_use_orig_params\": False,\n                },\n                \"use_tensorboard\": True,\n                \"bf16\": True,\n            }\n        )\n\n        # write cfg to yaml file\n        Path(temp_dir).mkdir(parents=True, exist_ok=True)\n        with open(Path(temp_dir) / \"config.yaml\", \"w\", encoding=\"utf-8\") as fout:\n            fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))\n\n        execute_subprocess_async(\n            [\n                \"axolotl\",\n                \"train\",\n                str(Path(temp_dir) / \"config.yaml\"),\n                \"--num-processes\",\n                \"2\",\n                \"--main-process-port\",\n                f\"{get_torch_dist_unique_port()}\",\n            ]\n        )\n\n        verify_training_success(temp_dir)\n\n    @pytest.mark.skip(reason=\"slow test, deprecate fsdp1 asap\")\n    def test_dpo_fft(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"Qwen/Qwen2.5-0.5B\",\n                \"sequence_len\": 2048,\n                \"val_set_size\": 0.01,\n                \"rl\": \"dpo\",\n                \"chat_template\": \"chatml\",\n                \"datasets\": [\n                    {\n                        \"path\": \"Intel/orca_dpo_pairs\",\n                        \"split\": \"train\",\n                        \"type\": \"chatml.intel\",\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"max_steps\": 2,\n                \"micro_batch_size\": 2,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_torch_fused\",\n                \"lr_scheduler\": \"cosine\",\n                \"flash_attention\": True,\n                \"fsdp_version\": \"1\",\n                \"fsdp_config\": {\n                    \"fsdp_offload_params\": False,\n                    \"fsdp_cpu_ram_efficient_loading\": True,\n                    \"fsdp_transformer_layer_cls_to_wrap\": \"Qwen2DecoderLayer\",\n                    \"fsdp_state_dict_type\": \"FULL_STATE_DICT\",\n                    \"fsdp_auto_wrap_policy\": \"TRANSFORMER_BASED_WRAP\",\n                    \"fsdp_sharding_strategy\": \"FULL_SHARD\",\n                    \"fsdp_sync_module_states\": True,\n                    \"fsdp_use_orig_params\": False,\n                },\n                \"use_tensorboard\": True,\n            }\n        )\n\n        # write cfg to yaml file\n        Path(temp_dir).mkdir(parents=True, exist_ok=True)\n        with open(Path(temp_dir) / \"config.yaml\", \"w\", encoding=\"utf-8\") as fout:\n            fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))\n\n        execute_subprocess_async(\n            [\n                \"axolotl\",\n                \"train\",\n                str(Path(temp_dir) / \"config.yaml\"),\n                \"--num-processes\",\n                \"2\",\n                \"--main-process-port\",\n                f\"{get_torch_dist_unique_port()}\",\n            ]\n        )\n\n        verify_training_success(temp_dir)\n\n    @pytest.mark.skip(\"broken in transformers v5\")\n    @pytest.mark.parametrize(\n        \"adapter_config\",\n        [\n            {\n                \"adapter\": \"lora\",\n                \"load_in_4bit\": False,\n            },\n            {\n                \"adapter\": \"qlora\",\n                \"load_in_4bit\": True,\n            },\n        ],\n    )\n    def test_dpo_lora(self, temp_dir, adapter_config):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"Qwen/Qwen2.5-0.5B\",\n                \"load_in_4bit\": adapter_config[\"load_in_4bit\"],\n                \"rl\": \"dpo\",\n                \"chat_template\": \"chatml\",\n                \"sequence_len\": 2048,\n                \"adapter\": adapter_config[\"adapter\"],\n                \"lora_r\": 8,\n                \"lora_alpha\": 16,\n                \"lora_dropout\": 0.05,\n                \"lora_target_linear\": True,\n                \"val_set_size\": 0.01,\n                \"datasets\": [\n                    {\n                        \"path\": \"Intel/orca_dpo_pairs\",\n                        \"split\": \"train\",\n                        \"type\": \"chatml.intel\",\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"max_steps\": 2,\n                \"micro_batch_size\": 2,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_torch_fused\",\n                \"lr_scheduler\": \"cosine\",\n                \"flash_attention\": True,\n                \"fsdp_version\": \"1\",\n                \"fsdp_config\": {\n                    \"fsdp_offload_params\": False,\n                    \"fsdp_cpu_ram_efficient_loading\": True,\n                    \"fsdp_transformer_layer_cls_to_wrap\": \"Qwen2DecoderLayer\",\n                    \"fsdp_state_dict_type\": \"FULL_STATE_DICT\",\n                    \"fsdp_auto_wrap_policy\": \"TRANSFORMER_BASED_WRAP\",\n                    \"fsdp_sharding_strategy\": \"FULL_SHARD\",\n                    \"fsdp_sync_module_states\": True,\n                    \"fsdp_use_orig_params\": False,\n                },\n                \"use_tensorboard\": True,\n                \"bf16\": \"auto\",\n                \"tf32\": True,\n            }\n        )\n\n        # write cfg to yaml file\n        Path(temp_dir).mkdir(parents=True, exist_ok=True)\n        with open(Path(temp_dir) / \"config.yaml\", \"w\", encoding=\"utf-8\") as fout:\n            fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))\n\n        execute_subprocess_async(\n            [\n                \"axolotl\",\n                \"train\",\n                str(Path(temp_dir) / \"config.yaml\"),\n                \"--num-processes\",\n                \"2\",\n                \"--main-process-port\",\n                f\"{get_torch_dist_unique_port()}\",\n            ]\n        )\n\n        verify_training_success(temp_dir)\n"
  },
  {
    "path": "tests/e2e/multigpu/test_fsdp2.py",
    "content": "\"\"\"Test module for FSDP2 multi-GPU functionality.\"\"\"\n\nimport os\nfrom pathlib import Path\n\nimport pytest\nimport torch\nimport yaml\nfrom accelerate.test_utils import execute_subprocess_async\nfrom tbparse import SummaryReader\nfrom transformers.testing_utils import get_torch_dist_unique_port\n\nfrom axolotl.utils.dict import DictDefault\n\nfrom tests.e2e.utils import most_recent_subdir, require_torch_2_7_0\n\nAXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent\n\n\ndef verify_training_success(temp_dir):\n    \"\"\"Verify that training completed successfully by checking artifacts and loss.\"\"\"\n    output_path = Path(temp_dir)\n\n    model_files = list(output_path.glob(\"*.bin\")) + list(\n        output_path.glob(\"*.safetensors\")\n    )\n    assert len(model_files) > 0, \"No model files found - training may have failed\"\n\n    checkpoint_files = list(output_path.glob(\"checkpoint-*\"))\n    assert len(checkpoint_files) > 0, (\n        \"No checkpoint files found - training may have failed\"\n    )\n\n    tb_log_path = most_recent_subdir(temp_dir + \"/runs\")\n    if tb_log_path:\n        event_files = sorted(os.listdir(tb_log_path))\n        if event_files:\n            event_file = os.path.join(tb_log_path, event_files[0])\n            reader = SummaryReader(event_file)\n            df = reader.scalars\n            train_loss_df = df[df.tag == \"train/train_loss\"]\n            if len(train_loss_df) > 0:\n                final_loss = train_loss_df.value.values[-1]\n                assert not torch.isnan(torch.tensor(final_loss)), (\n                    f\"Training loss is NaN: {final_loss}\"\n                )\n\n\nclass TestFSDP2:\n    \"\"\"Test class for FSDP2 functionality.\"\"\"\n\n    @require_torch_2_7_0\n    @pytest.mark.parametrize(\n        \"fsdp_cpu_ram_efficient_loading\",\n        [True, False],\n    )\n    def test_fft_sft(self, temp_dir, fsdp_cpu_ram_efficient_loading):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"Qwen/Qwen2.5-0.5B\",\n                \"sequence_len\": 2048,\n                \"val_set_size\": 0.01,\n                \"datasets\": [\n                    {\n                        \"path\": \"tatsu-lab/alpaca\",\n                        \"type\": \"alpaca\",\n                        \"split\": \"train[:10%]\",\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"max_steps\": 2,\n                \"micro_batch_size\": 2,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_torch_fused\",\n                \"lr_scheduler\": \"cosine\",\n                \"flash_attention\": True,\n                \"fsdp_version\": 2,\n                \"fsdp_config\": {\n                    \"offload_params\": False,\n                    \"cpu_ram_efficient_loading\": fsdp_cpu_ram_efficient_loading,\n                    \"transformer_layer_cls_to_wrap\": \"Qwen2DecoderLayer\",\n                    \"state_dict_type\": \"FULL_STATE_DICT\",\n                    \"auto_wrap_policy\": \"TRANSFORMER_BASED_WRAP\",\n                    \"reshard_after_forward\": True,\n                },\n                \"use_tensorboard\": True,\n                \"bf16\": True,\n            }\n        )\n\n        # write cfg to yaml file\n        Path(temp_dir).mkdir(parents=True, exist_ok=True)\n        with open(Path(temp_dir) / \"config.yaml\", \"w\", encoding=\"utf-8\") as fout:\n            fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))\n\n        execute_subprocess_async(\n            [\n                \"axolotl\",\n                \"train\",\n                str(Path(temp_dir) / \"config.yaml\"),\n                \"--num-processes\",\n                \"2\",\n                \"--main-process-port\",\n                f\"{get_torch_dist_unique_port()}\",\n            ]\n        )\n\n        verify_training_success(temp_dir)\n\n    @require_torch_2_7_0\n    @pytest.mark.parametrize(\"peft_use_dora\", [True, False])\n    def test_lora_sft(self, temp_dir, peft_use_dora):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"Qwen/Qwen2.5-0.5B\",\n                \"sequence_len\": 2048,\n                \"val_set_size\": 0.01,\n                \"datasets\": [\n                    {\n                        \"path\": \"tatsu-lab/alpaca\",\n                        \"type\": \"alpaca\",\n                        \"split\": \"train[:10%]\",\n                    },\n                ],\n                \"peft_use_dora\": peft_use_dora,\n                \"adapter\": \"lora\",\n                \"lora_r\": 8,\n                \"lora_alpha\": 16,\n                \"lora_dropout\": 0.05,\n                \"lora_target_linear\": True,\n                \"num_epochs\": 1,\n                \"max_steps\": 2,\n                \"micro_batch_size\": 2,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_torch_fused\",\n                \"lr_scheduler\": \"cosine\",\n                \"flash_attention\": True,\n                \"fsdp_version\": 2,\n                \"fsdp_config\": {\n                    \"offload_params\": False,\n                    \"cpu_ram_efficient_loading\": False,\n                    \"transformer_layer_cls_to_wrap\": \"Qwen2DecoderLayer\",\n                    \"state_dict_type\": \"FULL_STATE_DICT\",\n                    \"auto_wrap_policy\": \"TRANSFORMER_BASED_WRAP\",\n                    \"reshard_after_forward\": True,\n                },\n                \"use_tensorboard\": True,\n                \"bf16\": True,\n                # explicitly disable LORA kernels, as they may be auto-enabled\n                \"lora_mlp_kernel\": False,\n                \"lora_qkv_kernel\": False,\n                \"lora_o_kernel\": False,\n            }\n        )\n\n        # write cfg to yaml file\n        Path(temp_dir).mkdir(parents=True, exist_ok=True)\n        with open(Path(temp_dir) / \"config.yaml\", \"w\", encoding=\"utf-8\") as fout:\n            fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))\n\n        execute_subprocess_async(\n            [\n                \"axolotl\",\n                \"train\",\n                str(Path(temp_dir) / \"config.yaml\"),\n                \"--num-processes\",\n                \"2\",\n                \"--main-process-port\",\n                f\"{get_torch_dist_unique_port()}\",\n            ]\n        )\n\n        verify_training_success(temp_dir)\n\n    @require_torch_2_7_0\n    def test_lora_sft_kernels(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"Qwen/Qwen2.5-0.5B\",\n                \"sequence_len\": 2048,\n                \"val_set_size\": 0.01,\n                \"datasets\": [\n                    {\n                        \"path\": \"tatsu-lab/alpaca\",\n                        \"type\": \"alpaca\",\n                        \"split\": \"train[:10%]\",\n                    },\n                ],\n                \"adapter\": \"lora\",\n                \"lora_r\": 8,\n                \"lora_alpha\": 16,\n                \"lora_target_linear\": True,\n                \"num_epochs\": 1,\n                \"max_steps\": 2,\n                \"micro_batch_size\": 2,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_torch_fused\",\n                \"lr_scheduler\": \"cosine\",\n                \"flash_attention\": True,\n                \"fsdp_version\": 2,\n                \"fsdp_config\": {\n                    \"offload_params\": False,\n                    \"cpu_ram_efficient_loading\": False,\n                    \"transformer_layer_cls_to_wrap\": \"Qwen2DecoderLayer\",\n                    \"state_dict_type\": \"FULL_STATE_DICT\",\n                    \"auto_wrap_policy\": \"TRANSFORMER_BASED_WRAP\",\n                    \"reshard_after_forward\": True,\n                },\n                \"use_tensorboard\": True,\n                \"bf16\": True,\n                \"lora_mlp_kernel\": True,\n                \"lora_qkv_kernel\": True,\n                \"lora_o_kernel\": True,\n            }\n        )\n\n        # write cfg to yaml file\n        Path(temp_dir).mkdir(parents=True, exist_ok=True)\n        with open(Path(temp_dir) / \"config.yaml\", \"w\", encoding=\"utf-8\") as fout:\n            fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))\n\n        execute_subprocess_async(\n            [\n                \"axolotl\",\n                \"train\",\n                str(Path(temp_dir) / \"config.yaml\"),\n                \"--num-processes\",\n                \"2\",\n                \"--main-process-port\",\n                f\"{get_torch_dist_unique_port()}\",\n            ]\n        )\n\n        verify_training_success(temp_dir)\n\n    @require_torch_2_7_0\n    def test_qlora_sft(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"Qwen/Qwen2.5-0.5B\",\n                \"sequence_len\": 2048,\n                \"val_set_size\": 0.01,\n                \"datasets\": [\n                    {\n                        \"path\": \"tatsu-lab/alpaca\",\n                        \"type\": \"alpaca\",\n                        \"split\": \"train[:10%]\",\n                    },\n                ],\n                \"load_in_4bit\": True,\n                \"adapter\": \"qlora\",\n                \"lora_r\": 8,\n                \"lora_alpha\": 16,\n                \"lora_dropout\": 0.05,\n                \"lora_target_linear\": True,\n                \"num_epochs\": 1,\n                \"max_steps\": 2,\n                \"micro_batch_size\": 2,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_torch_fused\",\n                \"lr_scheduler\": \"cosine\",\n                \"flash_attention\": True,\n                \"fsdp_version\": 2,\n                \"fsdp_config\": {\n                    \"offload_params\": False,\n                    \"cpu_ram_efficient_loading\": False,\n                    \"transformer_layer_cls_to_wrap\": \"Qwen2DecoderLayer\",\n                    \"state_dict_type\": \"FULL_STATE_DICT\",\n                    \"auto_wrap_policy\": \"TRANSFORMER_BASED_WRAP\",\n                    \"reshard_after_forward\": True,\n                },\n                \"use_tensorboard\": True,\n                \"bf16\": True,\n            }\n        )\n\n        # write cfg to yaml file\n        Path(temp_dir).mkdir(parents=True, exist_ok=True)\n        with open(Path(temp_dir) / \"config.yaml\", \"w\", encoding=\"utf-8\") as fout:\n            fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))\n\n        execute_subprocess_async(\n            [\n                \"axolotl\",\n                \"train\",\n                str(Path(temp_dir) / \"config.yaml\"),\n                \"--num-processes\",\n                \"2\",\n                \"--main-process-port\",\n                f\"{get_torch_dist_unique_port()}\",\n            ]\n        )\n\n        verify_training_success(temp_dir)\n\n    @require_torch_2_7_0\n    def test_qlora_sft_kernels(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"Qwen/Qwen2.5-0.5B\",\n                \"sequence_len\": 2048,\n                \"val_set_size\": 0.01,\n                \"datasets\": [\n                    {\n                        \"path\": \"tatsu-lab/alpaca\",\n                        \"type\": \"alpaca\",\n                        \"split\": \"train[:10%]\",\n                    },\n                ],\n                \"load_in_4bit\": True,\n                \"adapter\": \"qlora\",\n                \"lora_r\": 8,\n                \"lora_alpha\": 16,\n                \"lora_target_linear\": True,\n                \"num_epochs\": 1,\n                \"max_steps\": 2,\n                \"micro_batch_size\": 2,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_torch_fused\",\n                \"lr_scheduler\": \"cosine\",\n                \"flash_attention\": True,\n                \"fsdp_version\": 2,\n                \"fsdp_config\": {\n                    \"offload_params\": False,\n                    \"cpu_ram_efficient_loading\": False,\n                    \"transformer_layer_cls_to_wrap\": \"Qwen2DecoderLayer\",\n                    \"state_dict_type\": \"FULL_STATE_DICT\",\n                    \"auto_wrap_policy\": \"TRANSFORMER_BASED_WRAP\",\n                    \"reshard_after_forward\": True,\n                },\n                \"use_tensorboard\": True,\n                \"bf16\": True,\n                \"lora_mlp_kernel\": True,\n                \"lora_qkv_kernel\": True,\n                \"lora_o_kernel\": True,\n            }\n        )\n\n        # write cfg to yaml file\n        Path(temp_dir).mkdir(parents=True, exist_ok=True)\n        with open(Path(temp_dir) / \"config.yaml\", \"w\", encoding=\"utf-8\") as fout:\n            fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))\n\n        execute_subprocess_async(\n            [\n                \"axolotl\",\n                \"train\",\n                str(Path(temp_dir) / \"config.yaml\"),\n                \"--num-processes\",\n                \"2\",\n                \"--main-process-port\",\n                f\"{get_torch_dist_unique_port()}\",\n            ]\n        )\n\n        verify_training_success(temp_dir)\n\n    @pytest.mark.skip(reason=\"slow test w cu129 + torch 2.9.1 + py3.12\")\n    @require_torch_2_7_0\n    def test_dpo_fft(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"Qwen/Qwen2.5-0.5B\",\n                \"sequence_len\": 2048,\n                \"val_set_size\": 0.01,\n                \"rl\": \"dpo\",\n                \"chat_template\": \"chatml\",\n                \"datasets\": [\n                    {\n                        \"path\": \"Intel/orca_dpo_pairs\",\n                        \"split\": \"train\",\n                        \"type\": \"chatml.intel\",\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"max_steps\": 2,\n                \"micro_batch_size\": 2,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_torch_fused\",\n                \"lr_scheduler\": \"cosine\",\n                \"flash_attention\": True,\n                \"fsdp_version\": 2,\n                \"fsdp_config\": {\n                    \"offload_params\": False,\n                    \"cpu_ram_efficient_loading\": False,\n                    \"transformer_layer_cls_to_wrap\": \"Qwen2DecoderLayer\",\n                    \"state_dict_type\": \"FULL_STATE_DICT\",\n                    \"auto_wrap_policy\": \"TRANSFORMER_BASED_WRAP\",\n                    \"reshard_after_forward\": True,\n                },\n                \"use_tensorboard\": True,\n            }\n        )\n\n        # write cfg to yaml file\n        Path(temp_dir).mkdir(parents=True, exist_ok=True)\n        with open(Path(temp_dir) / \"config.yaml\", \"w\", encoding=\"utf-8\") as fout:\n            fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))\n\n        execute_subprocess_async(\n            [\n                \"axolotl\",\n                \"train\",\n                str(Path(temp_dir) / \"config.yaml\"),\n                \"--num-processes\",\n                \"2\",\n                \"--main-process-port\",\n                f\"{get_torch_dist_unique_port()}\",\n            ]\n        )\n\n        verify_training_success(temp_dir)\n\n    @pytest.mark.skip(reason=\"slow test w cu129 + torch 2.9.1 + py3.12\")\n    @require_torch_2_7_0\n    def test_dpo_lora(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"Qwen/Qwen2.5-0.5B\",\n                \"sequence_len\": 2048,\n                \"rl\": \"dpo\",\n                \"chat_template\": \"chatml\",\n                \"datasets\": [\n                    {\n                        \"path\": \"Intel/orca_dpo_pairs\",\n                        \"split\": \"train\",\n                        \"type\": \"chatml.intel\",\n                    },\n                ],\n                \"adapter\": \"lora\",\n                \"lora_r\": 8,\n                \"lora_alpha\": 16,\n                \"lora_dropout\": 0.05,\n                \"lora_target_linear\": True,\n                \"num_epochs\": 1,\n                \"max_steps\": 2,\n                \"micro_batch_size\": 2,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_torch_fused\",\n                \"lr_scheduler\": \"cosine\",\n                \"flash_attention\": True,\n                \"fsdp_version\": 2,\n                \"fsdp_config\": {\n                    \"offload_params\": False,\n                    \"cpu_ram_efficient_loading\": False,\n                    \"transformer_layer_cls_to_wrap\": \"Qwen2DecoderLayer\",\n                    \"state_dict_type\": \"FULL_STATE_DICT\",\n                    \"auto_wrap_policy\": \"TRANSFORMER_BASED_WRAP\",\n                    \"reshard_after_forward\": True,\n                },\n                \"use_tensorboard\": True,\n            }\n        )\n\n        # write cfg to yaml file\n        Path(temp_dir).mkdir(parents=True, exist_ok=True)\n        with open(Path(temp_dir) / \"config.yaml\", \"w\", encoding=\"utf-8\") as fout:\n            fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))\n\n        execute_subprocess_async(\n            [\n                \"axolotl\",\n                \"train\",\n                str(Path(temp_dir) / \"config.yaml\"),\n                \"--num-processes\",\n                \"2\",\n                \"--main-process-port\",\n                f\"{get_torch_dist_unique_port()}\",\n            ]\n        )\n\n        verify_training_success(temp_dir)\n"
  },
  {
    "path": "tests/e2e/multigpu/test_gemma3.py",
    "content": "\"\"\"\nE2E tests for multigpu lora tinyllama\n\"\"\"\n\nfrom pathlib import Path\n\nimport pytest\nimport yaml\nfrom accelerate.test_utils import execute_subprocess_async\nfrom huggingface_hub import snapshot_download\nfrom transformers.testing_utils import get_torch_dist_unique_port\n\nfrom axolotl.utils.dict import DictDefault\n\nfrom tests.e2e.utils import check_tensorboard\n\nAXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent\n\n\n@pytest.fixture(scope=\"session\", autouse=True)\ndef download_model():\n    # download the model\n    snapshot_download(\"axolotl-mirrors/gemma-3-4b-pt\", repo_type=\"model\")\n\n\n@pytest.mark.skip(reason=\"FIXME\")\nclass TestMultiGPUGemma3:\n    \"\"\"\n    Test case for Gemma3 models using LoRA\n    \"\"\"\n\n    def test_lora_ddp_packed(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"axolotl-mirrors/gemma-3-4b-pt\",\n                \"unfrozen_parameters\": [\"model.language_model.*\", \"lm_head\"],\n                \"sequence_len\": 2048,\n                \"ddp_find_unused_parameters\": True,\n                \"sample_packing\": True,\n                \"eval_sample_packing\": False,\n                \"pad_to_sequence_len\": True,\n                \"adapter\": \"lora\",\n                \"lora_r\": 8,\n                \"lora_alpha\": 16,\n                \"lora_dropout\": 0.05,\n                \"lora_target_linear\": True,\n                \"val_set_size\": 0.0,\n                \"chat_template\": \"gemma3\",\n                \"datasets\": [\n                    {\n                        \"path\": \"mlabonne/FineTome-100k\",\n                        \"type\": \"chat_template\",\n                        \"split\": \"train[:10%]\",\n                        \"field_messages\": \"conversations\",\n                        \"message_field_role\": \"from\",\n                        \"message_field_content\": \"value\",\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"max_steps\": 2,\n                \"micro_batch_size\": 4,\n                \"gradient_checkpointing\": True,\n                \"gradient_checkpointing_kwargs\": {\n                    \"use_reentrant\": False,\n                },\n                \"gradient_accumulation_steps\": 2,\n                \"output_dir\": temp_dir,\n                \"dataset_prepared_path\": temp_dir + \"/last_run_prepared\",\n                \"learning_rate\": 0.0001,\n                \"optimizer\": \"adamw_8bit\",\n                \"lr_scheduler\": \"cosine\",\n                \"flash_attention\": True,\n                \"use_tensorboard\": True,\n                \"bf16\": True,\n                \"save_first_step\": False,\n            }\n        )\n\n        # write cfg to yaml file\n        Path(temp_dir).mkdir(parents=True, exist_ok=True)\n        with open(Path(temp_dir) / \"config.yaml\", \"w\", encoding=\"utf-8\") as fout:\n            fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))\n\n        execute_subprocess_async(\n            [\n                \"axolotl\",\n                \"train\",\n                str(Path(temp_dir) / \"config.yaml\"),\n                \"--num-processes\",\n                \"2\",\n                \"--main-process-port\",\n                f\"{get_torch_dist_unique_port()}\",\n            ]\n        )\n\n        check_tensorboard(\n            temp_dir + \"/runs\", \"train/train_loss\", 1.8, \"Train Loss (%s) is too high\"\n        )\n"
  },
  {
    "path": "tests/e2e/multigpu/test_llama.py",
    "content": "\"\"\"\nE2E tests for multigpu lora tinyllama\n\"\"\"\n\nfrom pathlib import Path\n\nimport pytest\nimport transformers\nimport yaml\nfrom accelerate.test_utils import execute_subprocess_async\nfrom huggingface_hub import snapshot_download\nfrom packaging import version\nfrom transformers.testing_utils import get_torch_dist_unique_port\n\nfrom axolotl.utils.dict import DictDefault\n\nfrom tests.e2e.utils import check_tensorboard, require_torch_2_6_0\n\nAXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent\n\n\n@pytest.fixture(scope=\"session\", autouse=True)\ndef download_model():\n    # download the model\n    snapshot_download(\"HuggingFaceTB/SmolLM2-135M\")\n\n\ndef transformers_version_eq(required_version):\n    return version.parse(transformers.__version__) == version.parse(required_version)\n\n\nclass TestMultiGPULlama:\n    \"\"\"\n    Test case for Llama models using LoRA\n    \"\"\"\n\n    def test_lora_ddp(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"sequence_len\": 2048,\n                \"adapter\": \"lora\",\n                \"lora_r\": 8,\n                \"lora_alpha\": 16,\n                \"lora_dropout\": 0.05,\n                \"lora_target_linear\": True,\n                \"val_set_size\": 0.01,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"tatsu-lab/alpaca\",\n                        \"type\": \"alpaca\",\n                        \"split\": \"train[:10%]\",\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"max_steps\": 2,\n                \"micro_batch_size\": 1,\n                \"gradient_accumulation_steps\": 2,\n                # \"gradient_checkpointing\": True,\n                \"output_dir\": temp_dir,\n                \"dataset_prepared_path\": temp_dir + \"/last_run_prepared\",\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_8bit\",\n                \"lr_scheduler\": \"cosine\",\n                \"flash_attention\": True,\n                \"use_tensorboard\": True,\n                \"bf16\": True,\n                \"save_first_step\": False,\n            }\n        )\n\n        # write cfg to yaml file\n        Path(temp_dir).mkdir(parents=True, exist_ok=True)\n        with open(Path(temp_dir) / \"config.yaml\", \"w\", encoding=\"utf-8\") as fout:\n            fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))\n\n        execute_subprocess_async(\n            [\n                \"axolotl\",\n                \"train\",\n                str(Path(temp_dir) / \"config.yaml\"),\n                \"--num-processes\",\n                \"2\",\n                \"--main-process-port\",\n                f\"{get_torch_dist_unique_port()}\",\n            ]\n        )\n\n        check_tensorboard(\n            temp_dir + \"/runs\", \"train/train_loss\", 2.8, \"Train Loss (%s) is too high\"\n        )\n\n    @pytest.mark.parametrize(\n        \"gradient_accumulation_steps\",\n        [1, 2],\n    )\n    def test_lora_ddp_packed(self, temp_dir, gradient_accumulation_steps):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"sequence_len\": 2048,\n                \"sample_packing\": True,\n                \"eval_sample_packing\": False,\n                \"pad_to_sequence_len\": True,\n                \"adapter\": \"lora\",\n                \"lora_r\": 8,\n                \"lora_alpha\": 16,\n                \"lora_dropout\": 0.05,\n                \"lora_target_linear\": True,\n                \"val_set_size\": 0.05,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"tatsu-lab/alpaca\",\n                        \"type\": \"alpaca\",\n                        \"split\": \"train[:20%]\",\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"max_steps\": 2,\n                \"micro_batch_size\": 1,\n                \"gradient_accumulation_steps\": gradient_accumulation_steps,\n                # \"gradient_checkpointing\": True,\n                \"output_dir\": temp_dir,\n                \"dataset_prepared_path\": temp_dir + \"/last_run_prepared\",\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_8bit\",\n                \"lr_scheduler\": \"cosine\",\n                \"flash_attention\": True,\n                \"use_tensorboard\": True,\n                \"bf16\": True,\n                \"save_first_step\": False,\n            }\n        )\n\n        # write cfg to yaml file\n        Path(temp_dir).mkdir(parents=True, exist_ok=True)\n        with open(Path(temp_dir) / \"config.yaml\", \"w\", encoding=\"utf-8\") as fout:\n            fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))\n\n        execute_subprocess_async(\n            [\n                \"axolotl\",\n                \"train\",\n                str(Path(temp_dir) / \"config.yaml\"),\n                \"--num-processes\",\n                \"2\",\n                \"--main-process-port\",\n                f\"{get_torch_dist_unique_port()}\",\n            ]\n        )\n\n        check_tensorboard(\n            temp_dir + \"/runs\", \"train/train_loss\", 2.3, \"Train Loss (%s) is too high\"\n        )\n\n    def test_dpo_lora_ddp(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"sequence_len\": 2048,\n                \"sample_packing\": False,\n                \"eval_sample_packing\": False,\n                \"pad_to_sequence_len\": True,\n                \"load_in_8bit\": True,\n                \"adapter\": \"lora\",\n                \"lora_r\": 8,\n                \"lora_alpha\": 16,\n                \"lora_dropout\": 0.05,\n                \"lora_target_linear\": True,\n                \"val_set_size\": 0.01,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"rl\": \"dpo\",\n                \"chat_template\": \"chatml\",\n                \"datasets\": [\n                    {\n                        \"path\": \"fozziethebeat/alpaca_messages_2k_dpo_test\",\n                        \"type\": \"chat_template.default\",\n                        \"field_messages\": \"conversation\",\n                        \"field_chosen\": \"chosen\",\n                        \"field_rejected\": \"rejected\",\n                        \"message_field_role\": \"role\",\n                        \"message_field_content\": \"content\",\n                        \"roles\": {\n                            \"system\": [\"system\"],\n                            \"user\": [\"user\"],\n                            \"assistant\": [\"assistant\"],\n                        },\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"max_steps\": 2,\n                \"micro_batch_size\": 2,\n                \"gradient_accumulation_steps\": 2,\n                \"gradient_checkpointing\": False,\n                \"output_dir\": temp_dir,\n                \"dataset_prepared_path\": temp_dir + \"/last_run_prepared\",\n                \"warmup_steps\": 0,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_8bit\",\n                \"lr_scheduler\": \"cosine\",\n                \"flash_attention\": True,\n                \"use_tensorboard\": True,\n                \"bf16\": True,\n                \"save_first_step\": False,\n            }\n        )\n\n        # write cfg to yaml file\n        Path(temp_dir).mkdir(parents=True, exist_ok=True)\n        with open(Path(temp_dir) / \"config.yaml\", \"w\", encoding=\"utf-8\") as fout:\n            fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))\n\n        execute_subprocess_async(\n            [\n                \"axolotl\",\n                \"train\",\n                str(Path(temp_dir) / \"config.yaml\"),\n                \"--num-processes\",\n                \"2\",\n                \"--main-process-port\",\n                f\"{get_torch_dist_unique_port()}\",\n            ]\n        )\n\n        loss_threshold = 2.3\n        check_tensorboard(\n            temp_dir + \"/runs\",\n            \"train/train_loss\",\n            loss_threshold,\n            \"Train Loss (%s) is too high\",\n        )\n\n    def test_dpo_qlora_ddp(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"sequence_len\": 2048,\n                \"sample_packing\": False,\n                \"eval_sample_packing\": False,\n                \"pad_to_sequence_len\": True,\n                \"load_in_4bit\": True,\n                \"adapter\": \"qlora\",\n                \"lora_r\": 8,\n                \"lora_alpha\": 16,\n                \"lora_dropout\": 0.05,\n                \"lora_target_linear\": True,\n                \"val_set_size\": 0.01,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"rl\": \"dpo\",\n                \"chat_template\": \"chatml\",\n                \"datasets\": [\n                    {\n                        \"path\": \"fozziethebeat/alpaca_messages_2k_dpo_test\",\n                        \"type\": \"chat_template.default\",\n                        \"field_messages\": \"conversation\",\n                        \"field_chosen\": \"chosen\",\n                        \"field_rejected\": \"rejected\",\n                        \"message_field_role\": \"role\",\n                        \"message_field_content\": \"content\",\n                        \"roles\": {\n                            \"system\": [\"system\"],\n                            \"user\": [\"user\"],\n                            \"assistant\": [\"assistant\"],\n                        },\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"max_steps\": 2,\n                \"micro_batch_size\": 2,\n                \"gradient_accumulation_steps\": 2,\n                \"gradient_checkpointing\": False,\n                \"output_dir\": temp_dir,\n                \"dataset_prepared_path\": temp_dir + \"/last_run_prepared\",\n                \"warmup_steps\": 0,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_8bit\",\n                \"lr_scheduler\": \"cosine\",\n                \"flash_attention\": True,\n                \"use_tensorboard\": True,\n                \"bf16\": True,\n                \"save_first_step\": False,\n            }\n        )\n\n        # write cfg to yaml file\n        Path(temp_dir).mkdir(parents=True, exist_ok=True)\n        with open(Path(temp_dir) / \"config.yaml\", \"w\", encoding=\"utf-8\") as fout:\n            fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))\n\n        execute_subprocess_async(\n            [\n                \"axolotl\",\n                \"train\",\n                str(Path(temp_dir) / \"config.yaml\"),\n                \"--num-processes\",\n                \"2\",\n                \"--main-process-port\",\n                f\"{get_torch_dist_unique_port()}\",\n            ]\n        )\n\n        loss_threshold = 2.3\n        check_tensorboard(\n            temp_dir + \"/runs\",\n            \"train/train_loss\",\n            loss_threshold,\n            \"Train Loss (%s) is too high\",\n        )\n\n    @pytest.mark.parametrize(\n        \"gradient_accumulation_steps\",\n        [1, 2],\n    )\n    def test_fsdp(self, temp_dir, gradient_accumulation_steps):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"sequence_len\": 2048,\n                \"val_set_size\": 0.01,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"tatsu-lab/alpaca\",\n                        \"type\": \"alpaca\",\n                        \"split\": \"train[:10%]\",\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"max_steps\": 2,\n                \"micro_batch_size\": 2,\n                \"gradient_accumulation_steps\": gradient_accumulation_steps,\n                # \"gradient_checkpointing\": True,\n                \"output_dir\": temp_dir,\n                \"dataset_prepared_path\": temp_dir + \"/last_run_prepared\",\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_torch_fused\",\n                \"lr_scheduler\": \"cosine\",\n                \"flash_attention\": True,\n                \"fsdp\": [\n                    \"full_shard\",\n                    \"auto_wrap\",\n                ],\n                \"fsdp_config\": {\n                    \"fsdp_offload_params\": False,\n                    \"fsdp_sync_module_states\": True,\n                    \"fsdp_use_orig_params\": False,\n                    \"fsdp_cpu_ram_efficient_loading\": False,\n                    \"fsdp_transformer_layer_cls_to_wrap\": \"LlamaDecoderLayer\",\n                    \"fsdp_state_dict_type\": \"FULL_STATE_DICT\",\n                    \"fsdp_auto_wrap_policy\": \"TRANSFORMER_BASED_WRAP\",\n                },\n                \"use_tensorboard\": True,\n                \"seed\": 42,\n                \"save_first_step\": False,\n            }\n        )\n\n        # write cfg to yaml file\n        Path(temp_dir).mkdir(parents=True, exist_ok=True)\n        with open(Path(temp_dir) / \"config.yaml\", \"w\", encoding=\"utf-8\") as fout:\n            fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))\n\n        execute_subprocess_async(\n            [\n                \"axolotl\",\n                \"train\",\n                str(Path(temp_dir) / \"config.yaml\"),\n                \"--num-processes\",\n                \"2\",\n                \"--main-process-port\",\n                f\"{get_torch_dist_unique_port()}\",\n            ]\n        )\n\n        check_tensorboard(\n            temp_dir + \"/runs\", \"train/train_loss\", 2.3, \"Train Loss (%s) is too high\"\n        )\n\n    @pytest.mark.parametrize(\n        \"fsdp_state_dict_type\",\n        [\n            \"FULL_STATE_DICT\",\n            # \"SHARDED_STATE_DICT\",  # not supported since intermediate checkpoints fail with fsdp1\n        ],\n    )\n    def test_fsdp_packed(self, temp_dir, fsdp_state_dict_type):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"sample_packing\": True,\n                \"pad_to_sequence_len\": True,\n                \"sequence_len\": 1024,\n                \"val_set_size\": 0.05,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"tatsu-lab/alpaca\",\n                        \"type\": \"alpaca\",\n                        \"split\": \"train[:10%]\",\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"max_steps\": 3,\n                \"save_steps\": 2,\n                \"micro_batch_size\": 2,\n                \"gradient_accumulation_steps\": 2,\n                # \"gradient_checkpointing\": True,\n                \"output_dir\": temp_dir,\n                \"dataset_prepared_path\": temp_dir + \"/last_run_prepared\",\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_torch_fused\",\n                \"lr_scheduler\": \"cosine\",\n                \"flash_attention\": True,\n                \"fsdp\": [\n                    \"full_shard\",\n                    \"auto_wrap\",\n                ],\n                \"fsdp_config\": {\n                    \"fsdp_offload_params\": False,\n                    \"fsdp_sync_module_states\": True,\n                    \"fsdp_use_orig_params\": False,\n                    \"fsdp_cpu_ram_efficient_loading\": False,\n                    \"fsdp_transformer_layer_cls_to_wrap\": \"LlamaDecoderLayer\",\n                    \"fsdp_state_dict_type\": fsdp_state_dict_type,\n                    \"fsdp_auto_wrap_policy\": \"TRANSFORMER_BASED_WRAP\",\n                },\n                \"use_tensorboard\": True,\n                \"save_first_step\": False,\n            }\n        )\n\n        # write cfg to yaml file\n        Path(temp_dir).mkdir(parents=True, exist_ok=True)\n        with open(Path(temp_dir) / \"config.yaml\", \"w\", encoding=\"utf-8\") as fout:\n            fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))\n\n        execute_subprocess_async(\n            [\n                \"axolotl\",\n                \"train\",\n                str(Path(temp_dir) / \"config.yaml\"),\n                \"--num-processes\",\n                \"2\",\n                \"--main-process-port\",\n                f\"{get_torch_dist_unique_port()}\",\n            ]\n        )\n\n        check_tensorboard(\n            temp_dir + \"/runs\", \"train/train_loss\", 2.3, \"Train Loss (%s) is too high\"\n        )\n\n    @require_torch_2_6_0\n    @pytest.mark.parametrize(\n        \"attention_backend\",\n        [\"flash\", \"flex\"],\n    )\n    @pytest.mark.parametrize(\n        \"fsdp_reshard_after_forward\",\n        [True, False],\n    )\n    def test_fsdp2_packed(\n        self, temp_dir, attention_backend, fsdp_reshard_after_forward\n    ):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"sample_packing\": True,\n                \"pad_to_sequence_len\": True,\n                \"sequence_len\": 2048,\n                \"val_set_size\": 0.1,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"tatsu-lab/alpaca\",\n                        \"type\": \"alpaca\",\n                        \"split\": \"train[:10%]\",\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"max_steps\": 2,\n                \"micro_batch_size\": 4,\n                \"gradient_accumulation_steps\": 2,\n                \"gradient_checkpointing\": True,\n                \"output_dir\": temp_dir,\n                \"dataset_prepared_path\": temp_dir + \"/last_run_prepared\",\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_torch_8bit\",\n                \"lr_scheduler\": \"cosine\",\n                \"fsdp\": [\n                    \"auto_wrap\",\n                ],\n                \"fsdp_config\": {\n                    \"fsdp_version\": 2,\n                    # \"fsdp_forward_prefetch\": True,  # not yet implemented in accelerate\n                    \"fsdp_offload_params\": False,\n                    \"fsdp_cpu_ram_efficient_loading\": False,\n                    \"fsdp_transformer_layer_cls_to_wrap\": \"LlamaDecoderLayer\",\n                    \"fsdp_state_dict_type\": \"SHARDED_STATE_DICT\",\n                    \"fsdp_auto_wrap_policy\": \"TRANSFORMER_BASED_WRAP\",\n                    \"fsdp_reshard_after_forward\": fsdp_reshard_after_forward,\n                },\n                \"use_tensorboard\": True,\n                \"save_first_step\": False,\n            }\n        )\n        if attention_backend == \"flash\":\n            cfg.flash_attention = True\n        elif attention_backend == \"flex\":\n            cfg.flex_attention = True\n\n        # write cfg to yaml file\n        Path(temp_dir).mkdir(parents=True, exist_ok=True)\n        with open(Path(temp_dir) / \"config.yaml\", \"w\", encoding=\"utf-8\") as fout:\n            fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))\n\n        execute_subprocess_async(\n            [\n                \"axolotl\",\n                \"train\",\n                str(Path(temp_dir) / \"config.yaml\"),\n                \"--num-processes\",\n                \"2\",\n                \"--main-process-port\",\n                f\"{get_torch_dist_unique_port()}\",\n            ]\n        )\n\n        check_tensorboard(\n            temp_dir + \"/runs\", \"train/train_loss\", 2.1, \"Train Loss (%s) is too high\"\n        )\n\n    def test_fsdp_qlora_prequant_packed(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"axolotl-ai-co/SmolLM2-135M-bnb-nf4-bf16\",\n                \"adapter\": \"qlora\",\n                \"mean_resizing_embeddings\": True,\n                \"load_in_4bit\": True,\n                \"lora_r\": 8,\n                \"lora_alpha\": 16,\n                \"lora_dropout\": 0.05,\n                \"lora_target_linear\": True,\n                # \"lora_modules_to_save\": [\n                #     \"embed_tokens\",\n                #     \"lm_head\",\n                # ],\n                \"sample_packing\": True,\n                \"eval_sample_packing\": False,\n                \"pad_to_sequence_len\": True,\n                \"sequence_len\": 1024,\n                \"val_set_size\": 0.01,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"tatsu-lab/alpaca\",\n                        \"type\": \"alpaca\",\n                        \"split\": \"train[:10%]\",\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"max_steps\": 2,\n                \"micro_batch_size\": 2,\n                \"gradient_accumulation_steps\": 2,\n                # \"gradient_checkpointing\": True,\n                \"output_dir\": temp_dir,\n                \"dataset_prepared_path\": temp_dir + \"/last_run_prepared\",\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_torch_fused\",\n                \"lr_scheduler\": \"cosine\",\n                \"flash_attention\": True,\n                \"fsdp\": [\n                    \"full_shard\",\n                    \"auto_wrap\",\n                ],\n                \"fsdp_config\": {\n                    \"fsdp_offload_params\": False,\n                    \"fsdp_sync_module_states\": True,\n                    \"fsdp_use_orig_params\": False,\n                    \"fsdp_cpu_ram_efficient_loading\": True,\n                    \"fsdp_transformer_layer_cls_to_wrap\": \"LlamaDecoderLayer\",\n                    \"fsdp_state_dict_type\": \"FULL_STATE_DICT\",\n                    \"fsdp_auto_wrap_policy\": \"TRANSFORMER_BASED_WRAP\",\n                },\n                \"use_tensorboard\": True,\n                \"save_first_step\": False,\n            }\n        )\n\n        # write cfg to yaml file\n        Path(temp_dir).mkdir(parents=True, exist_ok=True)\n        with open(Path(temp_dir) / \"config.yaml\", \"w\", encoding=\"utf-8\") as fout:\n            fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))\n\n        execute_subprocess_async(\n            [\n                \"axolotl\",\n                \"train\",\n                str(Path(temp_dir) / \"config.yaml\"),\n                \"--num-processes\",\n                \"2\",\n                \"--main-process-port\",\n                f\"{get_torch_dist_unique_port()}\",\n            ]\n        )\n\n        check_tensorboard(\n            temp_dir + \"/runs\", \"train/train_loss\", 2.3, \"Train Loss (%s) is too high\"\n        )\n\n    @pytest.mark.parametrize(\n        \"gradient_accumulation_steps\",\n        [1, 2],\n    )\n    @pytest.mark.parametrize(\n        \"deepspeed\",\n        [\n            \"deepspeed_configs/zero3_bf16.json\",\n            \"deepspeed_configs/zero3_bf16_cpuoffload_all.json\",\n            # \"deepspeed_configs/zero3_bf16_cpuoffload_params.json\",\n        ],\n    )\n    @pytest.mark.parametrize(\n        \"qlora\",\n        [True, False],\n    )\n    def test_ds_zero3_packed(\n        self, temp_dir, gradient_accumulation_steps, deepspeed, qlora\n    ):\n        if qlora:\n            adapter = {\n                \"adapter\": \"qlora\",\n                \"lora_r\": 8,\n                \"lora_alpha\": 16,\n                \"lora_dropout\": 0.05,\n                \"lora_target_linear\": True,\n                \"load_in_4bit\": True,\n            }\n        else:\n            adapter = {}\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"sample_packing\": True,\n                \"pad_to_sequence_len\": True,\n                \"sequence_len\": 1024,\n                \"val_set_size\": 0.05,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"tatsu-lab/alpaca\",\n                        \"type\": \"alpaca\",\n                        \"split\": \"train[:10%]\",\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"max_steps\": 2,\n                \"micro_batch_size\": 1,\n                \"gradient_accumulation_steps\": gradient_accumulation_steps,\n                \"output_dir\": temp_dir,\n                \"dataset_prepared_path\": temp_dir + \"/last_run_prepared\",\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_torch_fused\",\n                \"lr_scheduler\": \"cosine\",\n                \"flash_attention\": True,\n                \"deepspeed\": str(AXOLOTL_ROOT / deepspeed),\n                \"use_tensorboard\": True,\n                \"save_first_step\": False,\n                **adapter,\n            }\n        )\n\n        # write cfg to yaml file\n        Path(temp_dir).mkdir(parents=True, exist_ok=True)\n        with open(Path(temp_dir) / \"config.yaml\", \"w\", encoding=\"utf-8\") as fout:\n            fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))\n\n        execute_subprocess_async(\n            [\n                \"axolotl\",\n                \"train\",\n                str(Path(temp_dir) / \"config.yaml\"),\n                \"--num-processes\",\n                \"2\",\n                \"--main-process-port\",\n                f\"{get_torch_dist_unique_port()}\",\n            ]\n        )\n\n        check_tensorboard(\n            temp_dir + \"/runs\", \"train/train_loss\", 2.45, \"Train Loss (%s) is too high\"\n        )\n\n    @pytest.mark.parametrize(\n        \"gradient_accumulation_steps\",\n        [1, 2],\n    )\n    @pytest.mark.parametrize(\n        \"qlora\",\n        [True, False],\n    )\n    def test_ds_zero2_packed(self, temp_dir, gradient_accumulation_steps, qlora):\n        if qlora:\n            adapter = {\n                \"adapter\": \"qlora\",\n                \"lora_r\": 8,\n                \"lora_alpha\": 16,\n                \"lora_dropout\": 0.05,\n                \"lora_target_linear\": True,\n                \"load_in_4bit\": True,\n            }\n        else:\n            adapter = {}\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"sample_packing\": True,\n                \"pad_to_sequence_len\": True,\n                \"sequence_len\": 1024,\n                \"val_set_size\": 0.01,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"tatsu-lab/alpaca\",\n                        \"type\": \"alpaca\",\n                        \"split\": \"train[:10%]\",\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"max_steps\": 2,\n                \"micro_batch_size\": 1,\n                \"gradient_accumulation_steps\": gradient_accumulation_steps,\n                \"output_dir\": temp_dir,\n                \"dataset_prepared_path\": temp_dir + \"/last_run_prepared\",\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_torch_fused\",\n                \"lr_scheduler\": \"cosine\",\n                \"flash_attention\": True,\n                \"deepspeed\": str(AXOLOTL_ROOT / \"deepspeed_configs/zero2.json\"),\n                \"use_tensorboard\": True,\n                \"seed\": 42,\n                \"save_first_step\": False,\n                **adapter,\n            }\n        )\n\n        # write cfg to yaml file\n        Path(temp_dir).mkdir(parents=True, exist_ok=True)\n        with open(Path(temp_dir) / \"config.yaml\", \"w\", encoding=\"utf-8\") as fout:\n            fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))\n\n        execute_subprocess_async(\n            [\n                \"axolotl\",\n                \"train\",\n                str(Path(temp_dir) / \"config.yaml\"),\n                \"--num-processes\",\n                \"2\",\n                \"--main-process-port\",\n                f\"{get_torch_dist_unique_port()}\",\n            ]\n        )\n\n        check_tensorboard(\n            temp_dir + \"/runs\", \"train/train_loss\", 2.3, \"Train Loss (%s) is too high\"\n        )\n\n    @pytest.mark.parametrize(\n        \"gradient_accumulation_steps\",\n        [1, 2],\n    )\n    @pytest.mark.parametrize(\n        \"qlora\",\n        [True, False],\n    )\n    def test_ds_zero1_packed(self, temp_dir, gradient_accumulation_steps, qlora):\n        if qlora:\n            adapter = {\n                \"adapter\": \"qlora\",\n                \"lora_r\": 8,\n                \"lora_alpha\": 16,\n                \"lora_dropout\": 0.05,\n                \"lora_target_linear\": True,\n                \"load_in_4bit\": True,\n            }\n        else:\n            adapter = {}\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"sample_packing\": True,\n                \"pad_to_sequence_len\": True,\n                \"sequence_len\": 1024,\n                \"val_set_size\": 0.01,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"tatsu-lab/alpaca\",\n                        \"type\": \"alpaca\",\n                        \"split\": \"train[:10%]\",\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"max_steps\": 2,\n                \"micro_batch_size\": 1,\n                \"gradient_accumulation_steps\": gradient_accumulation_steps,\n                \"output_dir\": temp_dir,\n                \"dataset_prepared_path\": temp_dir + \"/last_run_prepared\",\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_torch_fused\",\n                \"lr_scheduler\": \"cosine\",\n                \"flash_attention\": True,\n                \"deepspeed\": str(AXOLOTL_ROOT / \"deepspeed_configs/zero1.json\"),\n                \"use_tensorboard\": True,\n                \"save_first_step\": False,\n                **adapter,\n            }\n        )\n\n        # write cfg to yaml file\n        Path(temp_dir).mkdir(parents=True, exist_ok=True)\n        with open(Path(temp_dir) / \"config.yaml\", \"w\", encoding=\"utf-8\") as fout:\n            fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))\n\n        execute_subprocess_async(\n            [\n                \"axolotl\",\n                \"train\",\n                str(Path(temp_dir) / \"config.yaml\"),\n                \"--num-processes\",\n                \"2\",\n                \"--main-process-port\",\n                f\"{get_torch_dist_unique_port()}\",\n            ]\n        )\n\n        check_tensorboard(\n            temp_dir + \"/runs\", \"train/train_loss\", 2.5, \"Train Loss (%s) is too high\"\n        )\n\n    @pytest.mark.skip(\n        reason=\"fix untrained tokens brittle with lots of edge cases in latest transformers\"\n    )\n    def test_fix_untrained_tokens(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"fix_untrained_tokens\": True,\n                \"sequence_len\": 512,\n                \"val_set_size\": 0.0,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                    \"bos_token\": \"<|custom_im_start|>\",\n                    \"eos_token\": \"<|custom_im_end|>\",\n                },\n                \"datasets\": [\n                    {\n                        \"chat_template\": \"jinja\",\n                        \"chat_template_jinja\": \"{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|custom_im_start|>' + message['role'] + '\\n' + message['content'] + '<|custom_im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|custom_im_start|>assistant\\n' }}{% endif %}\",\n                        \"path\": \"mlabonne/FineTome-100k\",\n                        \"type\": \"chat_template\",\n                        \"split\": \"train[:10%]\",\n                        \"field_messages\": \"conversations\",\n                        \"message_field_role\": \"from\",\n                        \"message_field_content\": \"value\",\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"max_steps\": 2,\n                \"micro_batch_size\": 1,\n                \"gradient_accumulation_steps\": 1,\n                # \"gradient_checkpointing\": True,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_torch_fused\",\n                \"lr_scheduler\": \"cosine\",\n                \"flash_attention\": True,\n                \"sample_packing\": True,\n                \"bf16\": True,\n                # \"deepspeed\": str(AXOLOTL_ROOT / \"deepspeed_configs/zero1.json\"),\n                \"use_tensorboard\": True,\n                \"save_first_step\": False,\n            }\n        )\n\n        # write cfg to yaml file\n        Path(temp_dir).mkdir(parents=True, exist_ok=True)\n        with open(Path(temp_dir) / \"config.yaml\", \"w\", encoding=\"utf-8\") as fout:\n            fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))\n\n        execute_subprocess_async(\n            [\n                \"axolotl\",\n                \"train\",\n                str(Path(temp_dir) / \"config.yaml\"),\n                \"--num-processes\",\n                \"2\",\n                \"--main-process-port\",\n                f\"{get_torch_dist_unique_port()}\",\n            ]\n        )\n\n        check_tensorboard(\n            temp_dir + \"/runs\", \"train/train_loss\", 4.0, \"Train Loss (%s) is too high\"\n        )\n"
  },
  {
    "path": "tests/e2e/multigpu/test_locking.py",
    "content": "\"\"\"Tests for FileLockLoader class.\"\"\"\n\nimport tempfile\nimport threading\nimport time\nfrom pathlib import Path\nfrom unittest.mock import MagicMock, Mock, patch\n\nimport pytest\n\nfrom axolotl.utils.data.lock import FileLockLoader\nfrom axolotl.utils.dict import DictDefault\n\n\nclass TestFileLockLoader:\n    \"\"\"Class with tests for FileLockLoader.\"\"\"\n\n    @pytest.fixture\n    def temp_dir(self):\n        \"\"\"Create a temporary directory for testing.\"\"\"\n        with tempfile.TemporaryDirectory() as tmp_dir:\n            yield Path(tmp_dir)\n\n    @pytest.fixture\n    def cfg(self, temp_dir):\n        \"\"\"Create a test configuration.\"\"\"\n        return DictDefault({\"dataset_prepared_path\": str(temp_dir)})\n\n    @pytest.fixture\n    def loader(self, cfg):\n        \"\"\"Create a FileLockLoader instance for testing.\"\"\"\n        return FileLockLoader(cfg)\n\n    def test_load_first_process(self, loader):\n        \"\"\"Test load() when no ready flag exists (first process).\"\"\"\n        mock_load_fn = Mock(return_value=\"test_data\")\n\n        result = loader.load(mock_load_fn)\n\n        # Should call the load function\n        mock_load_fn.assert_called_once()\n        assert result == \"test_data\"\n\n        # Should create the ready flag\n        assert loader.ready_flag_path.exists()\n\n    def test_load_subsequent_process(self, loader):\n        \"\"\"Test load() when ready flag already exists (subsequent process).\"\"\"\n        # Create ready flag first\n        loader.ready_flag_path.touch()\n\n        mock_load_fn = Mock(return_value=\"loaded_data\")\n\n        result = loader.load(mock_load_fn)\n\n        # Should still call load function (to load the prepared data)\n        mock_load_fn.assert_called_once()\n        assert result == \"loaded_data\"\n\n    def test_load_concurrent_processes(self, cfg):\n        \"\"\"Test that concurrent processes coordinate correctly.\"\"\"\n        results = []\n        call_count = 0\n\n        def slow_load_fn():\n            nonlocal call_count\n            call_count += 1\n            time.sleep(0.1)  # Simulate slow loading\n            return f\"data_{call_count}\"\n\n        def worker():\n            loader = FileLockLoader(cfg)\n            result = loader.load(slow_load_fn)\n            results.append(result)\n\n        # Start multiple threads simultaneously\n        threads = [threading.Thread(target=worker) for _ in range(3)]\n        for t in threads:\n            t.start()\n        for t in threads:\n            t.join()\n\n        # Only one thread should have done the initial loading\n        # All should return data, but the load function should be called\n        # once by the first process and once by each subsequent process\n        assert len(results) == 3\n        assert all(result.startswith(\"data_\") for result in results)\n\n    @patch(\"time.sleep\")\n    def test_load_waiting_for_ready_flag(self, mock_sleep, loader):\n        \"\"\"Test that processes wait for the ready flag to appear.\"\"\"\n        mock_load_fn = Mock(return_value=\"waiting_data\")\n        mock_ready_flag_path = Mock()\n        exists_call_count = 0\n\n        def mock_exists():\n            nonlocal exists_call_count\n            exists_call_count += 1\n\n            if exists_call_count == 1:\n                # First check: ready flag exists (not first process)\n                return True\n            if exists_call_count <= 3:\n                # While loop checks: flag doesn't exist yet\n                return False\n            return True\n\n        mock_ready_flag_path.exists.side_effect = mock_exists\n\n        # Replace the ready_flag_path with our mock\n        original_path = loader.ready_flag_path\n        loader.ready_flag_path = mock_ready_flag_path\n\n        try:\n            result = loader.load(mock_load_fn)\n        finally:\n            # Restore original path\n            loader.ready_flag_path = original_path\n\n        # Should have slept twice while waiting\n        assert mock_sleep.call_count == 2\n        mock_sleep.assert_called_with(1)\n\n        # Should eventually call load function\n        mock_load_fn.assert_called_once()\n        assert result == \"waiting_data\"\n\n    def test_complete_workflow_with_cleanup(self, loader):\n        \"\"\"Test the complete load -> cleanup workflow.\"\"\"\n        mock_load_fn = Mock(return_value=\"test_data\")\n\n        # First process calls load (this should set up counter)\n        result = loader.load(mock_load_fn)\n        assert result == \"test_data\"\n        assert loader.ready_flag_path.exists()\n        assert loader.counter_path.exists()\n\n        # Cleanup should remove everything since there's only one process\n        loader.cleanup()\n        assert not loader.ready_flag_path.exists()\n        assert not loader.counter_path.exists()\n\n    def test_multiple_processes_workflow(self, loader):\n        \"\"\"Test workflow with multiple processes.\"\"\"\n        # Simulate multiple processes by manually setting up counter\n        loader.ready_flag_path.touch()\n        loader.counter_path.write_text(\"3\")  # 3 processes\n\n        # First process cleanup\n        loader.cleanup()\n        assert loader.ready_flag_path.exists()\n        assert loader.counter_path.read_text().strip() == \"2\"\n\n        # Second process cleanup\n        loader.cleanup()\n        assert loader.ready_flag_path.exists()\n        assert loader.counter_path.read_text().strip() == \"1\"\n\n        # Last process cleanup\n        loader.cleanup()\n        assert not loader.ready_flag_path.exists()\n        assert not loader.counter_path.exists()\n\n    def test_load_exception_handling(self, loader):\n        \"\"\"Test behavior when load_fn raises an exception.\"\"\"\n\n        def failing_load_fn():\n            raise ValueError(\"Load failed\")\n\n        with pytest.raises(ValueError, match=\"Load failed\"):\n            loader.load(failing_load_fn)\n\n        # Ready flag should not be created on failure\n        assert not loader.ready_flag_path.exists()\n\n    def test_file_lock_called(self, loader):\n        \"\"\"Test that FileLock is properly used.\"\"\"\n        mock_load_fn = Mock(return_value=\"locked_data\")\n\n        with patch(\"axolotl.utils.data.lock.FileLock\") as mock_filelock:\n            mock_context = MagicMock()\n            mock_filelock.return_value.__enter__ = Mock(return_value=mock_context)\n            mock_filelock.return_value.__exit__ = Mock(return_value=None)\n\n            loader.load(mock_load_fn)\n\n            # Verify FileLock was called with correct path\n            mock_filelock.assert_called_once_with(str(loader.lock_file_path))\n\n            # Verify context manager was used\n            mock_filelock.return_value.__enter__.assert_called_once()\n            mock_filelock.return_value.__exit__.assert_called_once()\n"
  },
  {
    "path": "tests/e2e/multigpu/test_ray.py",
    "content": "\"\"\"\nE2E tests for multigpu post-training use Ray Train\n\"\"\"\n\nfrom pathlib import Path\n\nimport pytest\nimport yaml\nfrom accelerate.test_utils import execute_subprocess_async\n\nfrom axolotl.utils.dict import DictDefault\n\nfrom tests.e2e.utils import (\n    check_tensorboard,\n    require_torch_2_7_0,\n)\n\nAXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent\n\n\nclass TestMultiGPURay:\n    \"\"\"\n    Test cases for AnyScale Ray post training\n    \"\"\"\n\n    @require_torch_2_7_0\n    def test_lora_ddp(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"sequence_len\": 1024,\n                \"adapter\": \"lora\",\n                \"lora_r\": 8,\n                \"lora_alpha\": 16,\n                \"lora_dropout\": 0.05,\n                \"lora_target_linear\": True,\n                \"val_set_size\": 0.05,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"tatsu-lab/alpaca\",\n                        \"type\": \"alpaca\",\n                        \"split\": \"train[:10%]\",\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"max_steps\": 2,\n                \"micro_batch_size\": 4,\n                \"gradient_accumulation_steps\": 2,\n                \"output_dir\": temp_dir,\n                \"dataset_prepared_path\": temp_dir + \"/last_run_prepared\",\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_8bit\",\n                \"lr_scheduler\": \"cosine\",\n                \"flash_attention\": True,\n                \"use_tensorboard\": True,\n                \"use_ray\": True,\n                \"ray_num_workers\": 2,\n                \"save_first_step\": False,\n            }\n        )\n\n        # write cfg to yaml file\n        Path(temp_dir).mkdir(parents=True, exist_ok=True)\n        with open(Path(temp_dir) / \"config.yaml\", \"w\", encoding=\"utf-8\") as fout:\n            fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))\n\n        execute_subprocess_async(\n            [\n                \"axolotl\",\n                \"train\",\n                str(Path(temp_dir) / \"config.yaml\"),\n                \"--use-ray\",\n                \"--ray-num-workers\",\n                \"2\",\n            ]\n        )\n\n        check_tensorboard(\n            temp_dir + \"/runs\", \"train/train_loss\", 2.3, \"Train Loss (%s) is too high\"\n        )\n\n    @require_torch_2_7_0\n    @pytest.mark.parametrize(\n        \"gradient_accumulation_steps\",\n        [1, 2],\n    )\n    def test_ds_zero2_packed(self, temp_dir, gradient_accumulation_steps):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"sample_packing\": True,\n                \"pad_to_sequence_len\": True,\n                \"sequence_len\": 1024,\n                \"val_set_size\": 0.01,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"tatsu-lab/alpaca\",\n                        \"type\": \"alpaca\",\n                        \"split\": \"train[:10%]\",\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"max_steps\": 2,\n                \"micro_batch_size\": 1,\n                \"gradient_accumulation_steps\": gradient_accumulation_steps,\n                \"output_dir\": temp_dir,\n                \"dataset_prepared_path\": temp_dir + \"/last_run_prepared\",\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_torch\",\n                \"lr_scheduler\": \"cosine\",\n                \"flash_attention\": True,\n                \"deepspeed\": str(AXOLOTL_ROOT / \"deepspeed_configs/zero2.json\"),\n                \"use_tensorboard\": True,\n                \"save_first_step\": False,\n            }\n        )\n\n        # write cfg to yaml file\n        Path(temp_dir).mkdir(parents=True, exist_ok=True)\n        with open(Path(temp_dir) / \"config.yaml\", \"w\", encoding=\"utf-8\") as fout:\n            fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))\n\n        execute_subprocess_async(\n            [\n                \"axolotl\",\n                \"train\",\n                str(Path(temp_dir) / \"config.yaml\"),\n                \"--use-ray\",\n                \"--ray-num-workers\",\n                \"2\",\n            ]\n        )\n\n        check_tensorboard(\n            temp_dir + \"/runs\", \"train/train_loss\", 2.3, \"Train Loss (%s) is too high\"\n        )\n\n    @require_torch_2_7_0\n    @pytest.mark.parametrize(\n        \"gradient_accumulation_steps\",\n        [1, 2],\n    )\n    def test_sft_fsdp2_packed(self, temp_dir, gradient_accumulation_steps):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"sample_packing\": True,\n                \"pad_to_sequence_len\": True,\n                \"sequence_len\": 1024,\n                \"val_set_size\": 0.01,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"tatsu-lab/alpaca\",\n                        \"type\": \"alpaca\",\n                        \"split\": \"train[:10%]\",\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"max_steps\": 2,\n                \"micro_batch_size\": 1,\n                \"gradient_accumulation_steps\": gradient_accumulation_steps,\n                \"output_dir\": temp_dir,\n                \"dataset_prepared_path\": temp_dir + \"/last_run_prepared\",\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_torch\",\n                \"lr_scheduler\": \"cosine\",\n                \"flash_attention\": True,\n                \"fsdp_version\": 2,\n                \"fsdp_config\": {\n                    \"offload_params\": False,\n                    \"cpu_ram_efficient_loading\": False,\n                    \"transformer_layer_cls_to_wrap\": \"LlamaDecoderLayer\",\n                    \"state_dict_type\": \"FULL_STATE_DICT\",\n                    \"auto_wrap_policy\": \"TRANSFORMER_BASED_WRAP\",\n                    \"reshard_after_forward\": True,\n                },\n                \"use_tensorboard\": True,\n                \"save_first_step\": False,\n            }\n        )\n\n        # write cfg to yaml file\n        Path(temp_dir).mkdir(parents=True, exist_ok=True)\n        with open(Path(temp_dir) / \"config.yaml\", \"w\", encoding=\"utf-8\") as fout:\n            fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))\n\n        execute_subprocess_async(\n            [\n                \"axolotl\",\n                \"train\",\n                str(Path(temp_dir) / \"config.yaml\"),\n                \"--use-ray\",\n                \"--ray-num-workers\",\n                \"2\",\n            ]\n        )\n\n        check_tensorboard(\n            temp_dir + \"/runs\", \"train/train_loss\", 2.3, \"Train Loss (%s) is too high\"\n        )\n"
  },
  {
    "path": "tests/e2e/multigpu/test_tp.py",
    "content": "\"\"\"multigpu e2e test for tensor parallelism.\"\"\"\n\nfrom pathlib import Path\n\nimport pytest\nimport yaml\nfrom accelerate.test_utils import execute_subprocess_async, get_torch_dist_unique_port\n\nfrom axolotl.utils.dict import DictDefault\n\nfrom tests.e2e.utils import check_tensorboard, require_torch_2_7_0\n\n\nclass TestTensorParallel:\n    \"\"\"Test class for Tensor Parallel functionality.\"\"\"\n\n    @pytest.mark.skip(\n        reason=\"TP doesn't work with models with tied weights (embeddings)\"\n    )\n    @require_torch_2_7_0\n    def test_fft_sft(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"Qwen/Qwen2.5-0.5B\",\n                \"sequence_len\": 2048,\n                \"val_set_size\": 0.01,\n                \"datasets\": [\n                    {\n                        \"path\": \"tatsu-lab/alpaca\",\n                        \"type\": \"alpaca\",\n                        \"split\": \"train[:10%]\",\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"max_steps\": 2,\n                \"micro_batch_size\": 2,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_torch\",\n                \"tensor_parallel_size\": 2,\n                \"lr_scheduler\": \"cosine\",\n                \"flash_attention\": True,\n                \"use_tensorboard\": True,\n                \"bf16\": True,\n            }\n        )\n\n        # write cfg to yaml file\n        Path(temp_dir).mkdir(parents=True, exist_ok=True)\n        with open(Path(temp_dir) / \"config.yaml\", \"w\", encoding=\"utf-8\") as fout:\n            fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))\n\n        execute_subprocess_async(\n            [\n                \"axolotl\",\n                \"train\",\n                str(Path(temp_dir) / \"config.yaml\"),\n                \"--num-processes\",\n                \"2\",\n                \"--main-process-port\",\n                f\"{get_torch_dist_unique_port()}\",\n            ]\n        )\n\n        check_tensorboard(\n            temp_dir + \"/runs\", \"train/train_loss\", 1.0, \"Train Loss (%s) is too high\"\n        )\n"
  },
  {
    "path": "tests/e2e/patched/__init__.py",
    "content": ""
  },
  {
    "path": "tests/e2e/patched/lora_kernels/__init__.py",
    "content": ""
  },
  {
    "path": "tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py",
    "content": "\"\"\"Integration tests for LoRA activation and attention kernels.\"\"\"\n\nfrom pathlib import Path\n\nimport pytest\nimport torch\nimport yaml\nfrom accelerate.state import PartialState\nfrom peft import PeftModelForCausalLM, get_peft_config\nfrom transformers import AutoModelForCausalLM, LlamaForCausalLM\nfrom transformers.models.llama.configuration_llama import LlamaConfig\nfrom transformers.models.llama.modeling_llama import LlamaAttention\nfrom transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeAttention\n\nfrom axolotl.cli.config import load_cfg\nfrom axolotl.kernels.lora import (\n    apply_lora_mlp_geglu,\n    apply_lora_mlp_swiglu,\n    apply_lora_o,\n    apply_lora_qkv,\n)\nfrom axolotl.loaders.model import ModelLoader\nfrom axolotl.loaders.tokenizer import load_tokenizer\nfrom axolotl.monkeypatch.lora_kernels import (\n    apply_lora_kernel_patches,\n    find_self_attn_in_layer,\n    get_attention_cls_from_config,\n    get_layers,\n    patch_self_attn_lora,\n)\nfrom axolotl.utils.dict import DictDefault\n\nMODEL_CONFIGS = [\n    {\n        \"name\": \"trl-internal-testing/tiny-MistralForCausalLM-0.2\",\n        \"expected_activation\": apply_lora_mlp_swiglu,\n        \"dtype\": torch.float16,\n    },\n    {\n        \"name\": \"trl-internal-testing/tiny-Qwen2ForCausalLM-2.5\",\n        \"expected_activation\": apply_lora_mlp_swiglu,\n        \"dtype\": torch.float16,\n    },\n    {\n        \"name\": \"HuggingFaceTB/SmolLM2-135M\",\n        \"expected_activation\": apply_lora_mlp_swiglu,\n        \"dtype\": torch.float32,\n    },\n    {\n        \"name\": \"trl-internal-testing/tiny-Gemma2ForCausalLM\",\n        \"expected_activation\": apply_lora_mlp_geglu,\n        \"dtype\": torch.float16,\n    },\n]\n\n\n@pytest.fixture(autouse=True)\ndef init_accelerate():\n    \"\"\"Initialize Accelerate state before tests.\"\"\"\n    _ = PartialState()\n\n\n@pytest.fixture\ndef small_llama_model():\n    \"\"\"Create a small LLaMA model for testing.\"\"\"\n    config = {\n        \"vocab_size\": 100,\n        \"hidden_size\": 128,\n        \"intermediate_size\": 256,\n        \"num_hidden_layers\": 2,\n        \"num_attention_heads\": 4,\n    }\n\n    return LlamaForCausalLM(LlamaConfig(**config))\n\n\n@pytest.mark.parametrize(\n    \"model_name,attention_cls\",\n    [\n        (\"HuggingFaceTB/SmolLM2-135M\", LlamaAttention),\n        (\"Qwen/Qwen3-30B-A3B\", Qwen3MoeAttention),\n    ],\n)\ndef test_attention_patching_integration(model_name, attention_cls):\n    \"\"\"Test attention patching in integration context.\"\"\"\n    cfg = DictDefault({\"base_model\": model_name})\n\n    # Store the original implementation\n    original_forward = attention_cls.forward\n\n    # Apply patch\n    patch_self_attn_lora(cfg)\n\n    # Get the new forward method\n    patched_forward = attention_cls.forward\n\n    # Check the forward method was replaced\n    assert original_forward is not patched_forward\n    assert patched_forward.__name__ == \"axolotl_attn_forward\"\n\n    # Check original implementation was stored\n    assert hasattr(attention_cls, \"_original_forward\")\n\n    # Clean up\n    attention_cls.forward = original_forward\n    delattr(attention_cls, \"_original_forward\")\n\n\ndef test_swiglu_mlp_integration(small_llama_model):\n    \"\"\"Test SwiGLU activation in LoRA MLP context.\"\"\"\n    peft_config = get_peft_config(\n        {\n            \"peft_type\": \"LORA\",\n            \"task_type\": \"CAUSAL_LM\",\n            \"r\": 8,\n            \"lora_alpha\": 16,\n            \"target_modules\": [\"gate_proj\", \"up_proj\", \"down_proj\"],\n            \"lora_dropout\": 0,\n            \"bias\": \"none\",\n        }\n    )\n    model = PeftModelForCausalLM(small_llama_model, peft_config).to(\"cuda\")\n    cfg = DictDefault({\"lora_mlp_kernel\": True})\n\n    # Apply patches\n    patched_model = apply_lora_kernel_patches(model, cfg)\n\n    # Verify patches\n    layer = patched_model.model.model.layers[0]\n    assert layer.mlp.forward.__func__ is apply_lora_mlp_swiglu\n\n    # Test forward pass\n    batch_size, seq_len = 2, 10\n    hidden_states = torch.randn(\n        batch_size, seq_len, model.config.hidden_size, device=model.device\n    )\n    position_ids = (\n        torch.arange(seq_len, device=model.device).unsqueeze(0).expand(batch_size, -1)\n    )\n    cos, sin = model.model.model.rotary_emb(hidden_states, position_ids)\n\n    inputs = {\n        \"hidden_states\": hidden_states,\n        \"attention_mask\": None,\n        \"position_embeddings\": (cos, sin),\n        \"output_attentions\": False,\n        \"use_cache\": False,\n        \"past_key_value\": None,\n    }\n\n    # Compare outputs\n    with torch.no_grad():\n        original_output = model.model.model.layers[0](**inputs)[0]\n        patched_output = layer(**inputs)[0]\n\n    assert torch.allclose(original_output, patched_output, rtol=1e-4)\n\n\ndef test_geglu_model_integration():\n    \"\"\"Test GeGLU activation with Gemma model.\"\"\"\n    model = AutoModelForCausalLM.from_pretrained(\n        \"trl-internal-testing/tiny-Gemma2ForCausalLM\",\n        dtype=torch.float16,\n        device_map=\"cuda:0\",\n    )\n    peft_config = get_peft_config(\n        {\n            \"peft_type\": \"LORA\",\n            \"task_type\": \"CAUSAL_LM\",\n            \"r\": 8,\n            \"lora_alpha\": 16,\n            \"target_modules\": [\"gate_proj\", \"up_proj\", \"down_proj\"],\n            \"lora_dropout\": 0,\n            \"bias\": \"none\",\n        }\n    )\n    model = PeftModelForCausalLM(model, peft_config)\n\n    cfg = DictDefault({\"lora_mlp_kernel\": True})\n    patched_model = apply_lora_kernel_patches(model, cfg)\n\n    # Verify patches\n    layer = patched_model.model.model.layers[0]\n    assert layer.mlp.forward.__func__ is apply_lora_mlp_geglu\n\n    # Test end-to-end\n    inputs = torch.randint(0, 100, (1, 20), device=model.device, dtype=torch.long)\n    with torch.no_grad():\n        original_output = model(inputs).logits\n        patched_output = patched_model(inputs).logits\n\n    assert torch.allclose(original_output, patched_output, rtol=1e-4)\n\n\n@pytest.mark.parametrize(\n    \"model_name,expected_activation\",\n    [\n        (\"HuggingFaceTB/SmolLM2-135M\", apply_lora_mlp_swiglu),\n        (\"mhenrichsen/gemma-2b\", apply_lora_mlp_geglu),\n    ],\n)\ndef test_model_specific_activation(model_name, expected_activation):\n    \"\"\"Test that each model type gets the correct activation function.\"\"\"\n    model = AutoModelForCausalLM.from_pretrained(model_name)\n    peft_config = get_peft_config(\n        {\n            \"peft_type\": \"LORA\",\n            \"task_type\": \"CAUSAL_LM\",\n            \"r\": 8,\n            \"lora_alpha\": 16,\n            \"target_modules\": [\"gate_proj\", \"up_proj\", \"down_proj\"],\n            \"lora_dropout\": 0,\n            \"bias\": \"none\",\n        }\n    )\n    model = PeftModelForCausalLM(model, peft_config)\n    cfg = DictDefault({\"lora_mlp_kernel\": True})\n\n    patched_model = apply_lora_kernel_patches(model, cfg)\n    layer = patched_model.model.model.layers[0]\n    assert layer.mlp.forward.__func__ is expected_activation\n\n\ndef test_kernel_patch_conditions():\n    \"\"\"Test various conditions that should prevent kernel patching.\"\"\"\n    test_configs = [\n        # Dropout prevents patching\n        {\n            \"peft_type\": \"LORA\",\n            \"task_type\": \"CAUSAL_LM\",\n            \"r\": 8,\n            \"lora_alpha\": 16,\n            \"target_modules\": [\"gate_proj\", \"up_proj\", \"down_proj\"],\n            \"lora_dropout\": 0.1,\n            \"bias\": \"none\",\n        },\n        # Bias prevents patching\n        {\n            \"peft_type\": \"LORA\",\n            \"task_type\": \"CAUSAL_LM\",\n            \"r\": 8,\n            \"lora_alpha\": 16,\n            \"target_modules\": [\"gate_proj\", \"up_proj\", \"down_proj\"],\n            \"lora_dropout\": 0,\n            \"bias\": \"lora_only\",\n        },\n    ]\n\n    for config in test_configs:\n        model = AutoModelForCausalLM.from_pretrained(\"HuggingFaceTB/SmolLM2-135M\")\n        peft_config = get_peft_config(config)\n        model = PeftModelForCausalLM(model, peft_config)\n        cfg = DictDefault({\"lora_mlp_kernel\": True})\n\n        # Should not patch\n        patched_model = apply_lora_kernel_patches(model, cfg)\n        layer = patched_model.model.model.layers[0].mlp\n\n        # Verify no patches applied\n        assert layer.forward.__func__ is not apply_lora_mlp_swiglu\n        assert layer.forward.__func__ is not apply_lora_mlp_geglu\n\n\ndef test_kernel_config_options():\n    \"\"\"Test that kernel configuration options are respected.\"\"\"\n    # Test different configurations\n    test_configs = [\n        (\n            {\"lora_mlp_kernel\": True, \"lora_qkv_kernel\": False, \"lora_o_kernel\": False},\n            lambda layer: (\n                layer.mlp.forward.__func__ is apply_lora_mlp_swiglu\n                and layer.self_attn.apply_qkv.__func__ is not apply_lora_qkv\n                and layer.self_attn.apply_o.__func__ is not apply_lora_o\n            ),\n        ),\n        (\n            {\"lora_mlp_kernel\": False, \"lora_qkv_kernel\": True, \"lora_o_kernel\": False},\n            lambda layer: (\n                layer.mlp.forward.__func__ is not apply_lora_mlp_swiglu\n                and layer.self_attn.apply_qkv.__func__ is apply_lora_qkv\n                and layer.self_attn.apply_o.__func__ is not apply_lora_o\n            ),\n        ),\n        (\n            {\"lora_mlp_kernel\": False, \"lora_qkv_kernel\": False, \"lora_o_kernel\": True},\n            lambda layer: (\n                layer.mlp.forward.__func__ is not apply_lora_mlp_swiglu\n                and layer.self_attn.apply_qkv.__func__ is not apply_lora_qkv\n                and layer.self_attn.apply_o.__func__ is apply_lora_o\n            ),\n        ),\n    ]\n\n    for config_dict, check_fn in test_configs:\n        # Create fresh model for each test\n        config = {\n            \"vocab_size\": 100,\n            \"hidden_size\": 128,\n            \"intermediate_size\": 256,\n            \"num_hidden_layers\": 2,\n            \"num_attention_heads\": 4,\n        }\n        small_llama_model = LlamaForCausalLM(LlamaConfig(**config))\n\n        peft_config = get_peft_config(\n            {\n                \"peft_type\": \"LORA\",\n                \"task_type\": \"CAUSAL_LM\",\n                \"r\": 8,\n                \"lora_alpha\": 16,\n                \"target_modules\": [\n                    \"gate_proj\",\n                    \"up_proj\",\n                    \"down_proj\",\n                    \"q_proj\",\n                    \"k_proj\",\n                    \"v_proj\",\n                    \"o_proj\",\n                ],\n                \"lora_dropout\": 0,\n                \"bias\": \"none\",\n            }\n        )\n        model = PeftModelForCausalLM(small_llama_model, peft_config).to(\"cuda\")\n        cfg = DictDefault(config_dict)\n        patched_model = apply_lora_kernel_patches(model, cfg)\n\n        # Verify only requested optimizations were applied\n        for layer in patched_model.model.model.layers:\n            assert check_fn(layer), f\"Failed for config: {config_dict}\"\n\n        # Clean up\n        del model\n        del small_llama_model\n        del patched_model\n\n\ndef get_lora_config():\n    \"\"\"Get standard LoRA configuration for testing.\"\"\"\n    return {\n        \"peft_type\": \"LORA\",\n        \"task_type\": \"CAUSAL_LM\",\n        \"r\": 8,\n        \"lora_alpha\": 16,\n        \"target_modules\": [\"gate_proj\", \"up_proj\", \"down_proj\"],\n        \"lora_dropout\": 0,\n        \"bias\": \"none\",\n    }\n\n\ndef get_test_inputs(model, seq_length=20):\n    \"\"\"Generate test inputs for model evaluation.\"\"\"\n    return torch.randint(\n        0,\n        model.config.vocab_size,\n        (1, seq_length),\n        device=model.device,\n        dtype=torch.long,\n    )\n\n\n@pytest.mark.parametrize(\"model_config\", MODEL_CONFIGS)\ndef test_model_architecture(model_config):\n    \"\"\"Test LoRA kernel patches across different model architectures.\"\"\"\n    # Load model with appropriate dtype\n    model = AutoModelForCausalLM.from_pretrained(\n        model_config[\"name\"], torch_dtype=model_config[\"dtype\"], device_map=\"cuda:0\"\n    )\n\n    # Apply LoRA configuration\n    peft_config = get_peft_config(get_lora_config())\n    model = PeftModelForCausalLM(model, peft_config)\n\n    # Apply kernel patches\n    cfg = DictDefault({\"lora_mlp_kernel\": True})\n    patched_model = apply_lora_kernel_patches(model, cfg)\n\n    # Verify correct activation function\n    layer = patched_model.model.model.layers[0]\n    assert layer.mlp.forward.__func__ is model_config[\"expected_activation\"], (\n        f\"Wrong activation for {model_config['name']}\"\n    )\n\n    # Test forward pass\n    inputs = get_test_inputs(model)\n    with torch.no_grad():\n        original_output = model(inputs).logits\n        patched_output = patched_model(inputs).logits\n\n    # Check outputs match\n    assert torch.allclose(original_output, patched_output, rtol=1e-4), (\n        f\"Outputs don't match for {model_config['name']}\"\n    )\n\n\ndef test_kernel_training_integration(temp_dir):\n    \"\"\"Test model loading with kernel patches enabled.\"\"\"\n    from axolotl.cli.utils import load_model_and_tokenizer\n\n    # Create minimal config\n    cfg = DictDefault(\n        {\n            \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n            \"tokenizer_config\": \"HuggingFaceTB/SmolLM2-135M\",\n            \"learning_rate\": 0.000001,\n            \"datasets\": [\n                {\n                    \"path\": \"mhenrichsen/alpaca_2k_test\",\n                    \"type\": \"alpaca\",\n                }\n            ],\n            \"micro_batch_size\": 1,\n            \"gradient_accumulation_steps\": 1,\n            \"adapter\": \"lora\",\n            \"lora_r\": 8,\n            \"lora_alpha\": 16,\n            \"lora_dropout\": 0.0,\n            \"lora_target_linear\": True,\n            \"sequence_len\": 1024,\n            \"lora_mlp_kernel\": True,\n            \"lora_qkv_kernel\": True,\n            \"lora_o_kernel\": True,\n        }\n    )\n\n    # Write cfg to yaml file\n    path = Path(temp_dir) / \"config.yaml\"\n    with open(path, \"w\", encoding=\"utf-8\") as fout:\n        fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))\n\n    # Load config\n    cfg = load_cfg(str(path))\n\n    # Load model\n    model, _, _ = load_model_and_tokenizer(cfg=cfg)\n\n    # Verify correct activation function\n    layer = model.model.model.layers[0]\n    assert layer.mlp.forward.__func__ is apply_lora_mlp_swiglu\n\n\ndef test_kernel_training_integration_auto_enable(temp_dir):\n    \"\"\"Test model loading with auto-enabled kernel patches.\"\"\"\n    # Create minimal config without explicitly setting kernel options\n    cfg = DictDefault(\n        {\n            \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n            \"tokenizer_config\": \"HuggingFaceTB/SmolLM2-135M\",\n            \"learning_rate\": 0.000001,\n            \"datasets\": [\n                {\n                    \"path\": \"mhenrichsen/alpaca_2k_test\",\n                    \"type\": \"alpaca\",\n                }\n            ],\n            \"micro_batch_size\": 1,\n            \"gradient_accumulation_steps\": 1,\n            \"adapter\": \"lora\",\n            \"lora_r\": 8,\n            \"lora_alpha\": 16,\n            \"lora_dropout\": 0.0,\n            \"lora_target_linear\": True,\n            \"sequence_len\": 1024,\n        }\n    )\n\n    # Write cfg to yaml file\n    path = Path(temp_dir) / \"config.yaml\"\n    with open(path, \"w\", encoding=\"utf-8\") as fout:\n        fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))\n\n    # Load config\n    cfg = load_cfg(str(path))\n\n    # Verify kernel options were auto-enabled in the config\n    assert cfg.lora_mlp_kernel is True\n    assert cfg.lora_qkv_kernel is True\n    assert cfg.lora_o_kernel is True\n\n    # Get the attention class before patching to check for side effects\n    attention_cls = get_attention_cls_from_config(cfg)\n\n    # Store original state before patching\n    original_forward_method = attention_cls.forward\n\n    # Load the model (this should trigger the patches)\n    tokenizer = load_tokenizer(cfg)\n    model, _ = ModelLoader(cfg, tokenizer).load()\n\n    # Test side effects of patch_self_attn_lora\n    assert hasattr(attention_cls, \"_original_forward\")\n    assert attention_cls.forward != original_forward_method\n\n    # Find at least one self-attention module and verify it has the patched methods\n    found_patched_attn = False\n    for layer in model.model.model.layers:\n        if hasattr(layer, \"self_attn\"):\n            self_attn = layer.self_attn\n            if all(\n                hasattr(self_attn, proj)\n                for proj in [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\"]\n            ):\n                # These methods should be added by apply_lora_kernel_patches\n                assert hasattr(self_attn, \"apply_qkv\") and callable(self_attn.apply_qkv)\n                assert hasattr(self_attn, \"apply_o\") and callable(self_attn.apply_o)\n\n                found_patched_attn = True\n                break\n\n    assert found_patched_attn\n\n\ndef test_kernel_training_integration_dropout_non_zero(temp_dir):\n    \"\"\"Test model loading with dropout non-zero should not patch.\"\"\"\n\n    from axolotl.cli.utils import load_model_and_tokenizer\n\n    # Create minimal config\n    cfg = DictDefault(\n        {\n            \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n            \"tokenizer_config\": \"HuggingFaceTB/SmolLM2-135M\",\n            \"learning_rate\": 0.000001,\n            \"datasets\": [\n                {\n                    \"path\": \"mhenrichsen/alpaca_2k_test\",\n                    \"type\": \"alpaca\",\n                }\n            ],\n            \"micro_batch_size\": 1,\n            \"gradient_accumulation_steps\": 1,\n            \"adapter\": \"lora\",\n            \"lora_r\": 8,\n            \"lora_alpha\": 16,\n            \"lora_dropout\": 0.1,\n            \"lora_target_linear\": True,\n            \"sequence_len\": 1024,\n        }\n    )\n\n    # Write cfg to yaml file\n    path = Path(temp_dir) / \"config.yaml\"\n    with open(path, \"w\", encoding=\"utf-8\") as fout:\n        fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))\n\n    # Load config\n    cfg = load_cfg(str(path))\n\n    # Get original attention class\n    attention_cls = get_attention_cls_from_config(cfg)\n\n    # Store original state before patching\n    original_forward_method = attention_cls.forward\n\n    # Load model\n    model, tokenizer, _ = load_model_and_tokenizer(cfg=cfg)\n\n    # We call modelloader as that's where the patches are applied\n    # despite the fact that we're not using it to load the model\n    model_loader = ModelLoader(cfg, tokenizer)\n\n    # Apply patch\n    model_loader.patch_manager._apply_self_attention_lora_patch()\n\n    # Verify patch was not applied\n    assert attention_cls.forward == original_forward_method\n\n    # Apply apply_lora_kernel_patches\n    model_loader.patch_manager._apply_lora_kernel_patch(model)\n\n    # Verify patch was not applied\n    layers = get_layers(model)\n    for layer in layers:\n        for self_attn in find_self_attn_in_layer(layer):\n            assert not hasattr(self_attn, \"apply_qkv\")\n            assert not hasattr(self_attn, \"apply_o\")\n"
  },
  {
    "path": "tests/e2e/patched/test_4d_multipack_llama.py",
    "content": "\"\"\"\nE2E tests for multipack fft llama using 4d attention masks\n\"\"\"\n\nimport unittest\n\nfrom axolotl.common.datasets import load_datasets\nfrom axolotl.train import train\nfrom axolotl.utils.config import normalize_config, validate_config\nfrom axolotl.utils.dict import DictDefault\n\nfrom ..utils import check_model_output_exists, with_temp_dir\n\n\nclass Test4dMultipackLlama(unittest.TestCase):\n    \"\"\"\n    Test case for Llama models using 4d attention with multipack\n    \"\"\"\n\n    @with_temp_dir\n    def test_sdp_lora_packing(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"flash_attention\": False,\n                \"sdp_attention\": True,\n                \"sample_packing\": True,\n                \"pad_to_sequence_len\": True,\n                \"load_in_8bit\": True,\n                \"adapter\": \"lora\",\n                \"lora_r\": 32,\n                \"lora_alpha\": 16,\n                \"lora_dropout\": 0.05,\n                \"lora_target_linear\": True,\n                \"sequence_len\": 1024,\n                \"val_set_size\": 0.02,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": \"alpaca\",\n                    },\n                ],\n                \"num_epochs\": 2,\n                \"micro_batch_size\": 2,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_torch_fused\",\n                \"lr_scheduler\": \"cosine\",\n                \"max_steps\": 5,\n                \"save_steps\": 3,\n                \"eval_steps\": 4,\n                \"fp16\": True,\n                \"save_first_step\": False,\n            }\n        )\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        check_model_output_exists(temp_dir, cfg)\n\n    @with_temp_dir\n    def test_torch_lora_packing(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"flash_attention\": False,\n                \"sdp_attention\": False,\n                \"sample_packing\": True,\n                \"pad_to_sequence_len\": True,\n                \"sequence_len\": 1024,\n                \"load_in_8bit\": True,\n                \"adapter\": \"lora\",\n                \"lora_r\": 32,\n                \"lora_alpha\": 16,\n                \"lora_dropout\": 0.05,\n                \"lora_target_linear\": True,\n                \"val_set_size\": 0.02,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": \"alpaca\",\n                    },\n                ],\n                \"num_epochs\": 2,\n                \"micro_batch_size\": 2,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_torch_fused\",\n                \"lr_scheduler\": \"cosine\",\n                \"max_steps\": 5,\n                \"save_steps\": 3,\n                \"eval_steps\": 4,\n                \"fp16\": True,\n                \"save_first_step\": False,\n            }\n        )\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        check_model_output_exists(temp_dir, cfg)\n"
  },
  {
    "path": "tests/e2e/patched/test_activation_checkpointing.py",
    "content": "\"\"\"\nE2E tests for activation checkpointing\n\"\"\"\n\nimport pytest\nimport transformers\nfrom torch.utils.checkpoint import checkpoint\n\nfrom axolotl.common.datasets import load_datasets\nfrom axolotl.train import train\nfrom axolotl.utils.config import normalize_config, validate_config\nfrom axolotl.utils.dict import DictDefault\n\nfrom ..utils import check_model_output_exists\n\n\n@pytest.fixture()\ndef fix_checkpoint_after_test():\n    yield\n    transformers.modeling_utils.checkpoint = checkpoint\n\n\nclass TestActivationCheckpointing:\n    \"\"\"\n    E2E tests for activation checkpointing\n    \"\"\"\n\n    @pytest.mark.parametrize(\n        \"gradient_checkpointing\",\n        [\"offload\", \"offload_disk\"],\n    )\n    def test_activation_checkpointing_offload(\n        self,\n        temp_dir,\n        fix_checkpoint_after_test,\n        gradient_checkpointing,\n    ):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"sequence_len\": 1024,\n                \"val_set_size\": 0.0,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                    \"eos_token\": \"<|im_end|>\",\n                },\n                \"datasets\": [\n                    {\n                        \"chat_template\": \"chatml\",\n                        \"path\": \"mlabonne/FineTome-100k\",\n                        \"type\": \"chat_template\",\n                        \"split\": \"train[:10%]\",\n                        \"field_messages\": \"conversations\",\n                        \"message_field_role\": \"from\",\n                        \"message_field_content\": \"value\",\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"max_steps\": 5,\n                \"micro_batch_size\": 1,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_8bit\",\n                \"lr_scheduler\": \"cosine\",\n                \"flash_attention\": True,\n                \"sample_packing\": True,\n                \"bf16\": True,\n                \"gradient_checkpointing\": gradient_checkpointing,\n                \"save_first_step\": False,\n                \"dataset_num_proc\": 4,\n            }\n        )\n\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        check_model_output_exists(temp_dir, cfg)\n"
  },
  {
    "path": "tests/e2e/patched/test_cli_integrations.py",
    "content": "\"\"\"\ntest cases to make sure the plugin args are loaded from the config file\n\"\"\"\n\nfrom pathlib import Path\n\nimport yaml\n\nfrom axolotl.cli.config import load_cfg\nfrom axolotl.utils.dict import DictDefault\n\n\nclass TestPluginArgs:\n    \"\"\"\n    test class for plugin args loaded from the config file\n    \"\"\"\n\n    def test_liger_plugin_args(self, temp_dir):\n        test_cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"learning_rate\": 0.000001,\n                \"micro_batch_size\": 1,\n                \"gradient_accumulation_steps\": 1,\n                \"datasets\": [\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": \"alpaca\",\n                    },\n                ],\n                \"plugins\": [\"axolotl.integrations.liger.LigerPlugin\"],\n                \"liger_layer_norm\": True,\n                \"liger_rope\": True,\n                \"liger_rms_norm\": False,\n                \"liger_glu_activation\": True,\n                \"liger_fused_linear_cross_entropy\": True,\n            }\n        )\n\n        with open(Path(temp_dir) / \"config.yaml\", \"w\", encoding=\"utf-8\") as fout:\n            fout.write(yaml.dump(test_cfg.to_dict()))\n        cfg = load_cfg(str(Path(temp_dir) / \"config.yaml\"))\n        assert cfg.liger_layer_norm is True\n        assert cfg.liger_rope is True\n        assert cfg.liger_rms_norm is False\n        assert cfg.liger_glu_activation is True\n        assert cfg.liger_fused_linear_cross_entropy is True\n"
  },
  {
    "path": "tests/e2e/patched/test_fa_xentropy.py",
    "content": "\"\"\"\nE2E tests for lora llama\n\"\"\"\n\nimport pytest\nfrom transformers.utils import is_torch_bf16_gpu_available\n\nfrom axolotl.common.datasets import load_datasets\nfrom axolotl.train import train\nfrom axolotl.utils.config import normalize_config, validate_config\nfrom axolotl.utils.dict import DictDefault\n\nfrom ..utils import check_model_output_exists, check_tensorboard\n\n\nclass TestFAXentropyLlama:\n    \"\"\"\n    Test case for Llama models using LoRA w multipack\n    \"\"\"\n\n    @pytest.mark.parametrize(\n        \"gradient_accumulation_steps\",\n        [1, 4],\n    )\n    def test_lora_packing_fa_cross_entropy(self, temp_dir, gradient_accumulation_steps):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"sequence_len\": 1024,\n                \"sample_packing\": True,\n                \"flash_attention\": True,\n                \"flash_attn_cross_entropy\": True,\n                \"load_in_8bit\": True,\n                \"adapter\": \"lora\",\n                \"lora_r\": 8,\n                \"lora_alpha\": 16,\n                \"lora_dropout\": 0.05,\n                \"lora_target_linear\": True,\n                \"val_set_size\": 0.05,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"chat_template\": \"chatml\",\n                \"datasets\": [\n                    {\n                        \"path\": \"mlabonne/FineTome-100k\",\n                        \"field_messages\": \"conversations\",\n                        \"message_field_content\": \"value\",\n                        \"message_field_role\": \"from\",\n                        \"type\": \"chat_template\",\n                        \"split\": \"train[:2%]\",\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"max_steps\": 5,\n                \"save_steps\": 5,\n                \"micro_batch_size\": 2,\n                \"gradient_accumulation_steps\": gradient_accumulation_steps,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_8bit\",\n                \"lr_scheduler\": \"cosine\",\n                \"use_tensorboard\": True,\n                \"save_first_step\": False,\n            }\n        )\n        if is_torch_bf16_gpu_available():\n            cfg.bf16 = True\n        else:\n            cfg.fp16 = True\n\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n\n        dataset_meta = load_datasets(cfg=cfg)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        check_model_output_exists(temp_dir, cfg)\n\n        check_tensorboard(\n            temp_dir + \"/runs\", \"train/train_loss\", 1.5, \"Train Loss (%s) is too high\"\n        )\n"
  },
  {
    "path": "tests/e2e/patched/test_falcon_samplepack.py",
    "content": "\"\"\"\nE2E tests for falcon\n\"\"\"\n\nimport unittest\n\nimport pytest\n\nfrom axolotl.common.datasets import load_datasets\nfrom axolotl.train import train\nfrom axolotl.utils.config import normalize_config, validate_config\nfrom axolotl.utils.dict import DictDefault\n\nfrom ..utils import check_model_output_exists, with_temp_dir\n\n\nclass TestFalconPatched(unittest.TestCase):\n    \"\"\"\n    Test case for Falcon models\n    \"\"\"\n\n    @pytest.mark.skip(reason=\"no tiny models for testing with safetensors\")\n    @with_temp_dir\n    def test_qlora(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"illuin/tiny-random-FalconForCausalLM\",\n                \"flash_attention\": True,\n                \"sample_packing\": True,\n                \"sequence_len\": 2048,\n                \"load_in_4bit\": True,\n                \"adapter\": \"qlora\",\n                \"lora_r\": 16,\n                \"lora_alpha\": 32,\n                \"lora_dropout\": 0.1,\n                \"lora_target_linear\": True,\n                \"lora_modules_to_save\": [\"word_embeddings\", \"lm_head\"],\n                \"val_set_size\": 0.05,\n                \"special_tokens\": {\n                    \"bos_token\": \"<|endoftext|>\",\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": \"alpaca\",\n                    },\n                ],\n                \"num_epochs\": 2,\n                \"micro_batch_size\": 2,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_bnb_8bit\",\n                \"lr_scheduler\": \"cosine\",\n                \"max_steps\": 20,\n                \"save_steps\": 10,\n                \"eval_steps\": 10,\n                \"bf16\": \"auto\",\n                \"save_first_step\": False,\n            }\n        )\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        check_model_output_exists(temp_dir, cfg)\n\n    @pytest.mark.skip(reason=\"no tiny models for testing with safetensors\")\n    @with_temp_dir\n    def test_ft(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"illuin/tiny-random-FalconForCausalLM\",\n                \"flash_attention\": True,\n                \"sample_packing\": True,\n                \"sequence_len\": 2048,\n                \"val_set_size\": 0.05,\n                \"special_tokens\": {\n                    \"bos_token\": \"<|endoftext|>\",\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": \"alpaca\",\n                    },\n                ],\n                \"num_epochs\": 2,\n                \"micro_batch_size\": 2,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_bnb_8bit\",\n                \"lr_scheduler\": \"cosine\",\n                \"max_steps\": 20,\n                \"save_steps\": 10,\n                \"eval_steps\": 10,\n                \"bf16\": \"auto\",\n                \"save_first_step\": False,\n            }\n        )\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        check_model_output_exists(temp_dir, cfg)\n"
  },
  {
    "path": "tests/e2e/patched/test_flattening.py",
    "content": "\"\"\"\nE2E tests for flattening batches\n\"\"\"\n\nimport pytest\nfrom transformers.utils import is_torch_bf16_gpu_available\n\nfrom axolotl.common.datasets import load_datasets\nfrom axolotl.train import train\nfrom axolotl.utils.config import normalize_config, validate_config\nfrom axolotl.utils.dict import DictDefault\n\nfrom ..utils import check_model_output_exists, check_tensorboard\n\n\nclass TestFAFlattening:\n    \"\"\"\n    Test case for Llama models using LoRA w batch flattening\n    \"\"\"\n\n    @pytest.mark.parametrize(\n        \"gradient_accumulation_steps\",\n        [1, 4],\n    )\n    def test_lora_packing_flattening(self, temp_dir, gradient_accumulation_steps):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"sequence_len\": 1024,\n                \"batch_flattening\": True,\n                \"flash_attention\": True,\n                \"load_in_8bit\": True,\n                \"adapter\": \"lora\",\n                \"lora_r\": 8,\n                \"lora_alpha\": 16,\n                \"lora_dropout\": 0.05,\n                \"lora_target_linear\": True,\n                \"val_set_size\": 0.05,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"chat_template\": \"chatml\",\n                \"datasets\": [\n                    {\n                        \"path\": \"mlabonne/FineTome-100k\",\n                        \"field_messages\": \"conversations\",\n                        \"message_field_content\": \"value\",\n                        \"message_field_role\": \"from\",\n                        \"type\": \"chat_template\",\n                        \"split\": \"train[:2%]\",\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"max_steps\": 5,\n                \"save_steps\": 5,\n                \"micro_batch_size\": 2,\n                \"gradient_accumulation_steps\": gradient_accumulation_steps,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_8bit\",\n                \"lr_scheduler\": \"cosine\",\n                \"use_tensorboard\": True,\n                \"save_first_step\": False,\n            }\n        )\n        if is_torch_bf16_gpu_available():\n            cfg.bf16 = True\n        else:\n            cfg.fp16 = True\n\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n\n        dataset_meta = load_datasets(cfg=cfg)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        check_model_output_exists(temp_dir, cfg)\n\n        check_tensorboard(\n            temp_dir + \"/runs\", \"train/train_loss\", 1.5, \"Train Loss (%s) is too high\"\n        )\n"
  },
  {
    "path": "tests/e2e/patched/test_fsdp2_qlora.py",
    "content": "\"\"\"Integration tests for FSDP2 Params4bit patches.\"\"\"\n\nimport pytest\nfrom torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam\n\n\nclass TestFSDPPatchIntegration:\n    \"\"\"Test FSDP patch integration.\"\"\"\n\n    @pytest.mark.integration\n    def test_fsdp2_init_patches(self):\n        \"\"\"Test that all patches can be applied together.\"\"\"\n        from axolotl.monkeypatch.fsdp2_qlora import (\n            apply_init_sharded_param_patch,\n            apply_init_unsharded_param_patch,\n        )\n\n        original_init_sharded = FSDPParam._init_sharded_param\n        original_init_unsharded = FSDPParam.init_unsharded_param\n\n        # Apply patches\n        apply_init_sharded_param_patch()\n        apply_init_unsharded_param_patch()\n\n        assert FSDPParam._init_sharded_param != original_init_sharded, (\n            \"_init_sharded_param was not patched\"\n        )\n        assert FSDPParam.init_unsharded_param != original_init_unsharded, (\n            \"init_unsharded_param was not patched\"\n        )\n"
  },
  {
    "path": "tests/e2e/patched/test_fused_llama.py",
    "content": "\"\"\"\nE2E tests for lora llama\n\"\"\"\n\nimport unittest\n\nimport pytest\nfrom transformers.utils import is_torch_bf16_gpu_available\n\nfrom axolotl.common.datasets import load_datasets\nfrom axolotl.train import train\nfrom axolotl.utils.config import normalize_config, validate_config\nfrom axolotl.utils.dict import DictDefault\n\nfrom ..utils import check_model_output_exists, with_temp_dir\n\n\n@pytest.mark.skip(\"FIXME, mostly underused functionality\")\nclass TestFusedLlama(unittest.TestCase):\n    \"\"\"\n    Test case for Llama models using Fused layers\n    \"\"\"\n\n    @with_temp_dir\n    def test_fft_packing(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"flash_attention\": True,\n                \"pad_to_sequence_len\": True,\n                \"flash_attn_fuse_mlp\": True,\n                \"sample_packing\": True,\n                \"sequence_len\": 1024,\n                \"val_set_size\": 0.02,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": \"alpaca\",\n                    },\n                ],\n                \"num_epochs\": 2,\n                \"micro_batch_size\": 2,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_torch_fused\",\n                \"lr_scheduler\": \"cosine\",\n                \"max_steps\": 10,\n                \"save_steps\": 5,\n                \"eval_steps\": 5,\n                \"save_first_step\": False,\n            }\n        )\n        if is_torch_bf16_gpu_available():\n            cfg.bf16 = True\n        else:\n            cfg.fp16 = True\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        check_model_output_exists(temp_dir, cfg)\n"
  },
  {
    "path": "tests/e2e/patched/test_llama_s2_attention.py",
    "content": "\"\"\"\nE2E tests for llama w/ S2 attn\n\"\"\"\n\nimport unittest\n\nimport pytest\n\nfrom axolotl.common.datasets import load_datasets\nfrom axolotl.train import train\nfrom axolotl.utils.config import normalize_config, validate_config\nfrom axolotl.utils.dict import DictDefault\n\nfrom ..utils import check_model_output_exists, with_temp_dir\n\n\n@pytest.mark.skip(reason=\"FIXME?\")\nclass TestLlamaShiftedSparseAttention(unittest.TestCase):\n    \"\"\"\n    Test case for Llama models using S2 Attn\n    \"\"\"\n\n    @with_temp_dir\n    def test_lora_s2_attn(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"tokenizer_type\": \"AutoTokenizer\",\n                \"sequence_len\": 16384,\n                \"sample_packing\": False,\n                \"flash_attention\": True,\n                \"s2_attention\": True,\n                \"load_in_8bit\": True,\n                \"adapter\": \"lora\",\n                \"lora_r\": 32,\n                \"lora_alpha\": 16,\n                \"lora_dropout\": 0.05,\n                \"lora_target_linear\": True,\n                \"val_set_size\": 0.02,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"Yukang/LongAlpaca-12k\",\n                        \"type\": \"alpaca\",\n                    },\n                ],\n                \"num_epochs\": 2,\n                \"micro_batch_size\": 1,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_torch_fused\",\n                \"lr_scheduler\": \"cosine\",\n                \"max_steps\": 10,\n                \"save_steps\": 5,\n                \"eval_steps\": 5,\n                \"bf16\": \"auto\",\n                \"save_first_step\": False,\n            }\n        )\n\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        check_model_output_exists(temp_dir, cfg)\n\n    @with_temp_dir\n    def test_fft_s2_attn(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"tokenizer_type\": \"AutoTokenizer\",\n                \"sequence_len\": 16384,\n                \"sample_packing\": False,\n                \"flash_attention\": True,\n                \"s2_attention\": True,\n                \"val_set_size\": 0.02,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"Yukang/LongAlpaca-12k\",\n                        \"type\": \"alpaca\",\n                    },\n                ],\n                \"num_epochs\": 2,\n                \"micro_batch_size\": 1,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_torch_fused\",\n                \"lr_scheduler\": \"cosine\",\n                \"max_steps\": 10,\n                \"save_steps\": 5,\n                \"eval_steps\": 5,\n                \"bf16\": \"auto\",\n                \"save_first_step\": False,\n            }\n        )\n\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        check_model_output_exists(temp_dir, cfg)\n"
  },
  {
    "path": "tests/e2e/patched/test_lora_llama_multipack.py",
    "content": "\"\"\"\nE2E tests for lora llama\n\"\"\"\n\nimport unittest\n\nfrom transformers.utils import is_torch_bf16_gpu_available\n\nfrom axolotl.common.datasets import load_datasets\nfrom axolotl.train import train\nfrom axolotl.utils.config import normalize_config, validate_config\nfrom axolotl.utils.dict import DictDefault\n\nfrom ..utils import check_model_output_exists, with_temp_dir\n\n\nclass TestLoraLlama(unittest.TestCase):\n    \"\"\"\n    Test case for Llama models using LoRA w multipack\n    \"\"\"\n\n    @with_temp_dir\n    def test_lora_packing(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"tokenizer_type\": \"AutoTokenizer\",\n                \"sequence_len\": 1024,\n                \"sample_packing\": True,\n                \"flash_attention\": True,\n                \"load_in_8bit\": True,\n                \"adapter\": \"lora\",\n                \"lora_r\": 32,\n                \"lora_alpha\": 64,\n                \"lora_dropout\": 0.05,\n                \"lora_target_linear\": True,\n                \"val_set_size\": 0.2,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": \"alpaca\",\n                    },\n                ],\n                \"num_epochs\": 2,\n                \"max_steps\": 20,\n                \"save_steps\": 10,\n                \"micro_batch_size\": 8,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_torch_fused\",\n                \"lr_scheduler\": \"cosine\",\n                \"save_first_step\": False,\n            }\n        )\n        if is_torch_bf16_gpu_available():\n            cfg.bf16 = True\n        else:\n            cfg.fp16 = True\n\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        check_model_output_exists(temp_dir, cfg)\n"
  },
  {
    "path": "tests/e2e/patched/test_mistral_samplepack.py",
    "content": "\"\"\"\nE2E tests for lora llama\n\"\"\"\n\nimport unittest\n\nfrom axolotl.common.datasets import load_datasets\nfrom axolotl.train import train\nfrom axolotl.utils.config import normalize_config, validate_config\nfrom axolotl.utils.dict import DictDefault\n\nfrom ..utils import check_model_output_exists, require_torch_2_6_0, with_temp_dir\n\n\nclass TestMistral(unittest.TestCase):\n    \"\"\"\n    Test case for Llama models using LoRA\n    \"\"\"\n\n    @require_torch_2_6_0\n    @with_temp_dir\n    def test_lora_packing(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"trl-internal-testing/tiny-MistralForCausalLM-0.2\",\n                \"flash_attention\": True,\n                \"sample_packing\": True,\n                \"sequence_len\": 1024,\n                \"load_in_8bit\": True,\n                \"adapter\": \"lora\",\n                \"lora_r\": 32,\n                \"lora_alpha\": 64,\n                \"lora_dropout\": 0.05,\n                \"lora_target_linear\": True,\n                \"val_set_size\": 0.05,\n                \"special_tokens\": {\n                    \"unk_token\": \"<unk>\",\n                    \"bos_token\": \"<s>\",\n                    \"eos_token\": \"</s>\",\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": \"alpaca\",\n                    },\n                ],\n                \"num_epochs\": 2,\n                \"micro_batch_size\": 2,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_torch_fused\",\n                \"lr_scheduler\": \"cosine\",\n                \"max_steps\": 5,\n                \"save_steps\": 3,\n                \"eval_steps\": 4,\n                \"bf16\": \"auto\",\n                \"save_first_step\": False,\n            }\n        )\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        check_model_output_exists(temp_dir, cfg)\n\n    @with_temp_dir\n    def test_ft_packing(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"trl-internal-testing/tiny-MistralForCausalLM-0.2\",\n                \"flash_attention\": True,\n                \"sample_packing\": True,\n                \"sequence_len\": 1024,\n                \"val_set_size\": 0.05,\n                \"special_tokens\": {\n                    \"unk_token\": \"<unk>\",\n                    \"bos_token\": \"<s>\",\n                    \"eos_token\": \"</s>\",\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": \"alpaca\",\n                    },\n                ],\n                \"num_epochs\": 2,\n                \"micro_batch_size\": 2,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_torch_fused\",\n                \"lr_scheduler\": \"cosine\",\n                \"max_steps\": 5,\n                \"save_steps\": 3,\n                \"eval_steps\": 4,\n                \"bf16\": \"auto\",\n                \"save_first_step\": False,\n            }\n        )\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        check_model_output_exists(temp_dir, cfg)\n"
  },
  {
    "path": "tests/e2e/patched/test_mixtral_samplepack.py",
    "content": "\"\"\"\nE2E tests for mixtral\n\"\"\"\n\nimport unittest\n\nfrom axolotl.common.datasets import load_datasets\nfrom axolotl.train import train\nfrom axolotl.utils.config import normalize_config, validate_config\nfrom axolotl.utils.dict import DictDefault\n\nfrom ..utils import check_model_output_exists, with_temp_dir\n\n\nclass TestMixtral(unittest.TestCase):\n    \"\"\"\n    Test case for Llama models using LoRA\n    \"\"\"\n\n    @with_temp_dir\n    def test_qlora(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"hf-internal-testing/Mixtral-tiny\",\n                \"tokenizer_config\": \"LoneStriker/Mixtral-8x7B-v0.1-HF\",\n                \"flash_attention\": True,\n                \"sample_packing\": True,\n                \"sequence_len\": 2048,\n                \"load_in_4bit\": True,\n                \"adapter\": \"qlora\",\n                \"lora_r\": 16,\n                \"lora_alpha\": 32,\n                \"lora_dropout\": 0.1,\n                \"lora_target_linear\": True,\n                \"val_set_size\": 0.05,\n                \"special_tokens\": {},\n                \"datasets\": [\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": \"alpaca\",\n                    },\n                ],\n                \"num_epochs\": 2,\n                \"micro_batch_size\": 2,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_bnb_8bit\",\n                \"lr_scheduler\": \"cosine\",\n                \"max_steps\": 5,\n                \"save_steps\": 3,\n                \"eval_steps\": 4,\n                \"bf16\": \"auto\",\n                \"save_first_step\": False,\n            }\n        )\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        check_model_output_exists(temp_dir, cfg)\n\n    @with_temp_dir\n    def test_ft(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"hf-internal-testing/Mixtral-tiny\",\n                \"tokenizer_config\": \"LoneStriker/Mixtral-8x7B-v0.1-HF\",\n                \"flash_attention\": True,\n                \"sample_packing\": True,\n                \"sequence_len\": 2048,\n                \"val_set_size\": 0.05,\n                \"special_tokens\": {},\n                \"datasets\": [\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": \"alpaca\",\n                    },\n                ],\n                \"num_epochs\": 2,\n                \"micro_batch_size\": 2,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_bnb_8bit\",\n                \"lr_scheduler\": \"cosine\",\n                \"max_steps\": 5,\n                \"save_steps\": 3,\n                \"eval_steps\": 4,\n                \"bf16\": \"auto\",\n                \"save_first_step\": False,\n            }\n        )\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        check_model_output_exists(temp_dir, cfg)\n"
  },
  {
    "path": "tests/e2e/patched/test_model_patches.py",
    "content": "\"\"\"\nE2E smoke tests to check that the monkeypatches are in place for certain configurations\n\"\"\"\n\nimport unittest\n\nimport transformers\n\nfrom axolotl.loaders import ModelLoader, load_tokenizer\nfrom axolotl.utils.config import normalize_config, validate_config\nfrom axolotl.utils.dict import DictDefault\n\nfrom ..utils import with_temp_dir\n\n\nclass TestModelPatches(unittest.TestCase):\n    \"\"\"\n    TestCases for the multipack monkey patches\n    \"\"\"\n\n    @with_temp_dir\n    def test_mixtral_multipack(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"hf-internal-testing/Mixtral-tiny\",\n                \"tokenizer_config\": \"LoneStriker/Mixtral-8x7B-v0.1-HF\",\n                \"flash_attention\": True,\n                \"sample_packing\": True,\n                \"sequence_len\": 2048,\n                \"val_set_size\": 0.02,\n                \"special_tokens\": {},\n                \"datasets\": [\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": \"alpaca\",\n                    },\n                ],\n                \"num_epochs\": 2,\n                \"micro_batch_size\": 2,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_bnb_8bit\",\n                \"lr_scheduler\": \"cosine\",\n                \"max_steps\": 20,\n                \"save_steps\": 10,\n                \"eval_steps\": 10,\n                \"save_first_step\": False,\n            }\n        )\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        tokenizer = load_tokenizer(cfg)\n        ModelLoader(cfg, tokenizer, inference=False).load()\n\n    @with_temp_dir\n    def test_mistral_multipack(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"trl-internal-testing/tiny-MistralForCausalLM-0.2\",\n                \"flash_attention\": True,\n                \"sample_packing\": True,\n                \"sequence_len\": 2048,\n                \"val_set_size\": 0.02,\n                \"special_tokens\": {},\n                \"datasets\": [\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": \"alpaca\",\n                    },\n                ],\n                \"num_epochs\": 2,\n                \"micro_batch_size\": 2,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_bnb_8bit\",\n                \"lr_scheduler\": \"cosine\",\n                \"max_steps\": 20,\n                \"save_steps\": 10,\n                \"eval_steps\": 10,\n                \"save_first_step\": False,\n            }\n        )\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        tokenizer = load_tokenizer(cfg)\n        ModelLoader(cfg, tokenizer, inference=False).load()\n\n        assert (\n            \"torch.jit\"\n            in transformers.modeling_flash_attention_utils._get_unpad_data.__module__\n        )\n"
  },
  {
    "path": "tests/e2e/patched/test_peft_embeddings.py",
    "content": "\"\"\"\nTest case for handling embeddings when using peft\n\"\"\"\n\nimport torch\n\nfrom axolotl.train import setup_model_and_tokenizer\nfrom axolotl.utils.config import normalize_config, validate_config\nfrom axolotl.utils.dict import DictDefault\n\n\nclass TestLlamaPeftEmbeddings:\n    \"\"\"\n    test class for handling embeddings when using peft\n    \"\"\"\n\n    def test_peft_embeddings_upcast(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"load_in_4bit\": True,\n                \"adapter\": \"qlora\",\n                \"lora_r\": 8,\n                \"lora_alpha\": 16,\n                \"lora_target_linear\": True,\n                \"trust_remote_code\": True,\n                \"sequence_len\": 512,\n                \"val_set_size\": 0.01,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": \"alpaca\",\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"max_steps\": 2,\n                \"micro_batch_size\": 1,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_8bit\",\n                \"lr_scheduler\": \"cosine\",\n                \"flash_attention\": True,\n                \"sample_packing\": False,\n                \"bf16\": \"auto\",\n                \"embeddings_skip_upcast\": True,\n                \"save_first_step\": False,\n            }\n        )\n\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n\n        model, _, _, _ = setup_model_and_tokenizer(cfg)\n\n        # Check if the embeddings are upcast correctly\n        # only embed_tokens is a parameter that may be upcast\n        assert model.base_model.model.model.embed_tokens.weight.dtype == torch.bfloat16\n        assert model.base_model.model.lm_head.weight.dtype == torch.bfloat16\n"
  },
  {
    "path": "tests/e2e/patched/test_phi_multipack.py",
    "content": "\"\"\"\nE2E tests for lora llama\n\"\"\"\n\nimport unittest\n\nfrom axolotl.common.datasets import load_datasets\nfrom axolotl.train import train\nfrom axolotl.utils.config import normalize_config, validate_config\nfrom axolotl.utils.dict import DictDefault\n\nfrom ..utils import check_model_output_exists, with_temp_dir\n\n\nclass TestPhiMultipack(unittest.TestCase):\n    \"\"\"\n    Test case for Phi2 models\n    \"\"\"\n\n    @with_temp_dir\n    def test_ft_packed(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"microsoft/phi-1_5\",\n                \"model_type\": \"PhiForCausalLM\",\n                \"tokenizer_type\": \"AutoTokenizer\",\n                \"sequence_len\": 1024,\n                \"sample_packing\": True,\n                \"flash_attention\": True,\n                \"pad_to_sequence_len\": True,\n                \"load_in_8bit\": False,\n                \"adapter\": None,\n                \"val_set_size\": 0.05,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": \"alpaca\",\n                    },\n                ],\n                \"dataset_shard_num\": 10,\n                \"dataset_shard_idx\": 0,\n                \"num_epochs\": 1,\n                \"micro_batch_size\": 1,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_bnb_8bit\",\n                \"lr_scheduler\": \"cosine\",\n                \"max_steps\": 5,\n                \"eval_steps\": 3,\n                \"save_steps\": 4,\n                \"bf16\": \"auto\",\n                \"save_first_step\": False,\n            }\n        )\n\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        check_model_output_exists(temp_dir, cfg)\n\n    @with_temp_dir\n    def test_qlora_packed(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"microsoft/phi-1_5\",\n                \"model_type\": \"PhiForCausalLM\",\n                \"tokenizer_type\": \"AutoTokenizer\",\n                \"sequence_len\": 1024,\n                \"sample_packing\": True,\n                \"flash_attention\": True,\n                \"pad_to_sequence_len\": True,\n                \"load_in_4bit\": True,\n                \"adapter\": \"qlora\",\n                \"lora_r\": 64,\n                \"lora_alpha\": 32,\n                \"lora_dropout\": 0.05,\n                \"lora_target_linear\": True,\n                \"val_set_size\": 0.02,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": \"alpaca\",\n                    },\n                ],\n                \"dataset_shard_num\": 10,\n                \"dataset_shard_idx\": 0,\n                \"num_epochs\": 1,\n                \"micro_batch_size\": 1,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_bnb_8bit\",\n                \"lr_scheduler\": \"cosine\",\n                \"max_steps\": 5,\n                \"eval_steps\": 3,\n                \"save_steps\": 4,\n                \"bf16\": \"auto\",\n                \"save_first_step\": False,\n            }\n        )\n\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        check_model_output_exists(temp_dir, cfg)\n"
  },
  {
    "path": "tests/e2e/patched/test_resume.py",
    "content": "\"\"\"\nE2E tests for resuming training\n\"\"\"\n\nimport os\nimport re\nimport subprocess\n\nfrom transformers.utils import is_torch_bf16_gpu_available\n\nfrom axolotl.common.datasets import load_datasets\nfrom axolotl.train import train\nfrom axolotl.utils.callbacks.tokens_per_second import TOKENS_STATE_FILE\nfrom axolotl.utils.config import normalize_config, validate_config\nfrom axolotl.utils.dict import DictDefault\n\nfrom ..utils import check_model_output_exists, most_recent_subdir, require_torch_2_6_0\n\n\nclass TestResumeLlama:\n    \"\"\"\n    Test case for resuming training of llama models\n    \"\"\"\n\n    @require_torch_2_6_0\n    def test_resume_lora_packed(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"sequence_len\": 1024,\n                \"sample_packing\": True,\n                \"flash_attention\": True,\n                \"load_in_8bit\": True,\n                \"adapter\": \"lora\",\n                \"lora_r\": 8,\n                \"lora_alpha\": 16,\n                \"lora_dropout\": 0.05,\n                \"lora_target_linear\": True,\n                \"val_set_size\": 0.001,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"tatsu-lab/alpaca\",\n                        \"type\": \"alpaca\",\n                        \"split\": \"train[:10%]\",\n                    },\n                ],\n                \"num_epochs\": 2,\n                \"micro_batch_size\": 1,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_8bit\",\n                \"lr_scheduler\": \"cosine\",\n                \"save_steps\": 3,\n                \"save_total_limit\": 5,\n                \"max_steps\": 15,\n                \"use_tensorboard\": True,\n                \"save_first_step\": False,\n                \"include_tkps\": True,\n            }\n        )\n        if is_torch_bf16_gpu_available():\n            cfg.bf16 = True\n        else:\n            cfg.fp16 = True\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        initial_total_num_tokens = cfg.total_num_tokens\n        assert initial_total_num_tokens is not None, (\n            \"total_num_tokens should be calculated during load_datasets\"\n        )\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n\n        checkpoint_path = f\"{temp_dir}/checkpoint-9\"\n        tokens_state_path = os.path.join(checkpoint_path, TOKENS_STATE_FILE)\n        assert os.path.isfile(tokens_state_path), (\n            f\"{TOKENS_STATE_FILE} should exist in checkpoint at {tokens_state_path}\"\n        )\n\n        resume_cfg = cfg | DictDefault(\n            {\n                \"resume_from_checkpoint\": f\"{temp_dir}/checkpoint-9/\",\n            }\n        )\n        normalize_config(resume_cfg)\n\n        assert resume_cfg.total_num_tokens == initial_total_num_tokens, (\n            f\"total_num_tokens should be preserved on resume. \"\n            f\"Expected {initial_total_num_tokens}, got {resume_cfg.total_num_tokens}\"\n        )\n\n        resume_dataset_meta = load_datasets(cfg=resume_cfg)\n\n        assert resume_cfg.total_num_tokens == initial_total_num_tokens, (\n            f\"total_num_tokens should not be recalculated when resuming. \"\n            f\"Expected {initial_total_num_tokens}, got {resume_cfg.total_num_tokens}\"\n        )\n\n        train(cfg=resume_cfg, dataset_meta=resume_dataset_meta)\n\n        assert resume_cfg.total_num_tokens == initial_total_num_tokens, (\n            f\"total_num_tokens should remain unchanged after resume training. \"\n            f\"Expected {initial_total_num_tokens}, got {resume_cfg.total_num_tokens}\"\n        )\n        check_model_output_exists(temp_dir, cfg)\n\n        tb_log_path_1 = most_recent_subdir(temp_dir + \"/runs\")\n        cmd = f\"tensorboard --inspect  --logdir {tb_log_path_1}\"\n        res = subprocess.run(\n            cmd, shell=True, text=True, capture_output=True, check=True\n        )\n        pattern = r\"first_step\\s+(\\d+)\"\n        first_steps = int(re.findall(pattern, res.stdout)[0])\n        assert first_steps == 10\n"
  },
  {
    "path": "tests/e2e/patched/test_unsloth_integration.py",
    "content": "\"\"\"Test module for checking whether the integration of Unsloth with Hugging Face Transformers is working as expected.\"\"\"\n\nimport unittest\n\nimport pytest\n\n\n@pytest.mark.skip(\n    reason=\"Unsloth integration will be broken going into latest transformers\"\n)\nclass TestUnslothIntegration(unittest.TestCase):\n    \"\"\"Unsloth monkeypatch integration tests.\"\"\"\n\n    def test_is_self_attn_patchable(self):\n        from axolotl.monkeypatch.unsloth_ import check_self_attn_is_patchable\n\n        # ensures the current version of transformers has loss code that matches our patching code\n        self.assertTrue(\n            check_self_attn_is_patchable(),\n            \"HF transformers self attention code has changed and isn't patchable\",\n        )\n"
  },
  {
    "path": "tests/e2e/patched/test_unsloth_qlora.py",
    "content": "\"\"\"\ne2e tests for unsloth qlora\n\"\"\"\n\nimport pytest\n\nfrom axolotl.common.datasets import load_datasets\nfrom axolotl.train import train\nfrom axolotl.utils.config import normalize_config, validate_config\nfrom axolotl.utils.dict import DictDefault\n\nfrom ..utils import check_model_output_exists, check_tensorboard\n\n\n@pytest.mark.skip(\n    reason=\"Unsloth integration will be broken going into latest transformers\"\n)\nclass TestUnslothQLoRA:\n    \"\"\"\n    Test class for Unsloth QLoRA Llama models\n    \"\"\"\n\n    @pytest.mark.parametrize(\n        \"sample_packing\",\n        [True, False],\n    )\n    def test_unsloth_llama_qlora_fa2(self, temp_dir, sample_packing):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"sequence_len\": 1024,\n                \"sample_packing\": sample_packing,\n                \"flash_attention\": True,\n                \"unsloth_lora_mlp\": True,\n                \"unsloth_lora_qkv\": True,\n                \"unsloth_lora_o\": True,\n                \"load_in_4bit\": True,\n                \"adapter\": \"qlora\",\n                \"lora_r\": 16,\n                \"lora_alpha\": 16,\n                \"lora_dropout\": 0.05,\n                \"lora_target_linear\": True,\n                \"val_set_size\": 0.05,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": \"alpaca\",\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"max_steps\": 5,\n                \"save_steps\": 10,\n                \"micro_batch_size\": 4,\n                \"gradient_accumulation_steps\": 2,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_8bit\",\n                \"lr_scheduler\": \"cosine\",\n                \"use_tensorboard\": True,\n                \"bf16\": \"auto\",\n                \"save_first_step\": False,\n            }\n        )\n\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        check_model_output_exists(temp_dir, cfg)\n\n        check_tensorboard(\n            temp_dir + \"/runs\", \"train/train_loss\", 2.0, \"Train Loss (%s) is too high\"\n        )\n\n    def test_unsloth_llama_qlora_unpacked(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"sequence_len\": 1024,\n                \"unsloth_lora_mlp\": True,\n                \"unsloth_lora_qkv\": True,\n                \"unsloth_lora_o\": True,\n                \"sample_packing\": False,\n                \"load_in_4bit\": True,\n                \"adapter\": \"qlora\",\n                \"lora_r\": 16,\n                \"lora_alpha\": 16,\n                \"lora_dropout\": 0.05,\n                \"lora_target_linear\": True,\n                \"val_set_size\": 0.05,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": \"alpaca\",\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"max_steps\": 5,\n                \"save_steps\": 10,\n                \"micro_batch_size\": 4,\n                \"gradient_accumulation_steps\": 2,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_8bit\",\n                \"lr_scheduler\": \"cosine\",\n                \"use_tensorboard\": True,\n                \"bf16\": \"auto\",\n                \"save_first_step\": False,\n            }\n        )\n\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        check_model_output_exists(temp_dir, cfg)\n\n        check_tensorboard(\n            temp_dir + \"/runs\", \"train/train_loss\", 2.0, \"Train Loss (%s) is too high\"\n        )\n\n    @pytest.mark.parametrize(\n        \"sdp_attention\",\n        [True, False],\n    )\n    def test_unsloth_llama_qlora_unpacked_no_fa2_fp16(self, temp_dir, sdp_attention):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"sequence_len\": 1024,\n                \"unsloth_lora_mlp\": True,\n                \"unsloth_lora_qkv\": True,\n                \"unsloth_lora_o\": True,\n                \"sample_packing\": False,\n                \"load_in_4bit\": True,\n                \"adapter\": \"qlora\",\n                \"lora_r\": 16,\n                \"lora_alpha\": 16,\n                \"lora_dropout\": 0.05,\n                \"lora_target_linear\": True,\n                \"val_set_size\": 0.05,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": \"alpaca\",\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"max_steps\": 5,\n                \"save_steps\": 10,\n                \"micro_batch_size\": 4,\n                \"gradient_accumulation_steps\": 2,\n                \"sdp_attention\": sdp_attention,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_8bit\",\n                \"lr_scheduler\": \"cosine\",\n                \"use_tensorboard\": True,\n                \"fp16\": True,\n                \"save_first_step\": False,\n            }\n        )\n\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        check_model_output_exists(temp_dir, cfg)\n\n        check_tensorboard(\n            temp_dir + \"/runs\", \"train/train_loss\", 2.0, \"Train Loss (%s) is too high\"\n        )\n"
  },
  {
    "path": "tests/e2e/solo/__init__.py",
    "content": ""
  },
  {
    "path": "tests/e2e/solo/test_flex.py",
    "content": "\"\"\"\nE2E tests for packed training w/ flex attention\n\"\"\"\n\nimport unittest\n\nfrom transformers.utils import is_torch_bf16_gpu_available\n\nfrom axolotl.common.datasets import load_datasets\nfrom axolotl.train import train\nfrom axolotl.utils.config import normalize_config, validate_config\nfrom axolotl.utils.dict import DictDefault\n\nfrom ..utils import check_tensorboard, require_torch_2_6_0, with_temp_dir\n\n\nclass TestPackedFlex(unittest.TestCase):\n    \"\"\"\n    Test case for Packed training of llama models\n    \"\"\"\n\n    @require_torch_2_6_0\n    @with_temp_dir\n    def test_loss_llama(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"sequence_len\": 1024,\n                \"sample_packing\": True,\n                \"flex_attention\": True,\n                \"val_set_size\": 0.0,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"tatsu-lab/alpaca\",\n                        \"type\": \"alpaca\",\n                        \"split\": \"train[:10%]\",\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"micro_batch_size\": 2,\n                \"gradient_accumulation_steps\": 2,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_torch_fused\",\n                \"lr_scheduler\": \"cosine\",\n                \"max_steps\": 5,\n                \"use_tensorboard\": True,\n                \"save_first_step\": False,\n            }\n        )\n        if is_torch_bf16_gpu_available():\n            cfg.bf16 = True\n        else:\n            cfg.fp16 = True\n\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n\n        check_tensorboard(\n            temp_dir + \"/runs\", \"train/train_loss\", 2.1, \"Train Loss (%s) is too high\"\n        )\n"
  },
  {
    "path": "tests/e2e/solo/test_relora_llama.py",
    "content": "\"\"\"\nE2E tests for relora llama\n\"\"\"\n\nimport unittest\nfrom pathlib import Path\n\nfrom axolotl.common.datasets import load_datasets\nfrom axolotl.train import train\nfrom axolotl.utils.config import normalize_config, validate_config\nfrom axolotl.utils.dict import DictDefault\n\nfrom ..utils import check_model_output_exists, check_tensorboard, with_temp_dir\n\n\nclass TestReLoraLlama(unittest.TestCase):\n    \"\"\"\n    Test case for Llama models using LoRA\n    \"\"\"\n\n    @with_temp_dir\n    def test_relora(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"sequence_len\": 2048,\n                \"sample_packing\": True,\n                \"pad_to_sequence_len\": True,\n                \"flash_attention\": True,\n                \"load_in_8bit\": True,\n                \"adapter\": \"lora\",\n                \"lora_r\": 8,\n                \"lora_alpha\": 16,\n                \"lora_dropout\": 0.05,\n                \"lora_target_modules\": [\"q_proj\", \"v_proj\"],\n                \"relora\": True,\n                \"jagged_restart_steps\": 50,\n                \"jagged_restart_warmup_steps\": 10,\n                \"jagged_restart_anneal_steps\": 10,\n                \"relora_prune_ratio\": 0.9,\n                \"relora_cpu_offload\": True,\n                \"val_set_size\": 0.0,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"chat_template\": \"chatml\",\n                \"datasets\": [\n                    {\n                        \"path\": \"mlabonne/FineTome-100k\",\n                        \"type\": \"chat_template\",\n                        \"split\": \"train[:10%]\",\n                        \"field_messages\": \"conversations\",\n                        \"message_field_role\": \"from\",\n                        \"message_field_content\": \"value\",\n                    },\n                ],\n                \"warmup_steps\": 10,\n                \"num_epochs\": 2,\n                \"max_steps\": 105,  # at least 2x relora_steps\n                \"micro_batch_size\": 2,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_8bit\",\n                \"lr_scheduler\": \"cosine\",\n                \"use_tensorboard\": True,\n                \"save_first_step\": False,\n            }\n        )\n\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        check_model_output_exists(Path(temp_dir) / \"checkpoint-100/adapter\", cfg)\n        assert (Path(temp_dir) / \"checkpoint-100/relora/model.safetensors\").exists(), (\n            \"Relora model checkpoint not found\"\n        )\n\n        check_tensorboard(\n            temp_dir + \"/runs\", \"train/grad_norm\", 0.2, \"grad_norm is too high\"\n        )\n"
  },
  {
    "path": "tests/e2e/test_activation_offloading.py",
    "content": "\"\"\"\nE2E tests for activation offloading\n\"\"\"\n\nimport pytest\n\nfrom axolotl.common.datasets import load_datasets\nfrom axolotl.train import train\nfrom axolotl.utils.config import normalize_config, validate_config\nfrom axolotl.utils.dict import DictDefault\n\nfrom .utils import check_model_output_exists\n\n\nclass TestActivationOffloading:\n    \"\"\"\n    E2E test cases for activation offloading\n    \"\"\"\n\n    @pytest.mark.parametrize(\n        \"adapter\",\n        [\"lora\", \"qlora\", None],\n    )\n    def test_activation_offloading(\n        self,\n        temp_dir,\n        adapter,\n    ):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"sequence_len\": 1024,\n                \"val_set_size\": 0.0,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                    \"eos_token\": \"<|im_end|>\",\n                },\n                \"datasets\": [\n                    {\n                        \"chat_template\": \"chatml\",\n                        \"path\": \"mlabonne/FineTome-100k\",\n                        \"type\": \"chat_template\",\n                        \"split\": \"train[:10%]\",\n                        \"field_messages\": \"conversations\",\n                        \"message_field_role\": \"from\",\n                        \"message_field_content\": \"value\",\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"max_steps\": 2,\n                \"micro_batch_size\": 1,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_torch_8bit\",\n                \"lr_scheduler\": \"cosine\",\n                \"flash_attention\": True,\n                \"sample_packing\": True,\n                \"bf16\": \"auto\",\n                \"gradient_checkpointing\": True,\n                \"activation_offloading\": True,\n                \"save_first_step\": False,\n                \"lora_r\": 8,\n                \"lora_alpha\": 16,\n                \"lora_target_linear\": True,\n            }\n        )\n        if adapter == \"lora\":\n            cfg[\"adapter\"] = \"lora\"\n        if adapter == \"qlora\":\n            cfg[\"adapter\"] = \"qlora\"\n            cfg[\"load_in_4bit\"] = True\n\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        check_model_output_exists(temp_dir, cfg)\n"
  },
  {
    "path": "tests/e2e/test_deepseekv3.py",
    "content": "\"\"\"\nE2E tests for deepseekv3\n\"\"\"\n\nfrom pathlib import Path\n\nimport pytest\n\nfrom axolotl.common.datasets import load_datasets\nfrom axolotl.train import train\nfrom axolotl.utils.config import normalize_config, validate_config\nfrom axolotl.utils.dict import DictDefault\n\nfrom tests.hf_offline_utils import enable_hf_offline\n\n\nclass TestDeepseekV3:\n    \"\"\"\n    Test case for DeepseekV3 models\n    \"\"\"\n\n    @enable_hf_offline\n    @pytest.mark.parametrize(\n        \"sample_packing\",\n        [True, False],\n    )\n    def test_lora_deepseekv3(self, temp_dir, sample_packing):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"axolotl-ai-co/DeepSeek-V3-11M\",\n                \"trust_remote_code\": True,\n                \"sample_packing\": sample_packing,\n                \"flash_attention\": True,\n                \"sequence_len\": 2048,\n                \"adapter\": \"lora\",\n                \"lora_r\": 8,\n                \"lora_alpha\": 16,\n                \"lora_dropout\": 0.05,\n                \"lora_target_linear\": True,\n                \"val_set_size\": 0,\n                \"datasets\": [\n                    {\n                        \"path\": \"mlabonne/FineTome-100k\",\n                        \"type\": \"chat_template\",\n                        \"field_messages\": \"conversations\",\n                        \"message_property_mappings\": {\n                            \"role\": \"from\",\n                            \"content\": \"value\",\n                        },\n                        \"drop_system_message\": True,\n                        \"split\": \"train[:1%]\",\n                    },\n                ],\n                \"special_tokens\": {\n                    \"bos_token\": \"<｜begin▁of▁sentence｜>\",\n                    \"eos_token\": \"<｜end▁of▁sentence｜>\",\n                },\n                \"chat_template\": \"deepseek_v3\",\n                \"num_epochs\": 1,\n                \"micro_batch_size\": 1,\n                \"gradient_accumulation_steps\": 2,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_bnb_8bit\",\n                \"lr_scheduler\": \"cosine\",\n                \"max_steps\": 5,\n                \"bf16\": True,\n                \"save_first_step\": False,\n            }\n        )\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        assert (Path(temp_dir) / \"adapter_model.safetensors\").exists()\n\n    @enable_hf_offline\n    @pytest.mark.parametrize(\n        \"sample_packing\",\n        [True, False],\n    )\n    def test_fft_deepseekv3(self, temp_dir, sample_packing):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"axolotl-ai-co/DeepSeek-V3-11M\",\n                \"trust_remote_code\": True,\n                \"sample_packing\": sample_packing,\n                \"flash_attention\": True,\n                \"sequence_len\": 2048,\n                \"val_set_size\": 0,\n                \"datasets\": [\n                    {\n                        \"path\": \"mlabonne/FineTome-100k\",\n                        \"type\": \"chat_template\",\n                        \"field_messages\": \"conversations\",\n                        \"message_field_role\": \"from\",\n                        \"message_field_content\": \"value\",\n                        \"split\": \"train[:1%]\",\n                    },\n                ],\n                \"chat_template\": \"deepseek_v3\",\n                \"special_tokens\": {\n                    \"bos_token\": \"<｜begin▁of▁sentence｜>\",\n                    \"eos_token\": \"<｜end▁of▁sentence｜>\",\n                },\n                \"num_epochs\": 1,\n                \"micro_batch_size\": 1,\n                \"gradient_accumulation_steps\": 2,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_bnb_8bit\",\n                \"lr_scheduler\": \"cosine\",\n                \"max_steps\": 5,\n                \"bf16\": True,\n                \"save_first_step\": False,\n            }\n        )\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        assert (Path(temp_dir) / \"model.safetensors\").exists()\n"
  },
  {
    "path": "tests/e2e/test_diffusion.py",
    "content": "\"\"\"E2E smoke test for diffusion training plugin.\"\"\"\n\nfrom axolotl.common.datasets import load_datasets\nfrom axolotl.train import train\nfrom axolotl.utils.config import normalize_config, validate_config\nfrom axolotl.utils.dict import DictDefault\n\nfrom tests.e2e.utils import check_model_output_exists\n\n\nclass TestDiffusion:\n    \"\"\"Test case for diffusion training plugin.\"\"\"\n\n    def test_diffusion_smoke_test(self, temp_dir):\n        \"\"\"\n        Smoke test for diffusion training to ensure the plugin loads and trains without\n        error.\n        \"\"\"\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"tokenizer_type\": \"AutoTokenizer\",\n                \"trust_remote_code\": True,\n                \"sequence_len\": 256,\n                \"val_set_size\": 0.1,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": \"alpaca\",\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"max_steps\": 3,\n                \"micro_batch_size\": 1,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.0001,\n                \"optimizer\": \"adamw_torch\",\n                \"lr_scheduler\": \"cosine\",\n                \"bf16\": True,\n                \"save_first_step\": False,\n                \"logging_steps\": 1,\n                \"eval_steps\": 3,\n                # Diffusion-specific config\n                \"plugins\": [\"axolotl.integrations.diffusion.DiffusionPlugin\"],\n                \"diffusion\": {\n                    # sample generation\n                    \"generate_samples\": True,\n                    \"generation_interval\": 1,\n                    \"num_generation_samples\": 1,\n                    \"generation_steps\": 2,\n                    \"generation_max_length\": 32,\n                    \"generation_temperature\": 0.0,\n                    # training-specific\n                    \"mask_token_id\": 16,\n                    \"eps\": 1e-3,\n                    \"importance_weighting\": False,\n                },\n            }\n        )\n\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        check_model_output_exists(temp_dir, cfg)\n\n    def test_diffusion_sft_labels(self, temp_dir):\n        \"\"\"Test that diffusion training properly handles SFT data with labels.\"\"\"\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"tokenizer_type\": \"AutoTokenizer\",\n                \"trust_remote_code\": True,\n                \"sequence_len\": 256,\n                \"val_set_size\": 0.1,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": \"alpaca\",\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"max_steps\": 3,\n                \"micro_batch_size\": 1,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.0001,\n                \"optimizer\": \"adamw_torch\",\n                \"lr_scheduler\": \"cosine\",\n                \"bf16\": True,\n                \"save_first_step\": False,\n                \"logging_steps\": 1,\n                \"eval_steps\": 2,\n                # Diffusion-specific config\n                \"plugins\": [\"axolotl.integrations.diffusion.DiffusionPlugin\"],\n                \"diffusion\": {\n                    # sample generation\n                    \"generate_samples\": True,\n                    \"generation_interval\": 1,\n                    \"num_generation_samples\": 1,\n                    \"generation_steps\": 2,\n                    \"generation_max_length\": 32,\n                    \"generation_temperature\": 0.0,\n                    # training-specific\n                    \"mask_token_id\": 16,\n                    \"eps\": 1e-3,\n                    \"importance_weighting\": True,\n                },\n                # Ensure we have proper SFT labels\n                \"train_on_inputs\": False,\n            }\n        )\n\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        # Verify that the dataset has labels\n        sample = dataset_meta.train_dataset[0]\n        assert \"labels\" in sample, \"SFT dataset should have labels\"\n\n        # Check that some labels are -100 (prompt tokens)\n        labels = sample[\"labels\"]\n        if hasattr(labels, \"tolist\"):\n            labels = labels.tolist()\n        assert -100 in labels, \"SFT dataset should have -100 labels for prompt tokens\"\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        check_model_output_exists(temp_dir, cfg)\n"
  },
  {
    "path": "tests/e2e/test_dpo.py",
    "content": "\"\"\"E2E tests for lora llama\"\"\"\n\nimport unittest\nfrom pathlib import Path\n\nimport pytest\n\nfrom axolotl.cli.args import TrainerCliArgs\nfrom axolotl.common.datasets import load_preference_datasets\nfrom axolotl.train import train\nfrom axolotl.utils.config import normalize_config, validate_config\nfrom axolotl.utils.dict import DictDefault\n\nfrom .utils import check_model_output_exists, with_temp_dir\n\n\nclass TestDPOLlamaLora(unittest.TestCase):\n    \"\"\"\n    Test case for DPO Llama models using LoRA\n    \"\"\"\n\n    @with_temp_dir\n    def test_dpo_lora(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"tokenizer_type\": \"AutoTokenizer\",\n                \"sequence_len\": 1024,\n                \"load_in_8bit\": True,\n                \"adapter\": \"lora\",\n                \"lora_r\": 64,\n                \"lora_alpha\": 32,\n                \"lora_dropout\": 0.1,\n                \"lora_target_linear\": True,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"rl\": \"dpo\",\n                \"datasets\": [\n                    {\n                        \"path\": \"arcee-ai/distilabel-intel-orca-dpo-pairs-binarized\",\n                        \"type\": \"chatml.ultra\",\n                        \"split\": \"train\",\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"micro_batch_size\": 4,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"paged_adamw_8bit\",\n                \"lr_scheduler\": \"cosine\",\n                \"max_steps\": 20,\n                \"save_steps\": 10,\n                \"warmup_steps\": 5,\n                \"gradient_checkpointing\": True,\n                \"gradient_checkpointing_kwargs\": {\"use_reentrant\": True},\n                \"save_first_step\": False,\n            }\n        )\n\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        cli_args = TrainerCliArgs()\n        dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        check_model_output_exists(Path(temp_dir) / \"checkpoint-20\", cfg)\n\n    @with_temp_dir\n    def test_dpo_nll_lora(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"tokenizer_type\": \"AutoTokenizer\",\n                \"sequence_len\": 1024,\n                \"load_in_8bit\": True,\n                \"adapter\": \"lora\",\n                \"lora_r\": 64,\n                \"lora_alpha\": 32,\n                \"lora_dropout\": 0.1,\n                \"lora_target_linear\": True,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"rl\": \"dpo\",\n                \"rpo_alpha\": 0.5,\n                \"datasets\": [\n                    {\n                        \"path\": \"arcee-ai/distilabel-intel-orca-dpo-pairs-binarized\",\n                        \"type\": \"chatml.ultra\",\n                        \"split\": \"train\",\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"micro_batch_size\": 4,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"paged_adamw_8bit\",\n                \"lr_scheduler\": \"cosine\",\n                \"max_steps\": 20,\n                \"save_steps\": 10,\n                \"warmup_steps\": 5,\n                \"gradient_checkpointing\": True,\n                \"gradient_checkpointing_kwargs\": {\"use_reentrant\": True},\n                \"save_first_step\": False,\n            }\n        )\n\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        cli_args = TrainerCliArgs()\n        dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        check_model_output_exists(Path(temp_dir) / \"checkpoint-20\", cfg)\n\n    @with_temp_dir\n    def test_dpo_use_weighting(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"tokenizer_type\": \"AutoTokenizer\",\n                \"sequence_len\": 1024,\n                \"load_in_8bit\": True,\n                \"adapter\": \"lora\",\n                \"lora_r\": 64,\n                \"lora_alpha\": 32,\n                \"lora_dropout\": 0.1,\n                \"lora_target_linear\": True,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"rl\": \"dpo\",\n                \"dpo_use_weighting\": True,\n                \"datasets\": [\n                    {\n                        \"path\": \"arcee-ai/distilabel-intel-orca-dpo-pairs-binarized\",\n                        \"type\": \"chatml.ultra\",\n                        \"split\": \"train\",\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"micro_batch_size\": 4,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"paged_adamw_8bit\",\n                \"lr_scheduler\": \"cosine\",\n                \"max_steps\": 20,\n                \"save_steps\": 10,\n                \"warmup_steps\": 5,\n                \"gradient_checkpointing\": True,\n                \"gradient_checkpointing_kwargs\": {\"use_reentrant\": True},\n                \"save_first_step\": False,\n            }\n        )\n\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        cli_args = TrainerCliArgs()\n        dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        check_model_output_exists(Path(temp_dir) / \"checkpoint-20\", cfg)\n\n    @pytest.mark.skip(\"kto_pair no longer supported in trl\")\n    @with_temp_dir\n    def test_kto_pair_lora(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"tokenizer_type\": \"AutoTokenizer\",\n                \"sequence_len\": 1024,\n                \"load_in_8bit\": True,\n                \"adapter\": \"lora\",\n                \"lora_r\": 64,\n                \"lora_alpha\": 32,\n                \"lora_dropout\": 0.1,\n                \"lora_target_linear\": True,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"rl\": \"kto_pair\",\n                \"datasets\": [\n                    {\n                        \"path\": \"arcee-ai/distilabel-intel-orca-dpo-pairs-binarized\",\n                        \"type\": \"chatml.ultra\",\n                        \"split\": \"train\",\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"micro_batch_size\": 4,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"paged_adamw_8bit\",\n                \"lr_scheduler\": \"cosine\",\n                \"max_steps\": 20,\n                \"save_steps\": 10,\n                \"warmup_steps\": 5,\n                \"gradient_checkpointing\": True,\n                \"gradient_checkpointing_kwargs\": {\"use_reentrant\": True},\n                \"save_first_step\": False,\n            }\n        )\n\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        cli_args = TrainerCliArgs()\n        dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        check_model_output_exists(Path(temp_dir) / \"checkpoint-20\", cfg)\n\n    @with_temp_dir\n    def test_ipo_lora(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"tokenizer_type\": \"AutoTokenizer\",\n                \"sequence_len\": 1024,\n                \"load_in_8bit\": True,\n                \"adapter\": \"lora\",\n                \"lora_r\": 64,\n                \"lora_alpha\": 32,\n                \"lora_dropout\": 0.1,\n                \"lora_target_linear\": True,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"rl\": \"ipo\",\n                \"datasets\": [\n                    {\n                        \"path\": \"arcee-ai/distilabel-intel-orca-dpo-pairs-binarized\",\n                        \"type\": \"chatml.ultra\",\n                        \"split\": \"train\",\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"micro_batch_size\": 4,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"paged_adamw_8bit\",\n                \"lr_scheduler\": \"cosine\",\n                \"max_steps\": 20,\n                \"save_steps\": 10,\n                \"warmup_steps\": 5,\n                \"gradient_checkpointing\": True,\n                \"gradient_checkpointing_kwargs\": {\"use_reentrant\": True},\n                \"save_first_step\": False,\n            }\n        )\n\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        cli_args = TrainerCliArgs()\n        dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        check_model_output_exists(Path(temp_dir) / \"checkpoint-20\", cfg)\n\n    @with_temp_dir\n    def test_orpo_lora(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"tokenizer_type\": \"AutoTokenizer\",\n                \"sequence_len\": 1024,\n                \"load_in_8bit\": True,\n                \"adapter\": \"lora\",\n                \"lora_r\": 64,\n                \"lora_alpha\": 32,\n                \"lora_dropout\": 0.1,\n                \"lora_target_linear\": True,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"rl\": \"orpo\",\n                \"orpo_alpha\": 0.1,\n                \"remove_unused_columns\": False,\n                \"chat_template\": \"chatml\",\n                \"datasets\": [\n                    {\n                        \"path\": \"argilla/distilabel-capybara-dpo-7k-binarized\",\n                        \"type\": \"chat_template.argilla\",\n                        \"split\": \"train\",\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"micro_batch_size\": 4,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"paged_adamw_8bit\",\n                \"lr_scheduler\": \"cosine\",\n                \"max_steps\": 20,\n                \"save_steps\": 10,\n                \"warmup_steps\": 5,\n                \"gradient_checkpointing\": True,\n                \"gradient_checkpointing_kwargs\": {\"use_reentrant\": True},\n                \"save_first_step\": False,\n            }\n        )\n\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        cli_args = TrainerCliArgs()\n        dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        check_model_output_exists(Path(temp_dir) / \"checkpoint-20\", cfg)\n\n    @pytest.mark.skip(reason=\"Fix the implementation\")\n    @with_temp_dir\n    def test_kto_lora(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"tokenizer_type\": \"LlamaTokenizer\",\n                \"sequence_len\": 1024,\n                \"load_in_8bit\": True,\n                \"adapter\": \"lora\",\n                \"lora_r\": 64,\n                \"lora_alpha\": 32,\n                \"lora_dropout\": 0.1,\n                \"lora_target_linear\": True,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"rl\": \"kto\",\n                \"rl_beta\": 0.5,\n                \"kto_desirable_weight\": 1.0,\n                \"kto_undesirable_weight\": 1.0,\n                \"remove_unused_columns\": False,\n                \"datasets\": [\n                    # {\n                    #     \"path\": \"argilla/kto-mix-15k\",\n                    #     \"type\": \"chatml.argilla_chat\",\n                    #     \"split\": \"train\",\n                    # },\n                    {\n                        \"path\": \"argilla/ultrafeedback-binarized-preferences-cleaned-kto\",\n                        \"type\": \"chatml.ultra\",\n                        \"split\": \"train\",\n                    },\n                    # {\n                    #     \"path\": \"argilla/kto-mix-15k\",\n                    #     \"type\": \"llama3.argilla_chat\",\n                    #     \"split\": \"train\",\n                    # },\n                    {\n                        \"path\": \"argilla/ultrafeedback-binarized-preferences-cleaned-kto\",\n                        \"type\": \"llama3.ultra\",\n                        \"split\": \"train\",\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"micro_batch_size\": 4,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"paged_adamw_8bit\",\n                \"lr_scheduler\": \"cosine\",\n                \"max_steps\": 20,\n                \"save_steps\": 10,\n                \"warmup_steps\": 5,\n                \"gradient_checkpointing\": True,\n                \"gradient_checkpointing_kwargs\": {\"use_reentrant\": True},\n                \"save_first_step\": False,\n            }\n        )\n\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        cli_args = TrainerCliArgs()\n        dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        check_model_output_exists(Path(temp_dir) / \"checkpoint-20\", cfg)\n"
  },
  {
    "path": "tests/e2e/test_embeddings_lr.py",
    "content": "\"\"\"\nE2E tests for llama pretrain\n\"\"\"\n\nimport unittest\n\nfrom axolotl.common.datasets import load_datasets\nfrom axolotl.train import train\nfrom axolotl.utils.config import normalize_config, validate_config\nfrom axolotl.utils.dict import DictDefault\n\nfrom .utils import check_model_output_exists, check_tensorboard, with_temp_dir\n\n\nclass TestEmbeddingsLrScale(unittest.TestCase):\n    \"\"\"\n    Test case for embedding_lr*\n    \"\"\"\n\n    @with_temp_dir\n    def test_train_w_embedding_lr_scale(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"flash_attention\": True,\n                \"sequence_len\": 1024,\n                \"sample_packing\": True,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": \"alpaca\",\n                    },\n                ],\n                \"max_steps\": 5,\n                \"num_epochs\": 1,\n                \"micro_batch_size\": 1,\n                \"gradient_accumulation_steps\": 1,\n                \"val_set_size\": 0.0,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_torch_fused\",\n                \"embedding_lr_scale\": 0.5,\n                \"lr_scheduler\": \"cosine\",\n                \"bf16\": \"auto\",\n                \"use_tensorboard\": True,\n                \"save_first_step\": False,\n            }\n        )\n\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        check_model_output_exists(temp_dir, cfg)\n\n        check_tensorboard(\n            temp_dir + \"/runs\", \"train/train_loss\", 2.0, \"Loss is too high\"\n        )\n\n    @with_temp_dir\n    def test_train_w_embedding_lr(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"flash_attention\": True,\n                \"sequence_len\": 1024,\n                \"sample_packing\": True,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": \"alpaca\",\n                    },\n                ],\n                \"max_steps\": 5,\n                \"num_epochs\": 1,\n                \"micro_batch_size\": 1,\n                \"gradient_accumulation_steps\": 1,\n                \"val_set_size\": 0.0,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_torch_fused\",\n                \"embedding_lr\": 0.000005,\n                \"lr_scheduler\": \"cosine\",\n                \"bf16\": \"auto\",\n                \"use_tensorboard\": True,\n                \"save_first_step\": False,\n            }\n        )\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        check_model_output_exists(temp_dir, cfg)\n\n        check_tensorboard(\n            temp_dir + \"/runs\", \"train/train_loss\", 2.0, \"Loss is too high\"\n        )\n"
  },
  {
    "path": "tests/e2e/test_evaluate.py",
    "content": "\"\"\"E2E smoke test for evaluate CLI command\"\"\"\n\nfrom pathlib import Path\n\nimport yaml\nfrom accelerate.test_utils import execute_subprocess_async\nfrom transformers.testing_utils import get_torch_dist_unique_port\n\nfrom axolotl.utils.dict import DictDefault\n\n\nclass TestE2eEvaluate:\n    \"\"\"Test cases for evaluate CLI\"\"\"\n\n    def test_evaluate(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"sequence_len\": 1024,\n                \"val_set_size\": 0.02,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": \"alpaca\",\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"micro_batch_size\": 8,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_torch_fused\",\n                \"lr_scheduler\": \"cosine\",\n                \"max_steps\": 20,\n                \"save_first_step\": False,\n            }\n        )\n\n        # write cfg to yaml file\n        Path(temp_dir).mkdir(parents=True, exist_ok=True)\n        with open(Path(temp_dir) / \"config.yaml\", \"w\", encoding=\"utf-8\") as fout:\n            fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))\n\n        execute_subprocess_async(\n            [\n                \"accelerate\",\n                \"launch\",\n                \"--num-processes\",\n                \"2\",\n                \"--main_process_port\",\n                f\"{get_torch_dist_unique_port()}\",\n                \"-m\",\n                \"axolotl.cli.evaluate\",\n                str(Path(temp_dir) / \"config.yaml\"),\n            ]\n        )\n"
  },
  {
    "path": "tests/e2e/test_falcon.py",
    "content": "\"\"\"\nE2E tests for falcon\n\"\"\"\n\nimport unittest\n\nimport pytest\n\nfrom axolotl.common.datasets import load_datasets\nfrom axolotl.train import train\nfrom axolotl.utils.config import normalize_config, validate_config\nfrom axolotl.utils.dict import DictDefault\n\nfrom .utils import check_model_output_exists, with_temp_dir\n\n\nclass TestFalcon(unittest.TestCase):\n    \"\"\"\n    Test case for falcon\n    \"\"\"\n\n    @pytest.mark.skip(reason=\"no tiny models for testing with safetensors\")\n    @with_temp_dir\n    def test_lora(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"illuin/tiny-random-FalconForCausalLM\",\n                \"flash_attention\": True,\n                \"sequence_len\": 1024,\n                \"load_in_8bit\": True,\n                \"adapter\": \"lora\",\n                \"lora_r\": 32,\n                \"lora_alpha\": 64,\n                \"lora_dropout\": 0.05,\n                \"lora_target_linear\": True,\n                \"lora_modules_to_save\": [\n                    \"word_embeddings\",\n                    \"lm_head\",\n                ],\n                \"val_set_size\": 0.02,\n                \"special_tokens\": {\n                    \"bos_token\": \"<|endoftext|>\",\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": \"alpaca\",\n                    },\n                ],\n                \"num_epochs\": 2,\n                \"micro_batch_size\": 2,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_torch_fused\",\n                \"lr_scheduler\": \"cosine\",\n                \"max_steps\": 20,\n                \"save_steps\": 10,\n                \"eval_steps\": 10,\n                \"bf16\": \"auto\",\n                \"save_first_step\": False,\n            }\n        )\n\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        check_model_output_exists(temp_dir, cfg)\n\n    @pytest.mark.skip(reason=\"no tiny models for testing with safetensors\")\n    @with_temp_dir\n    def test_lora_added_vocab(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"illuin/tiny-random-FalconForCausalLM\",\n                \"flash_attention\": True,\n                \"sequence_len\": 1024,\n                \"load_in_8bit\": True,\n                \"adapter\": \"lora\",\n                \"lora_r\": 32,\n                \"lora_alpha\": 64,\n                \"lora_dropout\": 0.05,\n                \"lora_target_linear\": True,\n                \"lora_modules_to_save\": [\n                    \"word_embeddings\",\n                    \"lm_head\",\n                ],\n                \"val_set_size\": 0.02,\n                \"special_tokens\": {\n                    \"bos_token\": \"<|endoftext|>\",\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"tokens\": [\n                    \"<|im_start|>\",\n                    \"<|im_end|>\",\n                ],\n                \"datasets\": [\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": \"alpaca\",\n                    },\n                ],\n                \"num_epochs\": 2,\n                \"micro_batch_size\": 2,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_torch_fused\",\n                \"lr_scheduler\": \"cosine\",\n                \"max_steps\": 20,\n                \"save_steps\": 10,\n                \"eval_steps\": 10,\n                \"bf16\": \"auto\",\n                \"save_first_step\": False,\n            }\n        )\n\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        check_model_output_exists(temp_dir, cfg)\n\n    @pytest.mark.skip(reason=\"no tiny models for testing with safetensors\")\n    @with_temp_dir\n    def test_ft(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"illuin/tiny-random-FalconForCausalLM\",\n                \"flash_attention\": True,\n                \"sequence_len\": 1024,\n                \"val_set_size\": 0.02,\n                \"special_tokens\": {\n                    \"bos_token\": \"<|endoftext|>\",\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": \"alpaca\",\n                    },\n                ],\n                \"num_epochs\": 2,\n                \"micro_batch_size\": 2,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_torch_fused\",\n                \"lr_scheduler\": \"cosine\",\n                \"max_steps\": 20,\n                \"save_steps\": 10,\n                \"eval_steps\": 10,\n                \"bf16\": \"auto\",\n                \"save_first_step\": False,\n            }\n        )\n\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        check_model_output_exists(temp_dir, cfg)\n"
  },
  {
    "path": "tests/e2e/test_gemma2.py",
    "content": "\"\"\"\nE2E tests for gemma2\n\"\"\"\n\nfrom pathlib import Path\n\nimport pytest\n\nfrom axolotl.common.datasets import load_datasets\nfrom axolotl.train import train\nfrom axolotl.utils.config import normalize_config, validate_config\nfrom axolotl.utils.dict import DictDefault\n\n\nclass TestGemma2:\n    \"\"\"\n    Test case for Gemma2 models\n    \"\"\"\n\n    @pytest.mark.parametrize(\n        \"sample_packing\",\n        [True, False],\n    )\n    def test_lora_gemma2(self, temp_dir, sample_packing):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"axolotl-ai-co/gemma-2-33M\",\n                \"trust_remote_code\": True,\n                \"sample_packing\": sample_packing,\n                \"flash_attention\": True,\n                \"sequence_len\": 2048,\n                \"adapter\": \"lora\",\n                \"lora_r\": 8,\n                \"lora_alpha\": 16,\n                \"lora_dropout\": 0.05,\n                \"lora_target_linear\": True,\n                \"val_set_size\": 0,\n                \"datasets\": [\n                    {\n                        \"path\": \"mlabonne/FineTome-100k\",\n                        \"type\": \"chat_template\",\n                        \"field_messages\": \"conversations\",\n                        \"message_property_mappings\": {\n                            \"role\": \"from\",\n                            \"content\": \"value\",\n                        },\n                        \"drop_system_message\": True,\n                        \"split\": \"train[:1%]\",\n                    },\n                ],\n                \"special_tokens\": {\n                    \"bos_token\": \"<bos>\",\n                    \"eos_token\": \"<eos>\",\n                },\n                \"chat_template\": \"gemma\",  # gemma2's template is same as gemma\n                \"num_epochs\": 1,\n                \"micro_batch_size\": 1,\n                \"gradient_accumulation_steps\": 2,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_bnb_8bit\",\n                \"lr_scheduler\": \"cosine\",\n                \"max_steps\": 5,\n                \"bf16\": True,\n            }\n        )\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        assert (Path(temp_dir) / \"adapter_model.safetensors\").exists()\n\n    @pytest.mark.parametrize(\n        \"sample_packing\",\n        [True, False],\n    )\n    def test_fft_gemma2(self, temp_dir, sample_packing):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"axolotl-ai-co/gemma-2-33M\",\n                \"trust_remote_code\": True,\n                \"sample_packing\": sample_packing,\n                \"flash_attention\": True,\n                \"sequence_len\": 2048,\n                \"val_set_size\": 0,\n                \"datasets\": [\n                    {\n                        \"path\": \"mlabonne/FineTome-100k\",\n                        \"type\": \"chat_template\",\n                        \"field_messages\": \"conversations\",\n                        \"message_property_mappings\": {\n                            \"role\": \"from\",\n                            \"content\": \"value\",\n                        },\n                        \"split\": \"train[:1%]\",\n                        \"drop_system_message\": True,\n                    },\n                ],\n                \"chat_template\": \"gemma\",  # gemma2's template is same as gemma\n                \"special_tokens\": {\n                    \"bos_token\": \"<bos>\",\n                    \"eos_token\": \"<eos>\",\n                },\n                \"num_epochs\": 1,\n                \"micro_batch_size\": 1,\n                \"gradient_accumulation_steps\": 2,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_bnb_8bit\",\n                \"lr_scheduler\": \"cosine\",\n                \"max_steps\": 5,\n                \"bf16\": True,\n            }\n        )\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        assert (Path(temp_dir) / \"model.safetensors\").exists()\n"
  },
  {
    "path": "tests/e2e/test_gemma3_text.py",
    "content": "\"\"\"\nE2E tests for gemma3_text\n\"\"\"\n\nfrom pathlib import Path\n\nimport pytest\n\nfrom axolotl.common.datasets import load_datasets\nfrom axolotl.train import train\nfrom axolotl.utils.config import normalize_config, validate_config\nfrom axolotl.utils.dict import DictDefault\n\n\nclass TestGemma3Text:\n    \"\"\"\n    Test case for Gemma3Text models\n    \"\"\"\n\n    @pytest.mark.parametrize(\n        \"sample_packing\",\n        [True, False],\n    )\n    def test_lora_gemma3_text(self, temp_dir, sample_packing):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"axolotl-ai-co/gemma-3-34M\",\n                \"trust_remote_code\": True,\n                \"sample_packing\": sample_packing,\n                \"flash_attention\": True,\n                \"sequence_len\": 2048,\n                \"adapter\": \"lora\",\n                \"lora_r\": 8,\n                \"lora_alpha\": 16,\n                \"lora_dropout\": 0.05,\n                \"lora_target_linear\": True,\n                \"val_set_size\": 0,\n                \"datasets\": [\n                    {\n                        \"path\": \"mlabonne/FineTome-100k\",\n                        \"type\": \"chat_template\",\n                        \"field_messages\": \"conversations\",\n                        \"message_property_mappings\": {\n                            \"role\": \"from\",\n                            \"content\": \"value\",\n                        },\n                        \"split\": \"train[:1%]\",\n                    },\n                ],\n                \"special_tokens\": {\n                    \"bos_token\": \"<bos>\",\n                    \"eos_token\": \"<eos>\",\n                },\n                \"chat_template\": \"gemma3\",\n                \"num_epochs\": 1,\n                \"micro_batch_size\": 1,\n                \"gradient_accumulation_steps\": 2,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_bnb_8bit\",\n                \"lr_scheduler\": \"cosine\",\n                \"max_steps\": 5,\n                \"bf16\": True,\n                \"save_first_step\": False,\n            }\n        )\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        assert (Path(temp_dir) / \"adapter_model.safetensors\").exists()\n\n    @pytest.mark.parametrize(\n        \"sample_packing\",\n        [True, False],\n    )\n    def test_fft_gemma3_text(self, temp_dir, sample_packing):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"axolotl-ai-co/gemma-3-34M\",\n                \"trust_remote_code\": True,\n                \"sample_packing\": sample_packing,\n                \"flash_attention\": True,\n                \"sequence_len\": 2048,\n                \"val_set_size\": 0,\n                \"datasets\": [\n                    {\n                        \"path\": \"mlabonne/FineTome-100k\",\n                        \"type\": \"chat_template\",\n                        \"field_messages\": \"conversations\",\n                        \"message_property_mappings\": {\n                            \"role\": \"from\",\n                            \"content\": \"value\",\n                        },\n                        \"split\": \"train[:1%]\",\n                    },\n                ],\n                \"chat_template\": \"gemma3\",\n                \"special_tokens\": {\n                    \"bos_token\": \"<bos>\",\n                    \"eos_token\": \"<eos>\",\n                },\n                \"num_epochs\": 1,\n                \"micro_batch_size\": 1,\n                \"gradient_accumulation_steps\": 2,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_bnb_8bit\",\n                \"lr_scheduler\": \"cosine\",\n                \"max_steps\": 5,\n                \"bf16\": True,\n                \"save_first_step\": False,\n            }\n        )\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        assert (Path(temp_dir) / \"model.safetensors\").exists()\n"
  },
  {
    "path": "tests/e2e/test_imports.py",
    "content": "\"\"\"\ntest module to import various submodules that have historically broken due to dependency issues\n\"\"\"\n\nimport unittest\n\n\nclass TestImports(unittest.TestCase):\n    \"\"\"\n    Test class to import various submodules that have historically broken due to dependency issues\n    \"\"\"\n\n    def test_import_causal_trainer(self):\n        pass\n\n    def test_import_rl_trainer(self):\n        pass\n"
  },
  {
    "path": "tests/e2e/test_llama.py",
    "content": "\"\"\"\nE2E tests for llama\n\"\"\"\n\nimport pytest\n\nfrom axolotl.common.datasets import load_datasets\nfrom axolotl.train import train\nfrom axolotl.utils.config import normalize_config, validate_config\nfrom axolotl.utils.dict import DictDefault\n\nfrom tests.e2e.utils import check_model_output_exists\n\n\nclass TestLlama:\n    \"\"\"\n    Test case for Llama models\n    \"\"\"\n\n    def test_fft_trust_remote_code(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"tokenizer_type\": \"AutoTokenizer\",\n                \"trust_remote_code\": True,\n                \"sequence_len\": 512,\n                \"val_set_size\": 0.02,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": \"alpaca\",\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"max_steps\": 5,\n                \"micro_batch_size\": 2,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_bnb_8bit\",\n                \"lr_scheduler\": \"cosine\",\n                \"flash_attention\": True,\n                \"sample_packing\": True,\n                \"bf16\": True,\n                \"save_first_step\": False,\n            }\n        )\n\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        check_model_output_exists(temp_dir, cfg)\n\n    def test_fix_untrained_tokens(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"fix_untrained_tokens\": True,\n                \"sequence_len\": 512,\n                \"val_set_size\": 0.0,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                    \"bos_token\": \"<|custom_im_start|>\",\n                    \"eos_token\": \"<|custom_im_end|>\",\n                },\n                \"datasets\": [\n                    {\n                        \"chat_template\": \"jinja\",\n                        \"chat_template_jinja\": \"{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|custom_im_start|>' + message['role'] + '\\n' + message['content'] + '<|custom_im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|custom_im_start|>assistant\\n' }}{% endif %}\",\n                        \"path\": \"mlabonne/FineTome-100k\",\n                        \"type\": \"chat_template\",\n                        \"split\": \"train[:10%]\",\n                        \"field_messages\": \"conversations\",\n                        \"message_field_role\": \"from\",\n                        \"message_field_content\": \"value\",\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"max_steps\": 5,\n                \"micro_batch_size\": 1,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_8bit\",\n                \"lr_scheduler\": \"cosine\",\n                \"flash_attention\": True,\n                \"sample_packing\": True,\n                \"bf16\": True,\n                \"save_first_step\": False,\n            }\n        )\n\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        check_model_output_exists(temp_dir, cfg)\n\n    def test_fix_untrained_tokens_already_trained(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"fix_untrained_tokens\": True,\n                \"sequence_len\": 512,\n                \"val_set_size\": 0.0,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"chat_template\": \"chatml\",\n                \"datasets\": [\n                    {\n                        \"path\": \"mlabonne/FineTome-100k\",\n                        \"type\": \"chat_template\",\n                        \"split\": \"train[:10%]\",\n                        \"field_messages\": \"conversations\",\n                        \"message_field_role\": \"from\",\n                        \"message_field_content\": \"value\",\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"max_steps\": 5,\n                \"micro_batch_size\": 1,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_8bit\",\n                \"lr_scheduler\": \"cosine\",\n                \"flash_attention\": True,\n                \"sample_packing\": True,\n                \"bf16\": True,\n                \"save_first_step\": False,\n            }\n        )\n\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        check_model_output_exists(temp_dir, cfg)\n\n    @pytest.mark.parametrize(\"tf32\", [\"auto\", False])\n    def test_batch_flattening(self, tf32, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"trust_remote_code\": True,\n                \"sequence_len\": 512,\n                \"val_set_size\": 0.01,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": \"alpaca\",\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"max_steps\": 5,\n                \"micro_batch_size\": 4,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_8bit\",\n                \"lr_scheduler\": \"cosine\",\n                \"flash_attention\": True,\n                \"sample_packing\": False,\n                \"batch_flattening\": True,\n                \"bf16\": True,\n                \"tf32\": tf32,\n                \"save_first_step\": False,\n            }\n        )\n\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        check_model_output_exists(temp_dir, cfg)\n"
  },
  {
    "path": "tests/e2e/test_llama_pretrain.py",
    "content": "\"\"\"E2E tests for llama pretrain\"\"\"\n\nimport pytest\n\nfrom axolotl.common.datasets import load_datasets\nfrom axolotl.train import train\nfrom axolotl.utils.config import normalize_config, validate_config\nfrom axolotl.utils.dict import DictDefault\n\nfrom .utils import check_model_output_exists, check_tensorboard\n\n\nclass TestPretrainLlama:\n    \"\"\"Test case for Llama models w pretraining\"\"\"\n\n    @pytest.mark.parametrize(\n        (\"sample_packing\", \"pretrain_multipack_attn\"),\n        [\n            (False, False),\n            (True, True),\n            (True, False),\n        ],\n    )\n    def test_pretrain(self, temp_dir, sample_packing, pretrain_multipack_attn):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"flash_attention\": True,\n                \"sequence_len\": 1024,\n                \"sample_packing\": sample_packing,\n                \"pretrain_multipack_attn\": pretrain_multipack_attn,\n                \"dataset_num_proc\": 1,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"pretraining_dataset\": [\n                    {\n                        \"path\": \"allenai/c4\",\n                        \"name\": \"en\",\n                        \"type\": \"pretrain\",\n                    }\n                ],\n                \"max_steps\": 5,\n                \"num_epochs\": 1,\n                \"micro_batch_size\": 2,\n                \"gradient_accumulation_steps\": 1,\n                \"val_set_size\": 0.0,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_torch_fused\",\n                \"lr_scheduler\": \"cosine\",\n                \"bf16\": \"auto\",\n                \"use_tensorboard\": True,\n                \"save_first_step\": False,\n            }\n        )\n\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        check_model_output_exists(temp_dir, cfg)\n        loss_threshold = 3.6\n        if sample_packing and not pretrain_multipack_attn:\n            loss_threshold = 6.5\n        check_tensorboard(\n            temp_dir + \"/runs\",\n            \"train/train_loss\",\n            loss_threshold,\n            \"Train Loss (%s) is too high\",\n        )\n"
  },
  {
    "path": "tests/e2e/test_llama_vision.py",
    "content": "\"\"\"\nE2E tests for lora llama\n\"\"\"\n\nimport unittest\n\nfrom axolotl.common.datasets import load_datasets\nfrom axolotl.train import train\nfrom axolotl.utils.config import normalize_config, validate_config\nfrom axolotl.utils.dict import DictDefault\n\nfrom .utils import check_model_output_exists, with_temp_dir\n\n\nclass TestLlamaVision(unittest.TestCase):\n    \"\"\"\n    Test case for Llama Vision models\n    \"\"\"\n\n    @with_temp_dir\n    def test_lora_llama_vision_text_only_dataset(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"axolotl-ai-co/Llama-3.2-39M-Vision\",\n                \"processor_type\": \"AutoProcessor\",\n                \"skip_prepare_dataset\": True,\n                \"remove_unused_columns\": False,\n                \"sample_packing\": False,\n                \"sequence_len\": 1024,\n                \"adapter\": \"lora\",\n                \"lora_r\": 8,\n                \"lora_alpha\": 16,\n                \"lora_dropout\": 0.05,\n                \"lora_target_modules\": r\"model.language_model.layers.[\\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj\",\n                \"val_set_size\": 0,\n                \"chat_template\": \"llama3_2_vision\",\n                \"datasets\": [\n                    {\n                        \"path\": \"LDJnr/Puffin\",\n                        \"type\": \"chat_template\",\n                        \"field_messages\": \"conversations\",\n                        \"message_field_role\": \"from\",\n                        \"message_field_content\": \"value\",\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"micro_batch_size\": 1,\n                \"gradient_accumulation_steps\": 2,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_bnb_8bit\",\n                \"lr_scheduler\": \"cosine\",\n                \"max_steps\": 5,\n                \"bf16\": True,\n                \"save_first_step\": False,\n            }\n        )\n\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        check_model_output_exists(temp_dir, cfg)\n\n    @with_temp_dir\n    def test_lora_llama_vision_multimodal_dataset(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"axolotl-ai-co/Llama-3.2-39M-Vision\",\n                \"processor_type\": \"AutoProcessor\",\n                \"skip_prepare_dataset\": True,\n                \"remove_unused_columns\": False,\n                \"sample_packing\": False,\n                \"sequence_len\": 1024,\n                \"adapter\": \"lora\",\n                \"lora_r\": 8,\n                \"lora_alpha\": 16,\n                \"lora_dropout\": 0.05,\n                \"lora_target_modules\": r\"model.language_model.layers.[\\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj\",\n                \"val_set_size\": 0,\n                \"chat_template\": \"llama3_2_vision\",\n                \"datasets\": [\n                    {\n                        \"path\": \"axolotl-ai-co/llava-instruct-mix-vsft-small\",\n                        \"type\": \"chat_template\",\n                        \"split\": \"train\",\n                        \"field_messages\": \"messages\",\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"micro_batch_size\": 1,\n                \"gradient_accumulation_steps\": 2,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_bnb_8bit\",\n                \"lr_scheduler\": \"cosine\",\n                \"max_steps\": 5,\n                \"bf16\": True,\n                \"save_first_step\": False,\n            }\n        )\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        check_model_output_exists(temp_dir, cfg)\n"
  },
  {
    "path": "tests/e2e/test_load_model.py",
    "content": "\"\"\"Module for testing ModelLoader.\"\"\"\n\nimport shutil\nimport tempfile\n\nimport pytest\nimport torch\n\nfrom axolotl.loaders import ModelLoader, load_tokenizer\nfrom axolotl.utils.dict import DictDefault\n\n\n@pytest.fixture(name=\"temp_dir\")\ndef fixture_temp_dir():\n    temp_dir = tempfile.mkdtemp()\n    yield temp_dir\n    shutil.rmtree(temp_dir)\n\n\nclass TestLoadModelUtils:\n    \"\"\"\n    Testing module testing ModelLoader.\n    \"\"\"\n\n    def setup_method(self):\n        # load config\n        self.cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"tokenizer_type\": \"AutoTokenizer\",\n                \"tokenizer_config\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"sequence_len\": 1024,\n                \"load_in_8bit\": False,\n                \"adapter\": \"lora\",\n                \"lora_r\": 8,\n                \"lora_alpha\": 16,\n                \"lora_dropout\": 0.05,\n                \"lora_target_linear\": True,\n                \"val_set_size\": 0.02,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": \"alpaca\",\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"micro_batch_size\": 8,\n                \"gradient_accumulation_steps\": 1,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_torch_fused\",\n                \"lr_scheduler\": \"cosine\",\n                \"tensor_parallel_size\": 1,\n                \"context_parallel_size\": 1,\n            }\n        )\n        self.model_loader = ModelLoader(\n            cfg=self.cfg,\n            tokenizer=\"\",\n            inference=False,\n            reference_model=True,\n        )\n\n    @pytest.mark.parametrize(\"embedding_modules\", [\"embed_tokens\", \"lm_head\"])\n    @pytest.mark.parametrize(\n        \"dist_dtype\", [torch.bfloat16, torch.float16, torch.float32]\n    )\n    @pytest.mark.parametrize(\"before_kbit_train_or_finetune\", [True, False])\n    def test_convert_embedding_modules_dtype(\n        self, temp_dir, embedding_modules, dist_dtype, before_kbit_train_or_finetune\n    ):\n        self.cfg.output_dir = temp_dir\n        self.model_loader.tokenizer = load_tokenizer(self.cfg)\n        self.model_loader.load()\n        self.model_loader._convert_embedding_modules_dtype(\n            embedding_modules, dist_dtype, before_kbit_train_or_finetune\n        )\n        for name, module in self.model_loader.model.named_modules():\n            if (\n                \"norm\" in name\n                or (before_kbit_train_or_finetune and name.endswith(\".gate\"))\n                or (\n                    any(m in name for m in embedding_modules)\n                    and hasattr(module, \"weight\")\n                )\n            ):\n                for _, param in module.named_parameters():\n                    assert param.dtype == dist_dtype\n"
  },
  {
    "path": "tests/e2e/test_lora_llama.py",
    "content": "\"\"\"\nE2E tests for lora llama\n\"\"\"\n\nimport unittest\n\nfrom axolotl.common.datasets import load_datasets\nfrom axolotl.train import train\nfrom axolotl.utils.config import normalize_config, validate_config\nfrom axolotl.utils.dict import DictDefault\n\nfrom .utils import check_model_output_exists, with_temp_dir\n\n\nclass TestLoraLlama(unittest.TestCase):\n    \"\"\"\n    Test case for Llama models using LoRA\n    \"\"\"\n\n    @with_temp_dir\n    def test_lora(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"tokenizer_type\": \"AutoTokenizer\",\n                \"sequence_len\": 1024,\n                \"load_in_8bit\": True,\n                \"adapter\": \"lora\",\n                \"lora_r\": 8,\n                \"lora_alpha\": 16,\n                \"lora_dropout\": 0.05,\n                \"lora_target_linear\": True,\n                \"val_set_size\": 0.02,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": \"alpaca\",\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"micro_batch_size\": 2,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_torch_fused\",\n                \"lr_scheduler\": \"cosine\",\n                \"max_steps\": 5,\n                \"save_first_step\": False,\n            }\n        )\n\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        check_model_output_exists(temp_dir, cfg)\n"
  },
  {
    "path": "tests/e2e/test_mamba.py",
    "content": "\"\"\"\nE2E tests for lora llama\n\"\"\"\n\nimport unittest\n\nimport pytest\n\nfrom axolotl.common.datasets import load_datasets\nfrom axolotl.train import train\nfrom axolotl.utils.config import normalize_config, validate_config\nfrom axolotl.utils.dict import DictDefault\n\nfrom .utils import check_model_output_exists, with_temp_dir\n\n\n@pytest.mark.skip(reason=\"skipping until upstreamed into transformers\")\nclass TestMamba(unittest.TestCase):\n    \"\"\"\n    Test case for Mamba models\n    \"\"\"\n\n    @with_temp_dir\n    def test_fft(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"state-spaces/mamba-130m\",\n                \"model_type\": \"MambaLMHeadModel\",\n                \"tokenizer_type\": \"AutoTokenizer\",\n                \"tokenizer_config\": \"EleutherAI/gpt-neox-20b\",\n                \"flash_attention\": False,\n                \"sequence_len\": 1024,\n                \"load_in_8bit\": False,\n                \"val_set_size\": 0.0,\n                \"datasets\": [\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": \"alpaca\",\n                    },\n                ],\n                \"gradient_checkpointing\": False,\n                \"num_epochs\": 2,\n                \"micro_batch_size\": 2,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_torch_fused\",\n                \"lr_scheduler\": \"cosine\",\n                \"max_steps\": 20,\n                \"save_steps\": 10,\n                \"eval_steps\": None,\n                \"save_first_step\": False,\n            }\n        )\n\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        check_model_output_exists(temp_dir, cfg)\n"
  },
  {
    "path": "tests/e2e/test_mistral.py",
    "content": "\"\"\"\nE2E tests for lora llama\n\"\"\"\n\nimport unittest\n\nfrom transformers.utils import is_torch_bf16_gpu_available\n\nfrom axolotl.common.datasets import load_datasets\nfrom axolotl.train import train\nfrom axolotl.utils.config import normalize_config, validate_config\nfrom axolotl.utils.dict import DictDefault\n\nfrom .utils import check_model_output_exists, with_temp_dir\n\n\nclass TestMistral(unittest.TestCase):\n    \"\"\"\n    Test case for Llama models using LoRA\n    \"\"\"\n\n    @with_temp_dir\n    def test_lora(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"trl-internal-testing/tiny-MistralForCausalLM-0.2\",\n                \"flash_attention\": True,\n                \"sequence_len\": 1024,\n                \"load_in_8bit\": True,\n                \"adapter\": \"lora\",\n                \"lora_r\": 32,\n                \"lora_alpha\": 64,\n                \"lora_dropout\": 0.05,\n                \"lora_target_linear\": True,\n                \"val_set_size\": 0.02,\n                \"special_tokens\": {\n                    \"unk_token\": \"<unk>\",\n                    \"bos_token\": \"<s>\",\n                    \"eos_token\": \"</s>\",\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": \"alpaca\",\n                    },\n                ],\n                \"num_epochs\": 2,\n                \"micro_batch_size\": 2,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_torch_fused\",\n                \"lr_scheduler\": \"cosine\",\n                \"max_steps\": 20,\n                \"save_steps\": 10,\n                \"eval_steps\": 10,\n                \"save_first_step\": False,\n            }\n        )\n\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        check_model_output_exists(temp_dir, cfg)\n\n    @with_temp_dir\n    def test_ft(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"trl-internal-testing/tiny-MistralForCausalLM-0.2\",\n                \"flash_attention\": True,\n                \"sequence_len\": 1024,\n                \"val_set_size\": 0.02,\n                \"special_tokens\": {\n                    \"unk_token\": \"<unk>\",\n                    \"bos_token\": \"<s>\",\n                    \"eos_token\": \"</s>\",\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": \"alpaca\",\n                    },\n                ],\n                \"num_epochs\": 2,\n                \"micro_batch_size\": 2,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_torch_fused\",\n                \"lr_scheduler\": \"cosine\",\n                \"max_steps\": 20,\n                \"save_steps\": 10,\n                \"eval_steps\": 10,\n                \"save_first_step\": False,\n            }\n        )\n        if is_torch_bf16_gpu_available():\n            cfg.bf16 = True\n        else:\n            cfg.fp16 = True\n\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        check_model_output_exists(temp_dir, cfg)\n"
  },
  {
    "path": "tests/e2e/test_mixtral.py",
    "content": "\"\"\"\nE2E tests for mixtral\n\"\"\"\n\nimport unittest\n\nimport torch\nfrom transformers.utils import is_torch_bf16_gpu_available\n\nfrom axolotl.common.datasets import load_datasets\nfrom axolotl.train import train\nfrom axolotl.utils.config import normalize_config, validate_config\nfrom axolotl.utils.dict import DictDefault\n\nfrom .utils import check_model_output_exists, with_temp_dir\n\n\nclass TestMixtral(unittest.TestCase):\n    \"\"\"\n    Test case for Llama models using LoRA\n    \"\"\"\n\n    @with_temp_dir\n    def test_qlora_w_fa2(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"hf-internal-testing/Mixtral-tiny\",\n                \"tokenizer_config\": \"LoneStriker/Mixtral-8x7B-v0.1-HF\",\n                \"flash_attention\": True,\n                \"sequence_len\": 1024,\n                \"load_in_4bit\": True,\n                \"adapter\": \"qlora\",\n                \"lora_r\": 4,\n                \"lora_alpha\": 8,\n                \"lora_dropout\": 0.1,\n                \"lora_target_modules\": [\n                    \"o_proj\",\n                    \"w3\",\n                    \"k_proj\",\n                    \"v_proj\",\n                    \"w1\",\n                    \"q_proj\",\n                    \"w2\",\n                ],\n                \"val_set_size\": 0.02,\n                \"special_tokens\": {},\n                \"datasets\": [\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": \"alpaca\",\n                    },\n                ],\n                \"num_epochs\": 2,\n                \"micro_batch_size\": 2,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_bnb_8bit\",\n                \"lr_scheduler\": \"cosine\",\n                \"max_steps\": 20,\n                \"save_steps\": 10,\n                \"eval_steps\": 10,\n                \"save_first_step\": False,\n            }\n        )\n\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)\n        assert (\n            model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype\n            == torch.float32\n        )\n        check_model_output_exists(temp_dir, cfg)\n\n    @with_temp_dir\n    def test_qlora_wo_fa2(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"hf-internal-testing/Mixtral-tiny\",\n                \"tokenizer_config\": \"LoneStriker/Mixtral-8x7B-v0.1-HF\",\n                \"flash_attention\": False,\n                \"sequence_len\": 1024,\n                \"load_in_4bit\": True,\n                \"adapter\": \"qlora\",\n                \"lora_r\": 4,\n                \"lora_alpha\": 8,\n                \"lora_dropout\": 0.1,\n                \"lora_target_modules\": [\n                    \"o_proj\",\n                    \"w3\",\n                    \"k_proj\",\n                    \"v_proj\",\n                    \"w1\",\n                    \"q_proj\",\n                    \"w2\",\n                ],\n                \"val_set_size\": 0.02,\n                \"special_tokens\": {},\n                \"datasets\": [\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": \"alpaca\",\n                    },\n                ],\n                \"num_epochs\": 2,\n                \"micro_batch_size\": 2,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_bnb_8bit\",\n                \"lr_scheduler\": \"cosine\",\n                \"max_steps\": 20,\n                \"save_steps\": 10,\n                \"eval_steps\": 10,\n                \"save_first_step\": False,\n            }\n        )\n\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)\n        assert (\n            model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype\n            == torch.float32\n        )\n        check_model_output_exists(temp_dir, cfg)\n\n    @with_temp_dir\n    def test_16bit_lora_w_fa2(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"hf-internal-testing/Mixtral-tiny\",\n                \"tokenizer_config\": \"LoneStriker/Mixtral-8x7B-v0.1-HF\",\n                \"flash_attention\": True,\n                \"sequence_len\": 1024,\n                \"adapter\": \"lora\",\n                \"lora_r\": 4,\n                \"lora_alpha\": 8,\n                \"lora_dropout\": 0.1,\n                \"lora_target_modules\": [\n                    \"o_proj\",\n                    \"w3\",\n                    \"k_proj\",\n                    \"v_proj\",\n                    \"w1\",\n                    \"q_proj\",\n                    \"w2\",\n                ],\n                \"val_set_size\": 0.02,\n                \"special_tokens\": {},\n                \"datasets\": [\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": \"alpaca\",\n                    },\n                ],\n                \"num_epochs\": 2,\n                \"micro_batch_size\": 2,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_bnb_8bit\",\n                \"lr_scheduler\": \"cosine\",\n                \"max_steps\": 20,\n                \"save_steps\": 10,\n                \"eval_steps\": 10,\n                \"save_first_step\": False,\n            }\n        )\n        if is_torch_bf16_gpu_available():\n            cfg.bf16 = True\n        else:\n            cfg.fp16 = True\n\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)\n        assert (\n            model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype\n            == torch.float32\n        )\n        check_model_output_exists(temp_dir, cfg)\n\n    @with_temp_dir\n    def test_16bit_lora_wo_fa2(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"hf-internal-testing/Mixtral-tiny\",\n                \"tokenizer_config\": \"LoneStriker/Mixtral-8x7B-v0.1-HF\",\n                \"flash_attention\": False,\n                \"sequence_len\": 1024,\n                \"adapter\": \"lora\",\n                \"lora_r\": 4,\n                \"lora_alpha\": 8,\n                \"lora_dropout\": 0.1,\n                \"lora_target_modules\": [\n                    \"o_proj\",\n                    \"w3\",\n                    \"k_proj\",\n                    \"v_proj\",\n                    \"w1\",\n                    \"q_proj\",\n                    \"w2\",\n                ],\n                \"val_set_size\": 0.02,\n                \"special_tokens\": {},\n                \"datasets\": [\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": \"alpaca\",\n                    },\n                ],\n                \"num_epochs\": 2,\n                \"micro_batch_size\": 2,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_bnb_8bit\",\n                \"lr_scheduler\": \"cosine\",\n                \"max_steps\": 20,\n                \"save_steps\": 10,\n                \"eval_steps\": 10,\n                \"save_first_step\": False,\n            }\n        )\n\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        if is_torch_bf16_gpu_available():\n            cfg.bf16 = True\n        else:\n            cfg.fp16 = True\n        dataset_meta = load_datasets(cfg=cfg)\n\n        model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)\n        assert (\n            model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype\n            == torch.float32\n        )\n        check_model_output_exists(temp_dir, cfg)\n\n    @with_temp_dir\n    def test_ft(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"hf-internal-testing/Mixtral-tiny\",\n                \"tokenizer_config\": \"LoneStriker/Mixtral-8x7B-v0.1-HF\",\n                \"flash_attention\": True,\n                \"sequence_len\": 1024,\n                \"val_set_size\": 0.02,\n                \"special_tokens\": {},\n                \"datasets\": [\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": \"alpaca\",\n                    },\n                ],\n                \"num_epochs\": 2,\n                \"micro_batch_size\": 2,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_bnb_8bit\",\n                \"lr_scheduler\": \"cosine\",\n                \"max_steps\": 20,\n                \"save_steps\": 10,\n                \"eval_steps\": 10,\n                \"save_first_step\": False,\n            }\n        )\n        if is_torch_bf16_gpu_available():\n            cfg.bf16 = True\n        else:\n            cfg.fp16 = True\n\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        check_model_output_exists(temp_dir, cfg)\n"
  },
  {
    "path": "tests/e2e/test_optimizers.py",
    "content": "\"\"\"\nE2E tests for custom optimizers using Llama\n\"\"\"\n\nimport unittest\n\nimport pytest\n\nfrom axolotl.common.datasets import load_datasets\nfrom axolotl.train import train\nfrom axolotl.utils.config import normalize_config, validate_config\nfrom axolotl.utils.dict import DictDefault\n\nfrom .utils import (\n    check_model_output_exists,\n    require_torch_2_5_1,\n    require_torch_2_6_0,\n    require_torch_2_7_0,\n    with_temp_dir,\n)\n\n\nclass TestCustomOptimizers(unittest.TestCase):\n    \"\"\"\n    Test case for Llama models using LoRA\n    \"\"\"\n\n    @with_temp_dir\n    def test_optimi_adamw(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"model_type\": \"AutoModelForCausalLM\",\n                \"tokenizer_type\": \"AutoTokenizer\",\n                \"sequence_len\": 1024,\n                \"load_in_8bit\": True,\n                \"adapter\": \"lora\",\n                \"lora_r\": 8,\n                \"lora_alpha\": 16,\n                \"lora_dropout\": 0.05,\n                \"lora_target_linear\": True,\n                \"val_set_size\": 0.02,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": \"alpaca\",\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"micro_batch_size\": 8,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"optimi_adamw\",\n                \"max_steps\": 5,\n                \"lr_scheduler\": \"cosine\",\n                \"save_first_step\": False,\n            }\n        )\n\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        _, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)\n        check_model_output_exists(temp_dir, cfg)\n        assert trainer.optimizer.optimizer.__class__.__name__ == \"AdamW\"\n\n    @with_temp_dir\n    @require_torch_2_5_1\n    def test_adopt_adamw(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"model_type\": \"AutoModelForCausalLM\",\n                \"tokenizer_type\": \"AutoTokenizer\",\n                \"sequence_len\": 1024,\n                \"load_in_8bit\": True,\n                \"adapter\": \"lora\",\n                \"lora_r\": 8,\n                \"lora_alpha\": 16,\n                \"lora_dropout\": 0.05,\n                \"lora_target_linear\": True,\n                \"val_set_size\": 0.02,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": \"alpaca\",\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"max_steps\": 5,\n                \"micro_batch_size\": 8,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adopt_adamw\",\n                \"lr_scheduler\": \"cosine\",\n                \"save_first_step\": False,\n            }\n        )\n\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        _, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)\n        check_model_output_exists(temp_dir, cfg)\n        assert \"ADOPT\" in trainer.optimizer.optimizer.__class__.__name__\n\n    @with_temp_dir\n    @require_torch_2_5_1\n    def test_muon(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"model_type\": \"AutoModelForCausalLM\",\n                \"tokenizer_type\": \"AutoTokenizer\",\n                \"sequence_len\": 1024,\n                \"load_in_8bit\": True,\n                \"adapter\": \"lora\",\n                \"lora_r\": 8,\n                \"lora_alpha\": 16,\n                \"lora_dropout\": 0.05,\n                \"lora_target_linear\": True,\n                \"val_set_size\": 0.02,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": \"alpaca\",\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"max_steps\": 5,\n                \"micro_batch_size\": 8,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"muon\",\n                \"lr_scheduler\": \"cosine\",\n                \"weight_decay\": 0.01,\n                \"save_first_step\": False,\n            }\n        )\n\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        _, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)\n        check_model_output_exists(temp_dir, cfg)\n        assert \"Muon\" in trainer.optimizer.optimizer.__class__.__name__\n\n    @with_temp_dir\n    @require_torch_2_7_0\n    def test_dion(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"model_type\": \"AutoModelForCausalLM\",\n                \"tokenizer_type\": \"AutoTokenizer\",\n                \"sequence_len\": 1024,\n                \"val_set_size\": 0.0,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": \"alpaca\",\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"max_steps\": 5,\n                \"micro_batch_size\": 8,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"dion\",\n                \"dion_lr\": 0.01,\n                \"dion_momentum\": 0.95,\n                \"lr_scheduler\": \"cosine\",\n                \"weight_decay\": 0.01,\n                \"save_first_step\": False,\n            }\n        )\n\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        _, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)\n        check_model_output_exists(temp_dir, cfg)\n        assert \"Dion\" in trainer.optimizer.optimizer.__class__.__name__\n\n    @with_temp_dir\n    def test_fft_schedule_free_adamw(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"model_type\": \"AutoModelForCausalLM\",\n                \"sequence_len\": 1024,\n                \"val_set_size\": 0.01,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": \"alpaca\",\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"micro_batch_size\": 2,\n                \"gradient_accumulation_steps\": 2,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"schedule_free_adamw\",\n                \"lr_scheduler\": \"constant\",\n                \"max_steps\": 10,\n                \"save_first_step\": False,\n            }\n        )\n\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        check_model_output_exists(temp_dir, cfg)\n\n    @with_temp_dir\n    @require_torch_2_6_0\n    def test_came_pytorch(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"JackFram/llama-68m\",\n                \"tokenizer_type\": \"LlamaTokenizer\",\n                \"sequence_len\": 1024,\n                \"load_in_8bit\": True,\n                \"adapter\": \"lora\",\n                \"lora_r\": 8,\n                \"lora_alpha\": 16,\n                \"lora_dropout\": 0.05,\n                \"lora_target_linear\": True,\n                \"val_set_size\": 0.1,\n                \"special_tokens\": {\n                    \"unk_token\": \"<unk>\",\n                    \"bos_token\": \"<s>\",\n                    \"eos_token\": \"</s>\",\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": \"alpaca\",\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"micro_batch_size\": 8,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"came_pytorch\",\n                \"adam_beta3\": 0.9999,\n                \"adam_epsilon2\": 1e-16,\n                \"max_steps\": 5,\n                \"lr_scheduler\": \"cosine\",\n                \"save_first_step\": False,\n            }\n        )\n\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        check_model_output_exists(temp_dir, cfg)\n\n\n@require_torch_2_7_0\n@pytest.mark.parametrize(\n    \"optimizer_name,expected_class,learning_rate\",\n    [\n        (\"flash_adamw\", \"FlashAdamW\", 0.00001),\n        (\"flash_adam\", \"FlashAdam\", 0.00001),\n        (\"flash_sgd\", \"FlashSGD\", 0.01),\n        (\"flash_sgdw\", \"FlashSGDW\", 0.01),\n        (\"flash_lion\", \"FlashLion\", 0.0001),\n    ],\n)\ndef test_flash_optimizers(tmp_path, optimizer_name, expected_class, learning_rate):\n    temp_dir = str(tmp_path)\n    cfg = DictDefault(\n        {\n            \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n            \"model_type\": \"AutoModelForCausalLM\",\n            \"tokenizer_type\": \"AutoTokenizer\",\n            \"sequence_len\": 1024,\n            \"load_in_8bit\": True,\n            \"adapter\": \"lora\",\n            \"lora_r\": 8,\n            \"lora_alpha\": 16,\n            \"lora_dropout\": 0.05,\n            \"lora_target_linear\": True,\n            \"val_set_size\": 0.02,\n            \"special_tokens\": {\n                \"pad_token\": \"<|endoftext|>\",\n            },\n            \"datasets\": [\n                {\n                    \"path\": \"mhenrichsen/alpaca_2k_test\",\n                    \"type\": \"alpaca\",\n                },\n            ],\n            \"num_epochs\": 1,\n            \"micro_batch_size\": 8,\n            \"gradient_accumulation_steps\": 1,\n            \"output_dir\": temp_dir,\n            \"learning_rate\": learning_rate,\n            \"optimizer\": optimizer_name,\n            \"max_steps\": 5,\n            \"lr_scheduler\": \"cosine\",\n            \"save_first_step\": False,\n        }\n    )\n\n    cfg = validate_config(cfg)\n    normalize_config(cfg)\n    dataset_meta = load_datasets(cfg=cfg)\n\n    _, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)\n    check_model_output_exists(temp_dir, cfg)\n    assert trainer.optimizer.optimizer.__class__.__name__ == expected_class\n"
  },
  {
    "path": "tests/e2e/test_packing_loss.py",
    "content": "\"\"\"\nE2E tests for packed training\n\"\"\"\n\nimport unittest\n\nfrom transformers.utils import is_torch_bf16_gpu_available\n\nfrom axolotl.common.datasets import load_datasets\nfrom axolotl.train import train\nfrom axolotl.utils.config import normalize_config, validate_config\nfrom axolotl.utils.dict import DictDefault\n\nfrom .utils import check_tensorboard, with_temp_dir\n\n\nclass TestPackedLlama(unittest.TestCase):\n    \"\"\"\n    Test case for Packed training of llama models\n    \"\"\"\n\n    @with_temp_dir\n    def test_loss_packed(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"sequence_len\": 1024,\n                \"sample_packing\": True,\n                \"flash_attention\": True,\n                \"val_set_size\": 0.0,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"tatsu-lab/alpaca\",\n                        \"type\": \"alpaca\",\n                        \"split\": \"train[:10%]\",\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"micro_batch_size\": 2,\n                \"gradient_accumulation_steps\": 2,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_torch_fused\",\n                \"lr_scheduler\": \"cosine\",\n                \"max_steps\": 5,\n                \"use_tensorboard\": True,\n                \"save_first_step\": False,\n            }\n        )\n        if is_torch_bf16_gpu_available():\n            cfg.bf16 = True\n        else:\n            cfg.fp16 = True\n\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n\n        check_tensorboard(\n            temp_dir + \"/runs\", \"train/train_loss\", 2.0, \"Train Loss (%s) is too high\"\n        )\n"
  },
  {
    "path": "tests/e2e/test_phi.py",
    "content": "\"\"\"\nE2E tests for lora llama\n\"\"\"\n\nimport unittest\n\nfrom axolotl.common.datasets import load_datasets\nfrom axolotl.train import train\nfrom axolotl.utils.config import normalize_config, validate_config\nfrom axolotl.utils.dict import DictDefault\n\nfrom .utils import check_model_output_exists, with_temp_dir\n\n\nclass TestPhi(unittest.TestCase):\n    \"\"\"\n    Test case for Phi2 models\n    \"\"\"\n\n    @with_temp_dir\n    def test_phi_ft(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"microsoft/phi-1_5\",\n                \"model_type\": \"AutoModelForCausalLM\",\n                \"tokenizer_type\": \"AutoTokenizer\",\n                \"sequence_len\": 2048,\n                \"sample_packing\": False,\n                \"load_in_8bit\": False,\n                \"adapter\": None,\n                \"val_set_size\": 0.02,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": \"alpaca\",\n                    },\n                ],\n                \"dataset_shard_num\": 10,\n                \"dataset_shard_idx\": 0,\n                \"num_epochs\": 1,\n                \"micro_batch_size\": 1,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"paged_adamw_8bit\",\n                \"lr_scheduler\": \"cosine\",\n                \"flash_attention\": True,\n                \"max_steps\": 10,\n                \"save_steps\": 10,\n                \"eval_steps\": 10,\n                \"bf16\": \"auto\",\n                \"save_first_step\": False,\n            }\n        )\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        check_model_output_exists(temp_dir, cfg)\n\n    @with_temp_dir\n    def test_phi_qlora(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"microsoft/phi-1_5\",\n                \"model_type\": \"AutoModelForCausalLM\",\n                \"tokenizer_type\": \"AutoTokenizer\",\n                \"sequence_len\": 2048,\n                \"sample_packing\": False,\n                \"load_in_4bit\": True,\n                \"adapter\": \"qlora\",\n                \"lora_r\": 64,\n                \"lora_alpha\": 32,\n                \"lora_dropout\": 0.05,\n                \"lora_target_linear\": True,\n                \"val_set_size\": 0.02,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": \"alpaca\",\n                    },\n                ],\n                \"dataset_shard_num\": 10,\n                \"dataset_shard_idx\": 0,\n                \"num_epochs\": 1,\n                \"micro_batch_size\": 1,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"paged_adamw_8bit\",\n                \"lr_scheduler\": \"cosine\",\n                \"flash_attention\": True,\n                \"max_steps\": 10,\n                \"save_steps\": 10,\n                \"eval_steps\": 10,\n                \"bf16\": \"auto\",\n                \"save_first_step\": False,\n            }\n        )\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        check_model_output_exists(temp_dir, cfg)\n"
  },
  {
    "path": "tests/e2e/test_preprocess.py",
    "content": "\"\"\"E2E Test the preprocess cli\"\"\"\n\nfrom pathlib import Path\n\nimport yaml\nfrom accelerate.test_utils import execute_subprocess_async\n\nfrom axolotl.utils.dict import DictDefault\n\nAXOLOTL_ROOT = Path(__file__).parent.parent.parent\n\n\nclass TestPreprocess:\n    \"\"\"test cases for preprocess\"\"\"\n\n    def test_w_deepspeed(self, temp_dir):\n        \"\"\"make sure preprocess doesn't choke when using deepspeed in the config\"\"\"\n\n        cfg = DictDefault(\n            {\n                \"base_model\": \"Qwen/Qwen2.5-0.5B\",\n                \"sequence_len\": 2048,\n                \"val_set_size\": 0.01,\n                \"datasets\": [\n                    {\n                        \"path\": \"tatsu-lab/alpaca\",\n                        \"type\": \"alpaca\",\n                        \"split\": \"train[:10%]\",\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"micro_batch_size\": 2,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_torch_fused\",\n                \"lr_scheduler\": \"cosine\",\n                \"flash_attention\": True,\n                \"bf16\": \"auto\",\n                \"deepspeed\": str(AXOLOTL_ROOT / \"deepspeed_configs/zero1.json\"),\n                \"dataset_prepared_path\": temp_dir + \"/last_run_prepared\",\n            }\n        )\n\n        # write cfg to yaml file\n        Path(temp_dir).mkdir(parents=True, exist_ok=True)\n        with open(Path(temp_dir) / \"config.yaml\", \"w\", encoding=\"utf-8\") as fout:\n            fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))\n\n        execute_subprocess_async(\n            [\n                \"axolotl\",\n                \"preprocess\",\n                str(Path(temp_dir) / \"config.yaml\"),\n            ]\n        )\n\n        assert (Path(temp_dir) / \"last_run_prepared\").exists()\n"
  },
  {
    "path": "tests/e2e/test_process_reward_model_smollm2.py",
    "content": "\"\"\"\nE2E tests for process reward model w/ lora llama\n\"\"\"\n\nimport unittest\n\nfrom axolotl.common.datasets import load_datasets\nfrom axolotl.train import train\nfrom axolotl.utils.config import normalize_config, validate_config\nfrom axolotl.utils.dict import DictDefault\n\nfrom .utils import check_model_output_exists, check_tensorboard, with_temp_dir\n\n\nclass TestProcessRewardSmolLM2(unittest.TestCase):\n    \"\"\"\n    Test case for Llama process reward models using LoRA\n    \"\"\"\n\n    @with_temp_dir\n    def test_prm(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"model_type\": \"AutoModelForTokenClassification\",\n                \"num_labels\": 2,\n                \"process_reward_model\": True,\n                \"sequence_len\": 512,\n                \"val_set_size\": 0.0,\n                \"datasets\": [\n                    {\n                        \"path\": \"trl-lib/math_shepherd\",\n                        \"type\": \"stepwise_supervised\",\n                        \"step_separator\": \"\\n\",\n                        \"split\": \"train[:10%]\",\n                    },\n                ],\n                \"max_steps\": 100,\n                \"num_epochs\": 1,\n                \"micro_batch_size\": 4,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.0005,\n                \"optimizer\": \"adamw_torch\",\n                \"lr_scheduler\": \"cosine\",\n                \"gradient_checkpointing\": True,\n                \"warmup_ratio\": 0.1,\n                \"use_tensorboard\": True,\n                \"special_tokens\": {\"pad_token\": \"<|endoftext|>\"},\n                \"seed\": 42,\n                \"save_first_step\": False,\n            }\n        )\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        check_tensorboard(\n            temp_dir + \"/runs\", \"train/train_loss\", 2.7, \"Train Loss (%s) is too high\"\n        )\n\n        check_model_output_exists(temp_dir, cfg)\n"
  },
  {
    "path": "tests/e2e/test_profiler.py",
    "content": "\"\"\"\ne2e gpu test for the pytorch profiler callback\n\"\"\"\n\nfrom pathlib import Path\n\nimport pytest\n\nfrom axolotl.common.datasets import load_datasets\nfrom axolotl.train import train\nfrom axolotl.utils.config import normalize_config, validate_config\nfrom axolotl.utils.dict import DictDefault\n\n\n@pytest.fixture(name=\"profiler_base_cfg\")\ndef fixture_profiler_base_cfg():\n    cfg = DictDefault(\n        base_model=\"HuggingFaceTB/SmolLM2-135M\",\n        tokenizer_type=\"AutoTokenizer\",\n        sequence_len=1024,\n        load_in_8bit=True,\n        adapter=\"lora\",\n        lora_r=8,\n        lora_alpha=16,\n        lora_dropout=0.05,\n        lora_target_linear=True,\n        val_set_size=0.02,\n        special_tokens={\"pad_token\": \"<|endoftext|>\"},\n        datasets=[\n            {\n                \"path\": \"mhenrichsen/alpaca_2k_test\",\n                \"type\": \"alpaca\",\n            },\n        ],\n        num_epochs=1,\n        micro_batch_size=2,\n        gradient_accumulation_steps=1,\n        learning_rate=0.00001,\n        optimizer=\"adamw_torch_fused\",\n        lr_scheduler=\"cosine\",\n    )\n    return cfg\n\n\nclass TestProfiler:\n    \"\"\"\n    test cases for the pytorch profiler callback\n    \"\"\"\n\n    def test_profiler_saves(self, profiler_base_cfg, temp_dir):\n        cfg = profiler_base_cfg | DictDefault(\n            output_dir=temp_dir,\n            max_steps=5,\n            profiler_steps=3,\n        )\n\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        assert (Path(temp_dir) / \"snapshot.pickle\").exists()\n\n    def test_profiler_saves_w_start(self, profiler_base_cfg, temp_dir):\n        cfg = profiler_base_cfg | DictDefault(\n            output_dir=temp_dir,\n            max_steps=5,\n            profiler_steps=3,\n            profiler_steps_start=1,\n        )\n\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        assert (Path(temp_dir) / \"snapshot.pickle\").exists()\n\n    @pytest.mark.parametrize(\n        \"profiler_steps_start\",\n        [3, 5],\n    )\n    def test_profiler_saves_past_end(\n        self, profiler_base_cfg, temp_dir, profiler_steps_start\n    ):\n        cfg = profiler_base_cfg | DictDefault(\n            output_dir=temp_dir,\n            max_steps=5,\n            profiler_steps=3,\n            profiler_steps_start=profiler_steps_start,\n        )\n\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        assert (Path(temp_dir) / \"snapshot.pickle\").exists()\n\n    def test_profiler_never_started(self, profiler_base_cfg, temp_dir):\n        cfg = profiler_base_cfg | DictDefault(\n            output_dir=temp_dir,\n            max_steps=5,\n            profiler_steps=3,\n            profiler_steps_start=6,\n        )\n\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        assert not (Path(temp_dir) / \"snapshot.pickle\").exists()\n"
  },
  {
    "path": "tests/e2e/test_qat.py",
    "content": "\"\"\"\nE2E tests for QAT\n\"\"\"\n\nfrom pathlib import Path\n\nfrom axolotl.common.datasets import load_datasets, load_preference_datasets\nfrom axolotl.train import train\nfrom axolotl.utils.config import normalize_config, validate_config\nfrom axolotl.utils.dict import DictDefault\nfrom axolotl.utils.schemas.enums import TorchAOQuantDType\nfrom axolotl.utils.schemas.quantization import QATConfig, validate_ao_dtype\n\nfrom .utils import check_model_output_exists, check_tensorboard\n\n\nclass TestQATLlama:\n    \"\"\"\n    Test case for QAT Llama models\n    \"\"\"\n\n    def test_qat(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"tokenizer_type\": \"AutoTokenizer\",\n                \"sequence_len\": 1024,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"mlabonne/FineTome-100k\",\n                        \"type\": \"chat_template\",\n                        \"field_messages\": \"conversations\",\n                        \"message_property_mappings\": {\n                            \"role\": \"from\",\n                            \"content\": \"value\",\n                        },\n                        \"drop_system_message\": True,\n                        \"split\": \"train[:1%]\",\n                    },\n                ],\n                \"chat_template\": \"chatml\",\n                \"qat\": {\n                    \"quantize_embedding\": True,\n                    \"activation_dtype\": \"int8\",\n                    \"weight_dtype\": \"int4\",\n                    \"group_size\": 8,\n                },\n                \"num_epochs\": 1,\n                \"micro_batch_size\": 1,\n                \"gradient_accumulation_steps\": 2,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_bnb_8bit\",\n                \"lr_scheduler\": \"cosine\",\n                \"max_steps\": 5,\n                \"bf16\": True,\n                \"save_first_step\": False,\n            }\n        )\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        check_model_output_exists(Path(temp_dir) / \"checkpoint-5\", cfg)\n\n    def test_qat_dpo(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"sequence_len\": 2048,\n                \"sample_packing\": False,\n                \"eval_sample_packing\": False,\n                \"pad_to_sequence_len\": True,\n                \"val_set_size\": 0.01,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"rl\": \"dpo\",\n                \"chat_template\": \"chatml\",\n                \"datasets\": [\n                    {\n                        \"path\": \"fozziethebeat/alpaca_messages_2k_dpo_test\",\n                        \"type\": \"chat_template.default\",\n                        \"field_messages\": \"conversation\",\n                        \"field_chosen\": \"chosen\",\n                        \"field_rejected\": \"rejected\",\n                        \"message_field_role\": \"role\",\n                        \"message_field_content\": \"content\",\n                        \"roles\": {\n                            \"system\": [\"system\"],\n                            \"user\": [\"user\"],\n                            \"assistant\": [\"assistant\"],\n                        },\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"max_steps\": 5,\n                \"micro_batch_size\": 2,\n                \"gradient_accumulation_steps\": 2,\n                \"output_dir\": temp_dir,\n                \"warmup_steps\": 0,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_torch_fused\",\n                \"lr_scheduler\": \"cosine\",\n                \"flash_attention\": True,\n                \"use_tensorboard\": True,\n                \"bf16\": True,\n                \"qat\": {\n                    \"quantize_embedding\": True,\n                    \"activation_dtype\": \"int8\",\n                    \"weight_dtype\": \"int4\",\n                    \"group_size\": 8,\n                },\n                \"save_first_step\": False,\n            }\n        )\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_preference_datasets(cfg=cfg)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        check_model_output_exists(Path(temp_dir) / \"checkpoint-5\", cfg)\n\n        loss_threshold = 2.3\n        check_tensorboard(\n            temp_dir + \"/runs\",\n            \"train/train_loss\",\n            loss_threshold,\n            \"Train Loss (%s) is too high\",\n        )\n\n\nclass TestMXFP4Schema:\n    \"\"\"Test MXFP4 schema validation\"\"\"\n\n    def test_validate_mxfp4_dtype(self):\n        result = validate_ao_dtype(\"mxfp4\")\n        assert result == TorchAOQuantDType.mxfp4\n\n    def test_qat_config_with_mxfp4(self):\n        \"\"\"Test QATConfig accepts mxfp4 weight_dtype\"\"\"\n        config = QATConfig(\n            weight_dtype=\"mxfp4\",\n            group_size=32,\n            quantize_embedding=False,\n        )\n        assert config.weight_dtype == TorchAOQuantDType.mxfp4\n        assert config.group_size == 32\n\n    def test_qat_config_mxfp4_invalid_group_size(self):\n        \"\"\"Test that invalid group_size raises appropriate error during quantization\"\"\"\n        # Note: Schema validation doesn't check group_size compatibility,\n        # that happens in get_quantization_config\n        config = QATConfig(\n            weight_dtype=\"mxfp4\",\n            group_size=16,  # Invalid for mxfp4, but schema allows it\n        )\n        assert config.group_size == 16  # Schema accepts it\n        # Actual validation happens at runtime in get_quantization_config\n"
  },
  {
    "path": "tests/e2e/test_quantization.py",
    "content": "\"\"\"\nTests for axolotl.utils.quantization\n\"\"\"\n\nimport pytest\nimport torch\nfrom torch import nn\nfrom torchao.prototype.qat import MXFakeQuantizeConfig\nfrom torchao.quantization import LinearActivationQuantizedTensor\nfrom torchao.quantization.qat.embedding import FakeQuantizedEmbedding\nfrom torchao.quantization.qat.linear import FakeQuantizedLinear\nfrom torchao.quantization.quant_api import (\n    Float8DynamicActivationFloat8WeightConfig,\n    Float8DynamicActivationInt4WeightConfig,\n    Int8DynamicActivationInt4WeightConfig,\n)\nfrom torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor\nfrom transformers import AutoModelForCausalLM\nfrom transformers.trainer_callback import TrainerState\n\nfrom axolotl.utils.callbacks.qat import QATCallback\nfrom axolotl.utils.quantization import (\n    convert_qat_model,\n    get_quantization_config,\n    prepare_model_for_qat,\n    quantize_model,\n)\nfrom axolotl.utils.schemas.enums import TorchAOQuantDType\nfrom axolotl.utils.schemas.quantization import QATConfig\n\nfrom tests.e2e.utils import (\n    require_torch_2_8_0,\n    requires_cuda_ge_8_9,\n    requires_sm_ge_100,\n)\n\n\n@pytest.fixture()\ndef model():\n    dummy_model = AutoModelForCausalLM.from_pretrained(\n        \"Qwen/Qwen2-0.5B\",\n        device_map=\"auto\",\n        dtype=torch.bfloat16,\n    )\n    with torch.device(dummy_model.device):\n        dummy_model.model.embed_tokens = torch.nn.Embedding(\n            dummy_model.model.embed_tokens.weight.shape[0],\n            dummy_model.model.embed_tokens.weight.shape[1],\n            dtype=dummy_model.model.embed_tokens.weight.dtype,\n        )\n    yield dummy_model\n    del dummy_model\n\n\nptq_config_test_cases = [\n    # weight_dtype, activation_dtype, group_size, expected_type\n    (\n        TorchAOQuantDType.int4,\n        TorchAOQuantDType.int8,\n        None,\n        Int8DynamicActivationInt4WeightConfig,\n    ),\n    (\n        TorchAOQuantDType.float8_e4m3fn,\n        TorchAOQuantDType.float8_e4m3fn,\n        None,\n        Float8DynamicActivationFloat8WeightConfig,\n    ),\n    (\n        TorchAOQuantDType.int4,\n        TorchAOQuantDType.float8_e4m3fn,\n        None,\n        Float8DynamicActivationInt4WeightConfig,\n    ),\n]\n\nptq_test_cases = [\n    # weight_dtype, activation_dtype, group_size, quantize_embedding, expected_exception, expected_tensor_class\n    (TorchAOQuantDType.int4, None, 4, True, None, Int4Tensor),\n    (\n        TorchAOQuantDType.int4,\n        TorchAOQuantDType.int8,\n        8,\n        False,\n        None,\n        LinearActivationQuantizedTensor,\n    ),\n    # (\n    #     TorchAOQuantDType.int4,\n    #     TorchAOQuantDType.float8_e4m3fn,\n    #     None,\n    #     False,\n    #     None,\n    #     Int4Tensor,\n    # ),\n    (TorchAOQuantDType.int4, None, None, False, None, Int4Tensor),\n    # Deprecated configs\n    (TorchAOQuantDType.int8, None, 8, False, ValueError, None),\n    (TorchAOQuantDType.int4, TorchAOQuantDType.int4, 8, False, ValueError, None),\n    (TorchAOQuantDType.int8, TorchAOQuantDType.int8, 8, True, ValueError, None),\n]\n\n\nclass TestQuantization:\n    \"\"\"\n    Test quantization utilities\n    \"\"\"\n\n    @pytest.mark.parametrize(\n        \"weight_dtype,activation_dtype,group_size,expected_type\",\n        ptq_config_test_cases,\n    )\n    @requires_cuda_ge_8_9\n    @require_torch_2_8_0\n    def test_get_ptq_config(\n        self, weight_dtype, activation_dtype, group_size, expected_type\n    ):\n        config = get_quantization_config(weight_dtype, activation_dtype, group_size)\n        assert isinstance(config, expected_type)\n\n    @require_torch_2_8_0\n    @requires_sm_ge_100\n    def test_get_ptq_config_mxfp4(self):\n        config = get_quantization_config(TorchAOQuantDType.mxfp4, None, 32)\n        assert isinstance(config, MXFakeQuantizeConfig)\n        assert config.block_size == 32\n\n    @require_torch_2_8_0\n    @requires_sm_ge_100\n    def test_get_ptq_config_mxfp4_invalid_group_size(self):\n        with pytest.raises(\n            ValueError, match=\"MXFP4 quantization must use a block_size\"\n        ):\n            get_quantization_config(TorchAOQuantDType.mxfp4, None, 16)\n\n    @requires_cuda_ge_8_9\n    @require_torch_2_8_0\n    def test_get_ptq_config_int4_weight_only(self):\n        from torchao.quantization.quant_api import Int4WeightOnlyConfig\n\n        config = get_quantization_config(TorchAOQuantDType.int4, None, 4)\n        assert isinstance(config, Int4WeightOnlyConfig)\n\n    @pytest.mark.parametrize(\n        \"weight_dtype,activation_dtype,group_size,quantize_embedding,expected_exception,expected_tensor_class\",\n        ptq_test_cases,\n    )\n    @requires_cuda_ge_8_9\n    @require_torch_2_8_0\n    def test_quantize_model_for_ptq(\n        self,\n        model,\n        weight_dtype,\n        activation_dtype,\n        group_size,\n        quantize_embedding,\n        expected_exception,\n        expected_tensor_class,\n    ):\n        if expected_exception:\n            with pytest.raises(expected_exception):\n                quantize_model(\n                    model,\n                    weight_dtype,\n                    group_size,\n                    activation_dtype,\n                    quantize_embedding,\n                )\n        else:\n            quantize_model(\n                model, weight_dtype, group_size, activation_dtype, quantize_embedding\n            )\n            if quantize_embedding:\n                assert isinstance(\n                    model.model.embed_tokens.weight, expected_tensor_class\n                ), \"Embedding weight should be quantized\"\n            for child in list(model.children()):\n                if isinstance(child, torch.nn.Linear):\n                    assert isinstance(child.weight, expected_tensor_class)\n\n    @require_torch_2_8_0\n    @requires_sm_ge_100\n    def test_quantize_model_for_ptq_fp8(\n        self,\n        model,\n    ):\n        from torchao.quantization.quantize_.workflows.float8.float8_tensor import (\n            Float8Tensor,\n            QuantizeTensorToFloat8Kwargs,\n        )\n\n        quantize_model(\n            model,\n            TorchAOQuantDType.float8_e4m3fn,\n            None,\n            TorchAOQuantDType.float8_e4m3fn,\n        )\n        for child in list(model.children()):\n            if isinstance(child, torch.nn.Linear):\n                assert isinstance(child.weight, Float8Tensor)\n                assert child.weight.act_quant_kwargs is not None and isinstance(\n                    child.weight.act_quant_kwargs, QuantizeTensorToFloat8Kwargs\n                )\n\n    @require_torch_2_8_0\n    @requires_sm_ge_100\n    def test_quantize_model_for_ptq_nvfp4(\n        self,\n        model,\n    ):\n        from torchao.prototype.mx_formats.nvfp4_tensor import (\n            NVFP4Tensor,\n            QuantizeTensorToNVFP4Kwargs,\n        )\n\n        quantize_model(model, TorchAOQuantDType.nvfp4, 16, TorchAOQuantDType.nvfp4)\n        for child in list(model.children()):\n            if isinstance(child, torch.nn.Linear):\n                assert isinstance(child.weight, NVFP4Tensor)\n                assert child.weight.act_quant_kwargs is not None and isinstance(\n                    child.weight.act_quant_kwargs, QuantizeTensorToNVFP4Kwargs\n                )\n\n    @pytest.mark.parametrize(\n        \"weight_dtype,activation_dtype,group_size,quantize_embedding\",\n        [\n            (TorchAOQuantDType.int4, None, 8, False),\n            (TorchAOQuantDType.int4, None, 16, True),\n            (TorchAOQuantDType.int4, TorchAOQuantDType.int8, 8, False),\n            (TorchAOQuantDType.int4, TorchAOQuantDType.int8, 16, True),\n            (\n                TorchAOQuantDType.float8_e4m3fn,\n                TorchAOQuantDType.float8_e4m3fn,\n                None,\n                False,\n            ),\n            (TorchAOQuantDType.int4, TorchAOQuantDType.float8_e4m3fn, None, True),\n        ],\n    )\n    @require_torch_2_8_0\n    @requires_cuda_ge_8_9\n    def test_prepare_model_for_qat(\n        self, model, weight_dtype, activation_dtype, group_size, quantize_embedding\n    ):\n        prepare_model_for_qat(\n            model,\n            weight_dtype,\n            group_size,\n            activation_dtype,\n            quantize_embedding,\n        )\n        if quantize_embedding:\n            assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)\n            assert hasattr(model.model.embed_tokens, \"weight_fake_quantizer\")\n            assert (\n                model.model.embed_tokens.weight_fake_quantizer.config.dtype\n                == weight_dtype.value\n            )\n            if group_size:\n                assert (\n                    model.model.embed_tokens.weight_fake_quantizer.config.group_size\n                    == group_size\n                )\n\n        for child in list(model.children()):\n            if isinstance(child, torch.nn.Linear):\n                assert isinstance(child, FakeQuantizedLinear)\n                assert hasattr(child, \"weight_fake_quantizer\")\n                assert child.weight_fake_quantizer.config.dtype == weight_dtype.value\n                if group_size:\n                    assert child.weight_fake_quantizer.config.group_size == group_size\n                if activation_dtype:\n                    assert hasattr(child, \"activation_fake_quantizer\")\n                    assert (\n                        child.activation_fake_quantizer.config.dtype\n                        == activation_dtype.value\n                    )\n                else:\n                    assert child.activation_fake_quantizer is None\n\n    @pytest.mark.parametrize(\n        \"weight_dtype,activation_dtype,group_size,quantize_embedding\",\n        [\n            (TorchAOQuantDType.mxfp4, None, 32, False),\n            (TorchAOQuantDType.mxfp4, None, 32, True),\n        ],\n    )\n    @require_torch_2_8_0\n    @requires_sm_ge_100\n    def test_prepare_model_for_qat_mxfp4(\n        self, model, weight_dtype, activation_dtype, group_size, quantize_embedding\n    ):\n        prepare_model_for_qat(\n            model,\n            weight_dtype,\n            group_size,\n            activation_dtype,\n            quantize_embedding,\n        )\n\n        if quantize_embedding:\n            assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)\n            assert hasattr(model.model.embed_tokens, \"weight_fake_quantizer\")\n\n        for child in list(model.children()):\n            if isinstance(child, torch.nn.Linear):\n                assert isinstance(child, FakeQuantizedLinear)\n                assert hasattr(child, \"weight_fake_quantizer\")\n\n    @require_torch_2_8_0\n    @requires_cuda_ge_8_9\n    def test_convert_qat_model(self, model):\n        config = QATConfig(\n            weight_dtype=\"int4\",\n            activation_dtype=\"int8\",\n            group_size=8,\n            quantize_embedding=True,\n        )\n\n        # quantize model for qat\n        prepare_model_for_qat(\n            model,\n            config.weight_dtype,\n            config.group_size,\n            config.activation_dtype,\n            config.quantize_embedding,\n        )\n\n        assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)\n        assert isinstance(model.lm_head, FakeQuantizedLinear)\n\n        # apply conversion\n        convert_qat_model(\n            model,\n            config.quantize_embedding,\n        )\n        # ensure modules have been swapped out\n        assert not isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)\n        assert not isinstance(model.lm_head, FakeQuantizedLinear)\n\n        # ensure weights have been quantized\n        assert isinstance(model.model.embed_tokens.weight, nn.Parameter)\n        assert isinstance(model.lm_head.weight, nn.Parameter)\n\n\nclass TestQuantizationCallback:\n    \"\"\"\n    Test QATCallback\n    \"\"\"\n\n    @pytest.fixture()\n    def trainer_state(self):\n        return TrainerState(\n            global_step=0,\n        )\n\n    @require_torch_2_8_0\n    def test_qat_callback_fake_quant_after_n_steps(self, model, trainer_state):\n        cfg = QATConfig(\n            weight_dtype=\"int4\",\n            activation_dtype=\"int8\",\n            group_size=8,\n            quantize_embedding=True,\n            fake_quant_after_n_steps=100,\n        )\n\n        prepare_model_for_qat(\n            model,\n            cfg.weight_dtype,\n            cfg.group_size,\n            cfg.activation_dtype,\n            cfg.quantize_embedding,\n        )\n\n        # ensure model has been quantized\n        assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)\n        assert model.model.embed_tokens.weight_fake_quantizer.enabled\n        assert isinstance(model.lm_head, FakeQuantizedLinear)\n        assert model.lm_head.weight_fake_quantizer.enabled\n\n        qat_callback = QATCallback(cfg)\n\n        # simulate first training step\n        qat_callback.on_step_begin(\n            args=None,\n            state=trainer_state,\n            control=None,\n            model=model,\n        )\n\n        # quantization should have been disabled\n        assert not model.model.embed_tokens.weight_fake_quantizer.enabled\n        assert not model.lm_head.weight_fake_quantizer.enabled\n\n        trainer_state.global_step = 100\n        qat_callback.on_step_begin(\n            args=None,\n            state=trainer_state,\n            control=None,\n            model=model,\n        )\n\n        # quantization should have been enabled\n        assert model.model.embed_tokens.weight_fake_quantizer.enabled\n        assert model.lm_head.weight_fake_quantizer.enabled\n\n    @require_torch_2_8_0\n    def test_qat_callback_fake_quant_after_n_steps_is_none(self, model, trainer_state):\n        cfg = QATConfig(\n            weight_dtype=\"int4\",\n            activation_dtype=\"int8\",\n            group_size=8,\n            quantize_embedding=True,\n            fake_quant_after_n_steps=None,\n        )\n\n        prepare_model_for_qat(\n            model,\n            cfg.weight_dtype,\n            cfg.group_size,\n            cfg.activation_dtype,\n            cfg.quantize_embedding,\n        )\n\n        # ensure model has been quantized\n        assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)\n        assert model.model.embed_tokens.weight_fake_quantizer.enabled\n        assert isinstance(model.lm_head, FakeQuantizedLinear)\n        assert model.lm_head.weight_fake_quantizer.enabled\n\n        qat_callback = QATCallback(cfg)\n        # simulate first training step\n        qat_callback.on_step_begin(\n            args=None,\n            state=trainer_state,\n            control=None,\n            model=model,\n        )\n\n        # quantization should be enabled from the get-go\n        assert model.model.embed_tokens.weight_fake_quantizer.enabled\n        assert model.lm_head.weight_fake_quantizer.enabled\n"
  },
  {
    "path": "tests/e2e/test_qwen.py",
    "content": "\"\"\"\nE2E tests for qwen\n\"\"\"\n\nfrom pathlib import Path\n\nimport pytest\nimport yaml\nfrom accelerate.test_utils import execute_subprocess_async\nfrom transformers.testing_utils import get_torch_dist_unique_port\n\nfrom axolotl.utils.dict import DictDefault\n\n\nclass TestE2eQwen:\n    \"\"\"\n    Test cases for qwen models\n    \"\"\"\n\n    @pytest.mark.parametrize(\"base_model\", [\"Qwen/Qwen2-0.5B\", \"Qwen/Qwen2.5-0.5B\"])\n    def test_dpo(self, base_model, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": base_model,\n                \"rl\": \"dpo\",\n                \"chat_template\": \"qwen_25\",\n                \"sequence_len\": 2048,\n                \"val_set_size\": 0.0,\n                \"datasets\": [\n                    {\n                        \"path\": \"fozziethebeat/alpaca_messages_2k_dpo_test\",\n                        \"split\": \"train\",\n                        \"type\": \"chat_template.default\",\n                        \"field_messages\": \"conversation\",\n                        \"field_chosen\": \"chosen\",\n                        \"field_rejected\": \"rejected\",\n                        \"message_property_mappings\": {\n                            \"role\": \"role\",\n                            \"content\": \"content\",\n                        },\n                        \"roles\": {\n                            \"system\": [\"system\"],\n                            \"user\": [\"user\"],\n                            \"assistant\": [\"assistant\"],\n                        },\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"max_steps\": 5,\n                \"warmup_steps\": 20,\n                \"micro_batch_size\": 2,\n                \"gradient_accumulation_steps\": 2,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_bnb_8bit\",\n                \"lr_scheduler\": \"cosine\",\n                \"flash_attention\": True,\n                \"bf16\": \"auto\",\n                \"tf32\": True,\n                \"gradient_checkpointing\": True,\n                \"save_first_step\": False,\n            }\n        )\n\n        # write cfg to yaml file\n        Path(temp_dir).mkdir(parents=True, exist_ok=True)\n        with open(Path(temp_dir) / \"config.yaml\", \"w\", encoding=\"utf-8\") as fout:\n            fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))\n\n        execute_subprocess_async(\n            [\n                \"accelerate\",\n                \"launch\",\n                \"--num-processes\",\n                \"2\",\n                \"--main_process_port\",\n                f\"{get_torch_dist_unique_port()}\",\n                \"-m\",\n                \"axolotl.cli.train\",\n                str(Path(temp_dir) / \"config.yaml\"),\n            ]\n        )\n"
  },
  {
    "path": "tests/e2e/test_reward_model_smollm2.py",
    "content": "\"\"\"\nE2E tests for reward model lora llama\n\"\"\"\n\nimport unittest\n\nfrom axolotl.common.datasets import load_datasets\nfrom axolotl.train import train\nfrom axolotl.utils.config import normalize_config, validate_config\nfrom axolotl.utils.dict import DictDefault\n\nfrom .utils import check_model_output_exists, check_tensorboard, with_temp_dir\n\n\nclass TestRewardModelLoraSmolLM2(unittest.TestCase):\n    \"\"\"\n    Test case for Llama reward models using LoRA\n    \"\"\"\n\n    @with_temp_dir\n    def test_rm_lora(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"model_type\": \"AutoModelForSequenceClassification\",\n                \"num_labels\": 1,\n                \"chat_template\": \"alpaca\",\n                \"reward_model\": True,\n                \"sequence_len\": 2048,\n                \"pad_to_sequence_len\": True,\n                \"adapter\": \"lora\",\n                \"lora_r\": 8,\n                \"lora_alpha\": 16,\n                \"lora_dropout\": 0.05,\n                \"lora_target_linear\": True,\n                \"val_set_size\": 0.0,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"argilla/distilabel-intel-orca-dpo-pairs\",\n                        \"type\": \"bradley_terry.chat_template\",\n                        \"split\": \"train[:10%]\",\n                    },\n                ],\n                \"lora_modules_to_save\": [\"embed_tokens\", \"lm_head\"],\n                \"remove_unused_columns\": False,\n                \"max_steps\": 10,\n                \"num_epochs\": 1,\n                \"micro_batch_size\": 4,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_torch\",\n                \"lr_scheduler\": \"cosine\",\n                \"gradient_checkpointing\": True,\n                \"warmup_ratio\": 0.1,\n                \"use_tensorboard\": True,\n                \"save_first_step\": False,\n            }\n        )\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        check_tensorboard(\n            temp_dir + \"/runs\", \"train/train_loss\", 2.5, \"Train Loss (%s) is too high\"\n        )\n        check_model_output_exists(temp_dir, cfg)\n"
  },
  {
    "path": "tests/e2e/test_save_first_step.py",
    "content": "\"\"\"\nE2E tests for relora llama\n\"\"\"\n\nimport unittest\nfrom pathlib import Path\n\nimport pytest\n\nfrom axolotl.common.datasets import load_datasets\nfrom axolotl.train import train\nfrom axolotl.utils.config import normalize_config, validate_config\nfrom axolotl.utils.dict import DictDefault\n\nfrom .utils import check_model_output_exists, with_temp_dir\n\n\nclass TestSaveFirstStepCallback(unittest.TestCase):\n    \"\"\"Test cases for save_first_step callback config.\"\"\"\n\n    @with_temp_dir\n    def test_save_first_step(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"tokenizer_type\": \"AutoTokenizer\",\n                \"sequence_len\": 512,\n                \"val_set_size\": 0.02,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": \"alpaca\",\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"max_steps\": 3,\n                \"micro_batch_size\": 2,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_bnb_8bit\",\n                \"lr_scheduler\": \"cosine\",\n                \"flash_attention\": True,\n                \"sample_packing\": True,\n                \"bf16\": True,\n                \"save_first_step\": True,\n            }\n        )\n\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        check_model_output_exists(str(Path(temp_dir) / \"checkpoint-1\"), cfg)\n\n    @with_temp_dir\n    def test_no_save_first_step(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"tokenizer_type\": \"AutoTokenizer\",\n                \"sequence_len\": 512,\n                \"val_set_size\": 0.02,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": \"alpaca\",\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"max_steps\": 3,\n                \"micro_batch_size\": 2,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_bnb_8bit\",\n                \"lr_scheduler\": \"cosine\",\n                \"flash_attention\": True,\n                \"sample_packing\": True,\n                \"bf16\": True,\n                \"save_first_step\": False,\n            }\n        )\n\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        with pytest.raises(AssertionError):\n            check_model_output_exists(str(Path(temp_dir) / \"checkpoint-1\"), cfg)\n"
  },
  {
    "path": "tests/e2e/test_schedulers.py",
    "content": "\"\"\"\nE2E tests for custom schedulers using Llama\n\"\"\"\n\nimport unittest\n\nfrom axolotl.common.datasets import load_datasets\nfrom axolotl.train import train\nfrom axolotl.utils.config import normalize_config, validate_config\nfrom axolotl.utils.dict import DictDefault\n\nfrom .utils import check_model_output_exists, with_temp_dir\n\n\nclass TestCustomSchedulers(unittest.TestCase):\n    \"\"\"\n    Test case for Llama models using LoRA\n    \"\"\"\n\n    @with_temp_dir\n    def test_rex_scheduler(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"tokenizer_type\": \"AutoTokenizer\",\n                \"sequence_len\": 1024,\n                \"load_in_8bit\": True,\n                \"adapter\": \"lora\",\n                \"lora_r\": 8,\n                \"lora_alpha\": 16,\n                \"lora_dropout\": 0.05,\n                \"lora_target_linear\": True,\n                \"val_set_size\": 0.02,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": \"alpaca\",\n                    },\n                ],\n                \"num_epochs\": 1,\n                \"micro_batch_size\": 8,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_torch_fused\",\n                \"max_steps\": 20,\n                \"lr_scheduler\": \"rex\",\n                \"warmup_steps\": 5,\n                \"cosine_min_lr_ratio\": 0.05,\n                \"save_first_step\": False,\n            }\n        )\n\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        check_model_output_exists(temp_dir, cfg)\n"
  },
  {
    "path": "tests/e2e/test_streaming.py",
    "content": "\"\"\"E2E tests for streaming dataset functionality\"\"\"\n\n# pylint: disable=duplicate-code\n\nimport pytest\n\nfrom axolotl.common.datasets import load_datasets\nfrom axolotl.train import train\nfrom axolotl.utils.config import normalize_config, validate_config\nfrom axolotl.utils.dict import DictDefault\n\nfrom .utils import check_model_output_exists, check_tensorboard\n\n\nclass TestStreamingDatasets:\n    \"\"\"Test case for streaming datasets\"\"\"\n\n    @pytest.mark.parametrize(\n        \"sample_packing\",\n        [True, False],\n    )\n    def test_streaming_dataset(self, temp_dir, sample_packing):\n        \"\"\"Test streaming datasets\"\"\"\n\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"flash_attention\": True,\n                \"sequence_len\": 1024,\n                \"sample_packing\": sample_packing,\n                \"pretrain_multipack_attn\": sample_packing,\n                \"streaming_multipack_buffer_size\": 10000,\n                \"dataset_num_proc\": 1,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": \"alpaca\",\n                    },\n                ],\n                # Streaming config\n                \"streaming\": True,\n                \"max_steps\": 3,\n                \"micro_batch_size\": 1,\n                \"gradient_accumulation_steps\": 1,\n                \"val_set_size\": 0.0,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_torch_fused\",\n                \"lr_scheduler\": \"cosine\",\n                \"bf16\": \"auto\",\n                \"use_tensorboard\": True,\n                \"save_first_step\": False,\n            }\n        )\n\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        dataset_meta = load_datasets(cfg=cfg)\n\n        train(cfg=cfg, dataset_meta=dataset_meta)\n        check_model_output_exists(temp_dir, cfg)\n\n        # Verify training actually happened by checking loss decrease\n        check_tensorboard(\n            temp_dir + \"/runs\",\n            \"train/train_loss\",\n            3.0,\n            \"Train Loss (%s) is too high\",\n        )\n"
  },
  {
    "path": "tests/e2e/test_tokenizer.py",
    "content": "\"\"\"\ne2e test for saving the tokenizer\n\"\"\"\n\nfrom unittest.mock import patch\n\nfrom axolotl.common.datasets import load_datasets\nfrom axolotl.train import train\nfrom axolotl.utils.config import normalize_config, validate_config\nfrom axolotl.utils.dict import DictDefault\n\nfrom tests.e2e.utils import check_model_output_exists\n\n\ndef test_tokenizer_no_save_jinja_files(temp_dir):\n    # pylint: disable=duplicate-code\n    cfg = DictDefault(\n        {\n            \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n            \"tokenizer_type\": \"AutoTokenizer\",\n            \"sequence_len\": 1024,\n            \"load_in_8bit\": True,\n            \"adapter\": \"lora\",\n            \"lora_r\": 8,\n            \"lora_alpha\": 16,\n            \"lora_dropout\": 0.05,\n            \"lora_target_linear\": True,\n            \"val_set_size\": 0.02,\n            \"special_tokens\": {\n                \"pad_token\": \"<|endoftext|>\",\n            },\n            \"chat_template\": \"chatml\",\n            \"datasets\": [\n                {\n                    \"path\": \"mhenrichsen/alpaca_2k_test\",\n                    \"type\": \"alpaca\",\n                },\n            ],\n            \"num_epochs\": 1,\n            \"micro_batch_size\": 2,\n            \"gradient_accumulation_steps\": 1,\n            \"output_dir\": temp_dir,\n            \"learning_rate\": 0.00001,\n            \"optimizer\": \"adamw_torch_fused\",\n            \"lr_scheduler\": \"cosine\",\n            \"max_steps\": 5,\n            \"save_first_step\": False,\n            \"fp16\": False,\n            \"tokenizer_save_jinja_files\": False,\n        }\n    )\n\n    cfg = validate_config(cfg)\n    normalize_config(cfg)\n    dataset_meta = load_datasets(cfg=cfg)\n\n    with patch(\"axolotl.train.execute_training\"):\n        train(cfg=cfg, dataset_meta=dataset_meta)\n\n    check_model_output_exists(temp_dir, cfg)\n    with open(f\"{temp_dir}/tokenizer_config.json\", \"r\", encoding=\"utf-8\") as f:\n        tokenizer_config = f.read()\n        assert \"chat_template\" in tokenizer_config\n"
  },
  {
    "path": "tests/e2e/utils.py",
    "content": "\"\"\"\nhelper utils for tests\n\"\"\"\n\nimport importlib.util\nimport os\nimport shutil\nimport tempfile\nimport unittest\nfrom functools import wraps\nfrom pathlib import Path\n\nimport torch\nfrom packaging import version\nfrom tbparse import SummaryReader\n\nfrom axolotl.utils.dict import DictDefault\n\n\ndef with_temp_dir(test_func):\n    @wraps(test_func)\n    def wrapper(*args, **kwargs):\n        # Create a temporary directory\n        temp_dir = tempfile.mkdtemp()\n        try:\n            # Pass the temporary directory to the test function\n            test_func(*args, temp_dir=temp_dir, **kwargs)\n        finally:\n            # Clean up the directory after the test\n            shutil.rmtree(temp_dir)\n\n    return wrapper\n\n\ndef most_recent_subdir(path):\n    base_path = Path(path)\n    subdirectories = [d for d in base_path.iterdir() if d.is_dir()]\n    if not subdirectories:\n        return None\n    subdir = max(subdirectories, key=os.path.getctime)\n\n    return subdir\n\n\ndef require_torch_2_4_1(test_case):\n    \"\"\"\n    Decorator marking a test that requires torch >= 2.5.1\n    \"\"\"\n\n    def is_min_2_4_1():\n        torch_version = version.parse(torch.__version__)\n        return torch_version >= version.parse(\"2.4.1\")\n\n    return unittest.skipUnless(is_min_2_4_1(), \"test requires torch>=2.4.1\")(test_case)\n\n\ndef require_torch_2_5_1(test_case):\n    \"\"\"\n    Decorator marking a test that requires torch >= 2.5.1\n    \"\"\"\n\n    def is_min_2_5_1():\n        torch_version = version.parse(torch.__version__)\n        return torch_version >= version.parse(\"2.5.1\")\n\n    return unittest.skipUnless(is_min_2_5_1(), \"test requires torch>=2.5.1\")(test_case)\n\n\ndef require_torch_2_6_0(test_case):\n    \"\"\"\n    Decorator marking a test that requires torch >= 2.6.0\n    \"\"\"\n\n    def is_min_2_6_0():\n        torch_version = version.parse(torch.__version__)\n        return torch_version >= version.parse(\"2.6.0\")\n\n    return unittest.skipUnless(is_min_2_6_0(), \"test requires torch>=2.6.0\")(test_case)\n\n\ndef require_torch_2_7_0(test_case):\n    \"\"\"\n    Decorator marking a test that requires torch >= 2.7.0\n    \"\"\"\n\n    def is_min_2_7_0():\n        torch_version = version.parse(torch.__version__)\n        return torch_version >= version.parse(\"2.7.0\")\n\n    return unittest.skipUnless(is_min_2_7_0(), \"test requires torch>=2.7.0\")(test_case)\n\n\ndef require_torch_2_8_0(test_case):\n    \"\"\"\n    Decorator marking a test that requires torch >= 2.7.0\n    \"\"\"\n\n    def is_min_2_8_0():\n        torch_version = version.parse(torch.__version__)\n        return torch_version >= version.parse(\"2.8.0\")\n\n    return unittest.skipUnless(is_min_2_8_0(), \"test requires torch>=2.8.0\")(test_case)\n\n\ndef require_torch_lt_2_6_0(test_case):\n    \"\"\"\n    Decorator marking a test that requires torch < 2.6.0\n    \"\"\"\n\n    def is_max_2_6_0():\n        torch_version = version.parse(torch.__version__)\n        return torch_version < version.parse(\"2.6.0\")\n\n    return unittest.skipUnless(is_max_2_6_0(), \"test requires torch<2.6.0\")(test_case)\n\n\ndef require_vllm(test_case):\n    \"\"\"\n    Decorator marking a test that requires a vllm to be installed\n    \"\"\"\n\n    def is_vllm_installed():\n        return importlib.util.find_spec(\"vllm\") is not None\n\n    return unittest.skipUnless(\n        is_vllm_installed(), \"test requires vllm to be installed\"\n    )(test_case)\n\n\ndef require_llmcompressor(test_case):\n    \"\"\"\n    Decorator marking a test that requires a llmcompressor to be installed\n    \"\"\"\n\n    def is_llmcompressor_installed():\n        return importlib.util.find_spec(\"llmcompressor\") is not None\n\n    return unittest.skipUnless(\n        is_llmcompressor_installed(), \"test requires llmcompressor to be installed\"\n    )(test_case)\n\n\ndef requires_sm_ge_100(test_case):\n    is_sm_ge_100 = (\n        torch.cuda.is_available()\n        and torch.version.cuda\n        and torch.cuda.get_device_capability() >= (10, 0)\n    )\n    return unittest.skipUnless(is_sm_ge_100, \"test requires sm>=100\")(test_case)\n\n\ndef requires_cuda_ge_8_9(test_case):\n    is_cuda_ge_8_9 = (\n        torch.cuda.is_available()\n        and torch.version.cuda\n        and torch.cuda.get_device_capability() >= (8, 9)\n    )\n    return unittest.skipUnless(is_cuda_ge_8_9, \"test requires cuda>=8.9\")(test_case)\n\n\ndef is_hopper():\n    compute_capability = torch.cuda.get_device_capability()\n    return compute_capability == (9, 0)\n\n\ndef require_hopper(test_case):\n    return unittest.skipUnless(is_hopper(), \"test requires h100/hopper GPU\")(test_case)\n\n\ndef supports_fp8(test_case):\n    compute_capability = torch.cuda.get_device_capability()\n    return unittest.skipUnless(\n        compute_capability >= (9, 0), \"test requires h100 or newer GPU\"\n    )(test_case)\n\n\ndef check_tensorboard(\n    temp_run_dir: str,\n    tag: str,\n    lt_val: float,\n    assertion_err: str,\n    rtol: float = 0.02,\n    gt_zero: bool = True,\n) -> None:\n    \"\"\"\n    helper function to parse and check tensorboard logs\n    \"\"\"\n    tb_log_path = most_recent_subdir(temp_run_dir)\n    event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])\n    reader = SummaryReader(event_file)\n    df = reader.scalars\n    df = df[(df.tag == tag)]\n    lt_val = (1 + rtol) * lt_val\n    if \"%s\" in assertion_err:\n        assert df.value.values[-1] < lt_val, assertion_err % df.value.values[-1]\n    else:\n        assert df.value.values[-1] < lt_val, assertion_err\n    if gt_zero:\n        assert df.value.values[-1] > 1e-5, \"Expected loss to be greater than zero\"\n\n\ndef check_model_output_exists(temp_dir: str, cfg: DictDefault) -> None:\n    \"\"\"\n    helper function to check if a model output file exists after training\n\n    checks based on adapter or not (always safetensors in Transformers V5)\n    \"\"\"\n\n    if not cfg.adapter:\n        assert (Path(temp_dir) / \"model.safetensors\").exists()\n    else:\n        assert (Path(temp_dir) / \"adapter_model.safetensors\").exists()\n"
  },
  {
    "path": "tests/fixtures/alpaca/alpaca.json",
    "content": "[\n  {\n    \"instruction\": \"You will be given a series of words. Output these words in reverse order, with each word on its own line.\",\n    \"input\": \"Words: ['Hello', 'world'].\",\n    \"output\": \"['world', 'Hello']\"\n  },\n  {\n    \"instruction\": \"In this task, you're given a short description of an event. Your job is to order the steps involved in the event from first to last. Note that there may be multiple correct answers for each event.\",\n    \"input\": \"Description: A man walks into a bar and orders a drink. He pays for his drink and leaves the bar.\",\n    \"output\": \"1. The man walks into the bar.\\n2. He orders a drink.\\n3. He pays for his drink.\\n4. He leaves the bar.\"\n  }\n]\n"
  },
  {
    "path": "tests/fixtures/conversation.json",
    "content": "{\n  \"conversations\": [\n    {\n      \"from\": \"human\",\n      \"value\": \"how could i implement a minesweeper algorithm that utilises algebraic topology to solve boards?\"\n    },\n    {\n      \"from\": \"gpt\",\n      \"value\": \"It's possible to use algebraic topology to solve Minesweeper, but it would likely be quite complex. Algebraic topology is a branch of mathematics that studies the properties of spaces that are preserved under continuous deformations, such as stretching and bending. In the context of Minesweeper, this could be used to analyze the structure of the game board and determine the locations of mines based on certain topological invariants. However, implementing this would require a deep understanding of both algebraic topology and the game of Minesweeper. It's not something that can be explained in a simple answer.\"\n    },\n    {\n      \"from\": \"human\",\n      \"value\": \"could you give me the gist of how it could though?\"\n    },\n    {\n      \"from\": \"gpt\",\n      \"value\": \"Sure. The basic idea would be to use algebraic topology to analyze the structure of the Minesweeper game board and determine the locations of mines based on certain topological invariants. This could involve representing the game board as a topological space and using topological invariants such as homotopy groups and Betti numbers to infer information about the location of mines on the board. However, as I mentioned before, implementing this would require a deep understanding of both algebraic topology and the game of Minesweeper, so it's not something that can be explained in a simple answer.\"\n    },\n    {\n      \"from\": \"human\",\n      \"value\": \"give me a specific example\"\n    },\n    {\n      \"from\": \"gpt\",\n      \"value\": \"As a specific example, suppose we have a Minesweeper game board with the following configuration:\\n```\\n 1  1  1  1  1\\n 1  ?  ?  ?  1\\n 1  ?  ?  ?  1\\n 1  ?  ?  ?  1\\n 1  1  1  1  1\\n```\\nIn this case, the game board can be represented as a topological space, where each square on the board corresponds to a point in the space. The mines on the board can be thought of as \\\"holes\\\" in the space, which are points that are not part of the space. Using algebraic topology, we can analyze the structure of this space and determine the locations of the mines based on certain topological invariants.\\n\\nFor instance, one way to do this is to compute the homotopy groups of the space. Homotopy groups are a way of describing the holes in a space, and they can be used to infer information about the locations of the mines on the Minesweeper board. In the case of the board above, we can compute the first homotopy group of the space, which describes the number and \\\"shape\\\" of the holes in the space. In this case, the first homotopy group would tell us that there are four holes in the space, arranged in a square shape.\\n\\nBased on this information, we can infer that the mines in the game are located in the four corners of the board, as shown below:\\n```\\n 1  1  1  1  1\\n 1  M  ?  ?  1\\n 1  ?  ?  ?  1\\n 1  ?  ?  M  1\\n 1  1  1  1  1\\n```\\nOf course, this is just one example of how algebraic topology could be used to solve Minesweeper. There are many other possible approaches, and the specific details of how to implement them would depend on the specific game board and the topological invariants being used. As I mentioned before, implementing a solution like this would require a deep understanding of both algebraic topology and the game of Minesweeper.\"\n    },\n    {\n      \"from\": \"human\",\n      \"value\": \"assuming i was someone with a deep understanding of both, could you explain the same concept to me with that assumption?\"\n    },\n    {\n      \"from\": \"gpt\",\n      \"value\": \"Sure. Assuming you have a deep understanding of both algebraic topology and the game of Minesweeper, here is how you might use algebraic topology to solve a Minesweeper game board:\\n\\n1. Represent the Minesweeper game board as a topological space. Each square on the board corresponds to a point in the space, and the mines on the board can be thought of as \\\"holes\\\" in the space, which are points that are not part of the space.\\n2. Use topological invariants to infer information about the locations of the mines on the board. There are many possible topological invariants that could be used for this, such as homotopy groups, Betti numbers, or persistent homology.\\n3. Use the information inferred from the topological invariants to determine the locations of the mines on the board. For example, if the first homotopy group of the space indicates that there are four holes in the space, arranged in a square shape, then we can infer that the mines are located in the four corners of the board.\\n4. Use this information to solve the Minesweeper game. Once you know the locations of the mines, you can use this information to avoid them and clear the board.\\n\\nOf course, this is just one possible approach to using algebraic topology to solve Minesweeper. There may be other ways to do it, depending on the specific game board and the topological invariants being used. However, this should give you a general idea of how algebraic topology can be used to solve Minesweeper.\"\n    },\n    {\n      \"from\": \"human\",\n      \"value\": \"how would i do it with persistent homology or any of the other topological invariants?\"\n    },\n    {\n      \"from\": \"gpt\",\n      \"value\": \"To use persistent homology or any of the other topological invariants to solve a Minesweeper game board, you would follow the same basic steps as described above, but with some differences in the specific details.\\n\\nFor instance, here is how you might use persistent homology to solve a Minesweeper game board:\\n\\n1. Represent the Minesweeper game board as a topological space. Each square on the board corresponds to a point in the space, and the mines on the board can be thought of as \\\"holes\\\" in the space, which are points that are not part of the space.\\n2. Compute the persistent homology of the space. Persistent homology is a way of describing the holes in a space over time, and it can be used to infer information about the locations of the mines on the Minesweeper board.\\n3. Use the information from the persistent homology to determine the locations of the mines on the board. For example, if the persistent homology indicates that there are four holes in the space, arranged in a square shape, then we can infer that the mines are located in the four corners of the board.\\n4. Use this information to solve the Minesweeper game. Once you know the locations of the mines, you can use this information to avoid them and clear the board.\\n\\nOf course, this is just one possible approach to using persistent homology to solve Minesweeper. There may be other ways to do it, depending on the specific game board and the specific details of the persistent homology computation. However, this should give you a general idea of how persistent homology can be used to solve Minesweeper.\"\n    },\n    {\n      \"from\": \"human\",\n      \"value\": \"how could i implement this in python?\"\n    },\n    {\n      \"from\": \"gpt\",\n      \"value\": \"To implement a Minesweeper solver that uses algebraic topology in Python, you would need to do the following:\\n\\n1. Install a Python package for algebraic topology. There are several options available, such as Gudhi, Dionysus, or scikit-tda. These packages provide algorithms and data structures for computing topological invariants, such as homotopy groups, Betti numbers, or persistent homology.\\n2. Represent the Minesweeper game board as a topological space. This could be done using a data structure provided by the algebraic topology package, such as a simplicial complex or a Vietoris-Rips complex. Each square on the board would correspond to a point in the space, and the mines on the board would be represented as \\\"holes\\\" in the space.\\n3. Compute the topological invariants of the space. This could be done using algorithms provided by the algebraic topology package. For example, you could compute the persistent homology of the space using the persistent homology algorithm.\\n4. Use the information from the topological invariants to determine the locations of the mines on the board. This could be done using a combination of mathematical reasoning and programming. For example, if the persistent homology indicates that there are four holes in the space, arranged in a square shape, then you could use a piece of code to infer that the mines are located in the four corners of the board.\\n5. Use this information to solve the Minesweeper game. Once you know the locations of the mines, you could use a piece of code to automatically click on the safe squares and clear the board.\\n\\nOf course, this is just one possible approach to implementing a Minesweeper solver that uses algebraic topology in Python. There may be other ways to do it, depending on the specific details of the implementation. However, this should give you a general idea of how to get started with such a project.\"\n    }\n  ]\n}\n"
  },
  {
    "path": "tests/fixtures/conversation.missingturns.json",
    "content": "{\n  \"conversations\": [\n    {\n      \"from\": \"human\",\n      \"value\": \"how could i implement a minesweeper algorithm that utilises algebraic topology to solve boards?\"\n    },\n    {\n      \"from\": \"gpt\",\n      \"value\": \"\"\n    },\n    {\n      \"from\": \"human\",\n      \"value\": \"could you give me the gist of how it could though?\"\n    },\n    {\n      \"from\": \"gpt\",\n      \"value\": \"Sure. The basic idea would be to use algebraic topology to analyze the structure of the Minesweeper game board and determine the locations of mines based on certain topological invariants. This could involve representing the game board as a topological space and using topological invariants such as homotopy groups and Betti numbers to infer information about the location of mines on the board. However, as I mentioned before, implementing this would require a deep understanding of both algebraic topology and the game of Minesweeper, so it's not something that can be explained in a simple answer.\"\n    },\n    {\n      \"from\": \"human\",\n      \"value\": \"give me a specific example\"\n    },\n    {\n      \"from\": \"gpt\",\n      \"value\": \"As a specific example, suppose we have a Minesweeper game board with the following configuration:\\n```\\n 1  1  1  1  1\\n 1  ?  ?  ?  1\\n 1  ?  ?  ?  1\\n 1  ?  ?  ?  1\\n 1  1  1  1  1\\n```\\nIn this case, the game board can be represented as a topological space, where each square on the board corresponds to a point in the space. The mines on the board can be thought of as \\\"holes\\\" in the space, which are points that are not part of the space. Using algebraic topology, we can analyze the structure of this space and determine the locations of the mines based on certain topological invariants.\\n\\nFor instance, one way to do this is to compute the homotopy groups of the space. Homotopy groups are a way of describing the holes in a space, and they can be used to infer information about the locations of the mines on the Minesweeper board. In the case of the board above, we can compute the first homotopy group of the space, which describes the number and \\\"shape\\\" of the holes in the space. In this case, the first homotopy group would tell us that there are four holes in the space, arranged in a square shape.\\n\\nBased on this information, we can infer that the mines in the game are located in the four corners of the board, as shown below:\\n```\\n 1  1  1  1  1\\n 1  M  ?  ?  1\\n 1  ?  ?  ?  1\\n 1  ?  ?  M  1\\n 1  1  1  1  1\\n```\\nOf course, this is just one example of how algebraic topology could be used to solve Minesweeper. There are many other possible approaches, and the specific details of how to implement them would depend on the specific game board and the topological invariants being used. As I mentioned before, implementing a solution like this would require a deep understanding of both algebraic topology and the game of Minesweeper.\"\n    },\n    {\n      \"from\": \"human\",\n      \"value\": \"assuming i was someone with a deep understanding of both, could you explain the same concept to me with that assumption?\"\n    },\n    {\n      \"from\": \"gpt\",\n      \"value\": \"Sure. Assuming you have a deep understanding of both algebraic topology and the game of Minesweeper, here is how you might use algebraic topology to solve a Minesweeper game board:\\n\\n1. Represent the Minesweeper game board as a topological space. Each square on the board corresponds to a point in the space, and the mines on the board can be thought of as \\\"holes\\\" in the space, which are points that are not part of the space.\\n2. Use topological invariants to infer information about the locations of the mines on the board. There are many possible topological invariants that could be used for this, such as homotopy groups, Betti numbers, or persistent homology.\\n3. Use the information inferred from the topological invariants to determine the locations of the mines on the board. For example, if the first homotopy group of the space indicates that there are four holes in the space, arranged in a square shape, then we can infer that the mines are located in the four corners of the board.\\n4. Use this information to solve the Minesweeper game. Once you know the locations of the mines, you can use this information to avoid them and clear the board.\\n\\nOf course, this is just one possible approach to using algebraic topology to solve Minesweeper. There may be other ways to do it, depending on the specific game board and the topological invariants being used. However, this should give you a general idea of how algebraic topology can be used to solve Minesweeper.\"\n    },\n    {\n      \"from\": \"human\",\n      \"value\": \"how would i do it with persistent homology or any of the other topological invariants?\"\n    },\n    {\n      \"from\": \"gpt\",\n      \"value\": \"To use persistent homology or any of the other topological invariants to solve a Minesweeper game board, you would follow the same basic steps as described above, but with some differences in the specific details.\\n\\nFor instance, here is how you might use persistent homology to solve a Minesweeper game board:\\n\\n1. Represent the Minesweeper game board as a topological space. Each square on the board corresponds to a point in the space, and the mines on the board can be thought of as \\\"holes\\\" in the space, which are points that are not part of the space.\\n2. Compute the persistent homology of the space. Persistent homology is a way of describing the holes in a space over time, and it can be used to infer information about the locations of the mines on the Minesweeper board.\\n3. Use the information from the persistent homology to determine the locations of the mines on the board. For example, if the persistent homology indicates that there are four holes in the space, arranged in a square shape, then we can infer that the mines are located in the four corners of the board.\\n4. Use this information to solve the Minesweeper game. Once you know the locations of the mines, you can use this information to avoid them and clear the board.\\n\\nOf course, this is just one possible approach to using persistent homology to solve Minesweeper. There may be other ways to do it, depending on the specific game board and the specific details of the persistent homology computation. However, this should give you a general idea of how persistent homology can be used to solve Minesweeper.\"\n    },\n    {\n      \"from\": \"human\",\n      \"value\": \"how could i implement this in python?\"\n    },\n    {\n      \"from\": \"gpt\",\n      \"value\": \"To implement a Minesweeper solver that uses algebraic topology in Python, you would need to do the following:\\n\\n1. Install a Python package for algebraic topology. There are several options available, such as Gudhi, Dionysus, or scikit-tda. These packages provide algorithms and data structures for computing topological invariants, such as homotopy groups, Betti numbers, or persistent homology.\\n2. Represent the Minesweeper game board as a topological space. This could be done using a data structure provided by the algebraic topology package, such as a simplicial complex or a Vietoris-Rips complex. Each square on the board would correspond to a point in the space, and the mines on the board would be represented as \\\"holes\\\" in the space.\\n3. Compute the topological invariants of the space. This could be done using algorithms provided by the algebraic topology package. For example, you could compute the persistent homology of the space using the persistent homology algorithm.\\n4. Use the information from the topological invariants to determine the locations of the mines on the board. This could be done using a combination of mathematical reasoning and programming. For example, if the persistent homology indicates that there are four holes in the space, arranged in a square shape, then you could use a piece of code to infer that the mines are located in the four corners of the board.\\n5. Use this information to solve the Minesweeper game. Once you know the locations of the mines, you could use a piece of code to automatically click on the safe squares and clear the board.\\n\\nOf course, this is just one possible approach to implementing a Minesweeper solver that uses algebraic topology in Python. There may be other ways to do it, depending on the specific details of the implementation. However, this should give you a general idea of how to get started with such a project.\"\n    }\n  ]\n}\n"
  },
  {
    "path": "tests/fixtures/conversation.tokenized.json",
    "content": "{\"input_ids\": [1, 319, 13563, 1546, 263, 12758, 1404, 322, 385, 23116, 21082, 20255, 29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568, 6089, 304, 278, 1404, 29915, 29879, 5155, 29889, 29871, 3148, 1001, 29901, 920, 1033, 474, 2334, 263, 29086, 705, 11356, 5687, 393, 3667, 4637, 21531, 20159, 304, 4505, 1045, 3163, 29973, 29871, 319, 1799, 9047, 13566, 29901, 739, 29915, 29879, 1950, 304, 671, 21531, 20159, 304, 4505, 341, 1475, 705, 11356, 29892, 541, 372, 723, 5517, 367, 3755, 4280, 29889, 838, 5869, 293, 20159, 338, 263, 5443, 310, 23964, 393, 11898, 278, 4426, 310, 8162, 393, 526, 21634, 1090, 9126, 316, 689, 800, 29892, 1316, 408, 16116, 292, 322, 289, 2548, 29889, 512, 278, 3030, 310, 341, 1475, 705, 11356, 29892, 445, 1033, 367, 1304, 304, 27599, 278, 3829, 310, 278, 3748, 7613, 322, 8161, 278, 14354, 310, 29086, 2729, 373, 3058, 25002, 15939, 1934, 29889, 2398, 29892, 16049, 445, 723, 1996, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29889, 739, 29915, 29879, 451, 1554, 393, 508, 367, 10824, 297, 263, 2560, 1234, 21106, 29879, 29958, 2, 3148, 1001, 29901, 1033, 366, 2367, 592, 278, 330, 391, 310, 920, 372, 1033, 2466, 29973, 29871, 319, 1799, 9047, 13566, 29901, 18585, 29889, 450, 6996, 2969, 723, 367, 304, 671, 21531, 20159, 304, 27599, 278, 3829, 310, 278, 341, 1475, 705, 11356, 3748, 7613, 322, 8161, 278, 14354, 310, 29086, 2729, 373, 3058, 25002, 15939, 1934, 29889, 910, 1033, 25135, 15783, 278, 3748, 7613, 408, 263, 25002, 2913, 322, 773, 25002, 15939, 1934, 1316, 408, 3632, 327, 2270, 6471, 322, 350, 9890, 3694, 304, 10115, 2472, 1048, 278, 4423, 310, 29086, 373, 278, 7613, 29889, 2398, 29892, 408, 306, 5276, 1434, 29892, 16049, 445, 723, 1996, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29892, 577, 372, 29915, 29879, 451, 1554, 393, 508, 367, 10824, 297, 263, 2560, 1234, 21106, 29879, 29958, 2, 3148, 1001, 29901, 2367, 592, 263, 2702, 1342, 29871, 319, 1799, 9047, 13566, 29901, 1094, 263, 2702, 1342, 29892, 7755, 591, 505, 263, 341, 1475, 705, 11356, 3748, 7613, 411, 278, 1494, 5285, 29901, 13, 28956, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 28956, 13, 797, 445, 1206, 29892, 278, 3748, 7613, 508, 367, 9875, 408, 263, 25002, 2913, 29892, 988, 1269, 6862, 373, 278, 7613, 16161, 304, 263, 1298, 297, 278, 2913, 29889, 450, 29086, 373, 278, 7613, 508, 367, 2714, 310, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29892, 607, 526, 3291, 393, 526, 451, 760, 310, 278, 2913, 29889, 5293, 21531, 20159, 29892, 591, 508, 27599, 278, 3829, 310, 445, 2913, 322, 8161, 278, 14354, 310, 278, 29086, 2729, 373, 3058, 25002, 15939, 1934, 29889, 13, 13, 2831, 2777, 29892, 697, 982, 304, 437, 445, 338, 304, 10272, 278, 3632, 327, 2270, 6471, 310, 278, 2913, 29889, 15089, 327, 2270, 6471, 526, 263, 982, 310, 20766, 278, 26532, 297, 263, 2913, 29892, 322, 896, 508, 367, 1304, 304, 10115, 2472, 1048, 278, 14354, 310, 278, 29086, 373, 278, 341, 1475, 705, 11356, 7613, 29889, 512, 278, 1206, 310, 278, 7613, 2038, 29892, 591, 508, 10272, 278, 937, 3632, 327, 2270, 2318, 310, 278, 2913, 29892, 607, 16612, 278, 1353, 322, 376, 12181, 29908, 310, 278, 26532, 297, 278, 2913, 29889, 512, 445, 1206, 29892, 278, 937, 3632, 327, 2270, 2318, 723, 2649, 502, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29889, 13, 13, 29933, 1463, 373, 445, 2472, 29892, 591, 508, 10115, 393, 278, 29086, 297, 278, 3748, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29892, 408, 4318, 2400, 29901, 13, 28956, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 29871, 29896, 29871, 341, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 341, 259, 29896, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 28956, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1342, 310, 920, 21531, 20159, 1033, 367, 1304, 304, 4505, 341, 1475, 705, 11356, 29889, 1670, 526, 1784, 916, 1950, 13501, 29892, 322, 278, 2702, 4902, 310, 920, 304, 2334, 963, 723, 8839, 373, 278, 2702, 3748, 7613, 322, 278, 25002, 15939, 1934, 1641, 1304, 29889, 1094, 306, 5276, 1434, 29892, 16049, 263, 1650, 763, 445, 723, 1996, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 21106, 29879, 29958, 2, 3148, 1001, 29901, 10241, 474, 471, 4856, 411, 263, 6483, 8004, 310, 1716, 29892, 1033, 366, 5649, 278, 1021, 6964, 304, 592, 411, 393, 11833, 29973, 29871, 319, 1799, 9047, 13566, 29901, 18585, 29889, 17090, 366, 505, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29892, 1244, 338, 920, 366, 1795, 671, 21531, 20159, 304, 4505, 263, 341, 1475, 705, 11356, 3748, 7613, 29901, 13, 13, 29896, 29889, 16314, 278, 341, 1475, 705, 11356, 3748, 7613, 408, 263, 25002, 2913, 29889, 7806, 6862, 373, 278, 7613, 16161, 304, 263, 1298, 297, 278, 2913, 29892, 322, 278, 29086, 373, 278, 7613, 508, 367, 2714, 310, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29892, 607, 526, 3291, 393, 526, 451, 760, 310, 278, 2913, 29889, 13, 29906, 29889, 4803, 25002, 15939, 1934, 304, 10115, 2472, 1048, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 1670, 526, 1784, 1950, 25002, 15939, 1934, 393, 1033, 367, 1304, 363, 445, 29892, 1316, 408, 3632, 327, 2270, 6471, 29892, 350, 9890, 3694, 29892, 470, 28152, 3632, 3002, 29889, 13, 29941, 29889, 4803, 278, 2472, 10115, 1127, 515, 278, 25002, 15939, 1934, 304, 8161, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 1152, 1342, 29892, 565, 278, 937, 3632, 327, 2270, 2318, 310, 278, 2913, 14088, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29892, 769, 591, 508, 10115, 393, 278, 29086, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29889, 13, 29946, 29889, 4803, 445, 2472, 304, 4505, 278, 341, 1475, 705, 11356, 3748, 29889, 9038, 366, 1073, 278, 14354, 310, 278, 29086, 29892, 366, 508, 671, 445, 2472, 304, 4772, 963, 322, 2821, 278, 7613, 29889, 13, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1950, 2948, 304, 773, 21531, 20159, 304, 4505, 341, 1475, 705, 11356, 29889, 1670, 1122, 367, 916, 5837, 304, 437, 372, 29892, 8679, 373, 278, 2702, 3748, 7613, 322, 278, 25002, 15939, 1934, 1641, 1304, 29889, 2398, 29892, 445, 881, 2367, 366, 263, 2498, 2969, 310, 920, 21531, 20159, 508, 367, 1304, 304, 4505, 341, 1475, 705, 11356, 21106, 29879, 29958, 2, 3148, 1001, 29901, 920, 723, 474, 437, 372, 411, 28152, 3632, 3002, 470, 738, 310, 278, 916, 25002, 15939, 1934, 29973, 29871, 319, 1799, 9047, 13566, 29901, 1763, 671, 28152, 3632, 3002, 470, 738, 310, 278, 916, 25002, 15939, 1934, 304, 4505, 263, 341, 1475, 705, 11356, 3748, 7613, 29892, 366, 723, 1101, 278, 1021, 6996, 6576, 408, 5439, 2038, 29892, 541, 411, 777, 12651, 297, 278, 2702, 4902, 29889, 13, 13, 2831, 2777, 29892, 1244, 338, 920, 366, 1795, 671, 28152, 3632, 3002, 304, 4505, 263, 341, 1475, 705, 11356, 3748, 7613, 29901, 13, 13, 29896, 29889, 16314, 278, 341, 1475, 705, 11356, 3748, 7613, 408, 263, 25002, 2913, 29889, 7806, 6862, 373, 278, 7613, 16161, 304, 263, 1298, 297, 278, 2913, 29892, 322, 278, 29086, 373, 278, 7613, 508, 367, 2714, 310, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29892, 607, 526, 3291, 393, 526, 451, 760, 310, 278, 2913, 29889, 13, 29906, 29889, 11796, 29872, 278, 28152, 3632, 3002, 310, 278, 2913, 29889, 9034, 9696, 3632, 3002, 338, 263, 982, 310, 20766, 278, 26532, 297, 263, 2913, 975, 931, 29892, 322, 372, 508, 367, 1304, 304, 10115, 2472, 1048, 278, 14354, 310, 278, 29086, 373, 278, 341, 1475, 705, 11356, 7613, 29889, 13, 29941, 29889, 4803, 278, 2472, 515, 278, 28152, 3632, 3002, 304, 8161, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 1152, 1342, 29892, 565, 278, 28152, 3632, 3002, 14088, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29892, 769, 591, 508, 10115, 393, 278, 29086, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29889, 13, 29946, 29889, 4803, 445, 2472, 304, 4505, 278, 341, 1475, 705, 11356, 3748, 29889, 9038, 366, 1073, 278, 14354, 310, 278, 29086, 29892, 366, 508, 671, 445, 2472, 304, 4772, 963, 322, 2821, 278, 7613, 29889, 13, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1950, 2948, 304, 773, 28152, 3632, 3002, 304, 4505, 341, 1475, 705, 11356, 29889, 1670, 1122, 367, 916, 5837, 304, 437, 372, 29892, 8679, 373, 278, 2702, 3748, 7613, 322, 278, 2702, 4902, 310, 278, 28152, 3632, 3002, 16287, 29889, 2398, 29892, 445, 881, 2367, 366, 263, 2498, 2969, 310, 920, 28152, 3632, 3002, 508, 367, 1304, 304, 4505, 341, 1475, 705, 11356, 21106, 29879, 29958, 2, 3148, 1001, 29901, 920, 1033, 474, 2334, 445, 297, 3017, 29973, 29871, 319, 1799, 9047, 13566, 29901, 1763, 2334, 263, 341, 1475, 705, 11356, 899, 369, 393, 3913, 21531, 20159, 297, 5132, 29892, 366, 723, 817, 304, 437, 278, 1494, 29901, 13, 13, 29896, 29889, 16052, 263, 5132, 3577, 363, 21531, 20159, 29889, 1670, 526, 3196, 3987, 3625, 29892, 1316, 408, 402, 566, 2918, 29892, 360, 291, 952, 375, 29892, 470, 4560, 7354, 29899, 29873, 1388, 29889, 4525, 9741, 3867, 14009, 322, 848, 12286, 363, 20602, 25002, 15939, 1934, 29892, 1316, 408, 3632, 327, 2270, 6471, 29892, 350, 9890, 3694, 29892, 470, 28152, 3632, 3002, 29889, 13, 29906, 29889, 16314, 278, 341, 1475, 705, 11356, 3748, 7613, 408, 263, 25002, 2913, 29889, 910, 1033, 367, 2309, 773, 263, 848, 3829, 4944, 491, 278, 21531, 20159, 3577, 29892, 1316, 408, 263, 3053, 506, 616, 4280, 470, 263, 478, 2035, 29367, 29899, 29934, 4512, 4280, 29889, 7806, 6862, 373, 278, 7613, 723, 3928, 304, 263, 1298, 297, 278, 2913, 29892, 322, 278, 29086, 373, 278, 7613, 723, 367, 9875, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29889, 13, 29941, 29889, 11796, 29872, 278, 25002, 15939, 1934, 310, 278, 2913, 29889, 910, 1033, 367, 2309, 773, 14009, 4944, 491, 278, 21531, 20159, 3577, 29889, 1152, 1342, 29892, 366, 1033, 10272, 278, 28152, 3632, 3002, 310, 278, 2913, 773, 278, 28152, 3632, 3002, 5687, 29889, 13, 29946, 29889, 4803, 278, 2472, 515, 278, 25002, 15939, 1934, 304, 8161, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 910, 1033, 367, 2309, 773, 263, 10296, 310, 19475, 24481, 322, 8720, 29889, 1152, 1342, 29892, 565, 278, 28152, 3632, 3002, 14088, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29892, 769, 366, 1033, 671, 263, 8424, 310, 775, 304, 10115, 393, 278, 29086, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29889, 13, 29945, 29889, 4803, 445, 2472, 304, 4505, 278, 341, 1475, 705, 11356, 3748, 29889, 9038, 366, 1073, 278, 14354, 310, 278, 29086, 29892, 366, 1033, 671, 263, 8424, 310, 775, 304, 6336, 2828, 373, 278, 9109, 25256, 322, 2821, 278, 7613, 29889, 13, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1950, 2948, 304, 16049, 263, 341, 1475, 705, 11356, 899, 369, 393, 3913, 21531, 20159, 297, 5132, 29889, 1670, 1122, 367, 916, 5837, 304, 437, 372, 29892, 8679, 373, 278, 2702, 4902, 310, 278, 5314, 29889, 2398, 29892, 445, 881, 2367, 366, 263, 2498, 2969, 310, 920, 304, 679, 4687, 411, 1316, 263, 2060, 21106, 29879, 29958, 2], \"attention_mask\": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], \"labels\": [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 739, 29915, 29879, 1950, 304, 671, 21531, 20159, 304, 4505, 341, 1475, 705, 11356, 29892, 541, 372, 723, 5517, 367, 3755, 4280, 29889, 838, 5869, 293, 20159, 338, 263, 5443, 310, 23964, 393, 11898, 278, 4426, 310, 8162, 393, 526, 21634, 1090, 9126, 316, 689, 800, 29892, 1316, 408, 16116, 292, 322, 289, 2548, 29889, 512, 278, 3030, 310, 341, 1475, 705, 11356, 29892, 445, 1033, 367, 1304, 304, 27599, 278, 3829, 310, 278, 3748, 7613, 322, 8161, 278, 14354, 310, 29086, 2729, 373, 3058, 25002, 15939, 1934, 29889, 2398, 29892, 16049, 445, 723, 1996, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29889, 739, 29915, 29879, 451, 1554, 393, 508, 367, 10824, 297, 263, 2560, 1234, 21106, 29879, 29958, 2, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 18585, 29889, 450, 6996, 2969, 723, 367, 304, 671, 21531, 20159, 304, 27599, 278, 3829, 310, 278, 341, 1475, 705, 11356, 3748, 7613, 322, 8161, 278, 14354, 310, 29086, 2729, 373, 3058, 25002, 15939, 1934, 29889, 910, 1033, 25135, 15783, 278, 3748, 7613, 408, 263, 25002, 2913, 322, 773, 25002, 15939, 1934, 1316, 408, 3632, 327, 2270, 6471, 322, 350, 9890, 3694, 304, 10115, 2472, 1048, 278, 4423, 310, 29086, 373, 278, 7613, 29889, 2398, 29892, 408, 306, 5276, 1434, 29892, 16049, 445, 723, 1996, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29892, 577, 372, 29915, 29879, 451, 1554, 393, 508, 367, 10824, 297, 263, 2560, 1234, 21106, 29879, 29958, 2, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1094, 263, 2702, 1342, 29892, 7755, 591, 505, 263, 341, 1475, 705, 11356, 3748, 7613, 411, 278, 1494, 5285, 29901, 13, 28956, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 28956, 13, 797, 445, 1206, 29892, 278, 3748, 7613, 508, 367, 9875, 408, 263, 25002, 2913, 29892, 988, 1269, 6862, 373, 278, 7613, 16161, 304, 263, 1298, 297, 278, 2913, 29889, 450, 29086, 373, 278, 7613, 508, 367, 2714, 310, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29892, 607, 526, 3291, 393, 526, 451, 760, 310, 278, 2913, 29889, 5293, 21531, 20159, 29892, 591, 508, 27599, 278, 3829, 310, 445, 2913, 322, 8161, 278, 14354, 310, 278, 29086, 2729, 373, 3058, 25002, 15939, 1934, 29889, 13, 13, 2831, 2777, 29892, 697, 982, 304, 437, 445, 338, 304, 10272, 278, 3632, 327, 2270, 6471, 310, 278, 2913, 29889, 15089, 327, 2270, 6471, 526, 263, 982, 310, 20766, 278, 26532, 297, 263, 2913, 29892, 322, 896, 508, 367, 1304, 304, 10115, 2472, 1048, 278, 14354, 310, 278, 29086, 373, 278, 341, 1475, 705, 11356, 7613, 29889, 512, 278, 1206, 310, 278, 7613, 2038, 29892, 591, 508, 10272, 278, 937, 3632, 327, 2270, 2318, 310, 278, 2913, 29892, 607, 16612, 278, 1353, 322, 376, 12181, 29908, 310, 278, 26532, 297, 278, 2913, 29889, 512, 445, 1206, 29892, 278, 937, 3632, 327, 2270, 2318, 723, 2649, 502, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29889, 13, 13, 29933, 1463, 373, 445, 2472, 29892, 591, 508, 10115, 393, 278, 29086, 297, 278, 3748, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29892, 408, 4318, 2400, 29901, 13, 28956, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 29871, 29896, 29871, 341, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 341, 259, 29896, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 28956, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1342, 310, 920, 21531, 20159, 1033, 367, 1304, 304, 4505, 341, 1475, 705, 11356, 29889, 1670, 526, 1784, 916, 1950, 13501, 29892, 322, 278, 2702, 4902, 310, 920, 304, 2334, 963, 723, 8839, 373, 278, 2702, 3748, 7613, 322, 278, 25002, 15939, 1934, 1641, 1304, 29889, 1094, 306, 5276, 1434, 29892, 16049, 263, 1650, 763, 445, 723, 1996, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 21106, 29879, 29958, 2, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 18585, 29889, 17090, 366, 505, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29892, 1244, 338, 920, 366, 1795, 671, 21531, 20159, 304, 4505, 263, 341, 1475, 705, 11356, 3748, 7613, 29901, 13, 13, 29896, 29889, 16314, 278, 341, 1475, 705, 11356, 3748, 7613, 408, 263, 25002, 2913, 29889, 7806, 6862, 373, 278, 7613, 16161, 304, 263, 1298, 297, 278, 2913, 29892, 322, 278, 29086, 373, 278, 7613, 508, 367, 2714, 310, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29892, 607, 526, 3291, 393, 526, 451, 760, 310, 278, 2913, 29889, 13, 29906, 29889, 4803, 25002, 15939, 1934, 304, 10115, 2472, 1048, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 1670, 526, 1784, 1950, 25002, 15939, 1934, 393, 1033, 367, 1304, 363, 445, 29892, 1316, 408, 3632, 327, 2270, 6471, 29892, 350, 9890, 3694, 29892, 470, 28152, 3632, 3002, 29889, 13, 29941, 29889, 4803, 278, 2472, 10115, 1127, 515, 278, 25002, 15939, 1934, 304, 8161, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 1152, 1342, 29892, 565, 278, 937, 3632, 327, 2270, 2318, 310, 278, 2913, 14088, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29892, 769, 591, 508, 10115, 393, 278, 29086, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29889, 13, 29946, 29889, 4803, 445, 2472, 304, 4505, 278, 341, 1475, 705, 11356, 3748, 29889, 9038, 366, 1073, 278, 14354, 310, 278, 29086, 29892, 366, 508, 671, 445, 2472, 304, 4772, 963, 322, 2821, 278, 7613, 29889, 13, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1950, 2948, 304, 773, 21531, 20159, 304, 4505, 341, 1475, 705, 11356, 29889, 1670, 1122, 367, 916, 5837, 304, 437, 372, 29892, 8679, 373, 278, 2702, 3748, 7613, 322, 278, 25002, 15939, 1934, 1641, 1304, 29889, 2398, 29892, 445, 881, 2367, 366, 263, 2498, 2969, 310, 920, 21531, 20159, 508, 367, 1304, 304, 4505, 341, 1475, 705, 11356, 21106, 29879, 29958, 2, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1763, 671, 28152, 3632, 3002, 470, 738, 310, 278, 916, 25002, 15939, 1934, 304, 4505, 263, 341, 1475, 705, 11356, 3748, 7613, 29892, 366, 723, 1101, 278, 1021, 6996, 6576, 408, 5439, 2038, 29892, 541, 411, 777, 12651, 297, 278, 2702, 4902, 29889, 13, 13, 2831, 2777, 29892, 1244, 338, 920, 366, 1795, 671, 28152, 3632, 3002, 304, 4505, 263, 341, 1475, 705, 11356, 3748, 7613, 29901, 13, 13, 29896, 29889, 16314, 278, 341, 1475, 705, 11356, 3748, 7613, 408, 263, 25002, 2913, 29889, 7806, 6862, 373, 278, 7613, 16161, 304, 263, 1298, 297, 278, 2913, 29892, 322, 278, 29086, 373, 278, 7613, 508, 367, 2714, 310, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29892, 607, 526, 3291, 393, 526, 451, 760, 310, 278, 2913, 29889, 13, 29906, 29889, 11796, 29872, 278, 28152, 3632, 3002, 310, 278, 2913, 29889, 9034, 9696, 3632, 3002, 338, 263, 982, 310, 20766, 278, 26532, 297, 263, 2913, 975, 931, 29892, 322, 372, 508, 367, 1304, 304, 10115, 2472, 1048, 278, 14354, 310, 278, 29086, 373, 278, 341, 1475, 705, 11356, 7613, 29889, 13, 29941, 29889, 4803, 278, 2472, 515, 278, 28152, 3632, 3002, 304, 8161, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 1152, 1342, 29892, 565, 278, 28152, 3632, 3002, 14088, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29892, 769, 591, 508, 10115, 393, 278, 29086, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29889, 13, 29946, 29889, 4803, 445, 2472, 304, 4505, 278, 341, 1475, 705, 11356, 3748, 29889, 9038, 366, 1073, 278, 14354, 310, 278, 29086, 29892, 366, 508, 671, 445, 2472, 304, 4772, 963, 322, 2821, 278, 7613, 29889, 13, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1950, 2948, 304, 773, 28152, 3632, 3002, 304, 4505, 341, 1475, 705, 11356, 29889, 1670, 1122, 367, 916, 5837, 304, 437, 372, 29892, 8679, 373, 278, 2702, 3748, 7613, 322, 278, 2702, 4902, 310, 278, 28152, 3632, 3002, 16287, 29889, 2398, 29892, 445, 881, 2367, 366, 263, 2498, 2969, 310, 920, 28152, 3632, 3002, 508, 367, 1304, 304, 4505, 341, 1475, 705, 11356, 21106, 29879, 29958, 2, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1763, 2334, 263, 341, 1475, 705, 11356, 899, 369, 393, 3913, 21531, 20159, 297, 5132, 29892, 366, 723, 817, 304, 437, 278, 1494, 29901, 13, 13, 29896, 29889, 16052, 263, 5132, 3577, 363, 21531, 20159, 29889, 1670, 526, 3196, 3987, 3625, 29892, 1316, 408, 402, 566, 2918, 29892, 360, 291, 952, 375, 29892, 470, 4560, 7354, 29899, 29873, 1388, 29889, 4525, 9741, 3867, 14009, 322, 848, 12286, 363, 20602, 25002, 15939, 1934, 29892, 1316, 408, 3632, 327, 2270, 6471, 29892, 350, 9890, 3694, 29892, 470, 28152, 3632, 3002, 29889, 13, 29906, 29889, 16314, 278, 341, 1475, 705, 11356, 3748, 7613, 408, 263, 25002, 2913, 29889, 910, 1033, 367, 2309, 773, 263, 848, 3829, 4944, 491, 278, 21531, 20159, 3577, 29892, 1316, 408, 263, 3053, 506, 616, 4280, 470, 263, 478, 2035, 29367, 29899, 29934, 4512, 4280, 29889, 7806, 6862, 373, 278, 7613, 723, 3928, 304, 263, 1298, 297, 278, 2913, 29892, 322, 278, 29086, 373, 278, 7613, 723, 367, 9875, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29889, 13, 29941, 29889, 11796, 29872, 278, 25002, 15939, 1934, 310, 278, 2913, 29889, 910, 1033, 367, 2309, 773, 14009, 4944, 491, 278, 21531, 20159, 3577, 29889, 1152, 1342, 29892, 366, 1033, 10272, 278, 28152, 3632, 3002, 310, 278, 2913, 773, 278, 28152, 3632, 3002, 5687, 29889, 13, 29946, 29889, 4803, 278, 2472, 515, 278, 25002, 15939, 1934, 304, 8161, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 910, 1033, 367, 2309, 773, 263, 10296, 310, 19475, 24481, 322, 8720, 29889, 1152, 1342, 29892, 565, 278, 28152, 3632, 3002, 14088, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29892, 769, 366, 1033, 671, 263, 8424, 310, 775, 304, 10115, 393, 278, 29086, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29889, 13, 29945, 29889, 4803, 445, 2472, 304, 4505, 278, 341, 1475, 705, 11356, 3748, 29889, 9038, 366, 1073, 278, 14354, 310, 278, 29086, 29892, 366, 1033, 671, 263, 8424, 310, 775, 304, 6336, 2828, 373, 278, 9109, 25256, 322, 2821, 278, 7613, 29889, 13, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1950, 2948, 304, 16049, 263, 341, 1475, 705, 11356, 899, 369, 393, 3913, 21531, 20159, 297, 5132, 29889, 1670, 1122, 367, 916, 5837, 304, 437, 372, 29892, 8679, 373, 278, 2702, 4902, 310, 278, 5314, 29889, 2398, 29892, 445, 881, 2367, 366, 263, 2498, 2969, 310, 920, 304, 679, 4687, 411, 1316, 263, 2060, 21106, 29879, 29958, 2]}\n"
  },
  {
    "path": "tests/fixtures/conversation.tokenized_llama2chat.json",
    "content": "{\"input_ids\": [1, 518, 25580, 29962, 3532, 14816, 29903, 6778, 13, 3492, 526, 263, 8444, 29892, 3390, 1319, 322, 15993, 20255, 29889, 29849, 1234, 408, 1371, 3730, 408, 1950, 29892, 1550, 1641, 9109, 29889, 3575, 6089, 881, 451, 3160, 738, 10311, 1319, 29892, 443, 621, 936, 29892, 11021, 391, 29892, 7916, 391, 29892, 304, 27375, 29892, 18215, 29892, 470, 27302, 2793, 29889, 3529, 9801, 393, 596, 20890, 526, 5374, 635, 443, 5365, 1463, 322, 6374, 297, 5469, 29889, 13, 13, 3644, 263, 1139, 947, 451, 1207, 738, 4060, 29892, 470, 338, 451, 2114, 1474, 16165, 261, 296, 29892, 5649, 2020, 2012, 310, 22862, 1554, 451, 1959, 29889, 960, 366, 1016, 29915, 29873, 1073, 278, 1234, 304, 263, 1139, 29892, 3113, 1016, 29915, 29873, 6232, 2089, 2472, 29889, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 3525, 1033, 474, 2334, 263, 29086, 705, 11356, 5687, 393, 3667, 4637, 21531, 20159, 304, 4505, 1045, 3163, 29973, 518, 29914, 25580, 29962, 739, 29915, 29879, 1950, 304, 671, 21531, 20159, 304, 4505, 341, 1475, 705, 11356, 29892, 541, 372, 723, 5517, 367, 3755, 4280, 29889, 838, 5869, 293, 20159, 338, 263, 5443, 310, 23964, 393, 11898, 278, 4426, 310, 8162, 393, 526, 21634, 1090, 9126, 316, 689, 800, 29892, 1316, 408, 16116, 292, 322, 289, 2548, 29889, 512, 278, 3030, 310, 341, 1475, 705, 11356, 29892, 445, 1033, 367, 1304, 304, 27599, 278, 3829, 310, 278, 3748, 7613, 322, 8161, 278, 14354, 310, 29086, 2729, 373, 3058, 25002, 15939, 1934, 29889, 2398, 29892, 16049, 445, 723, 1996, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29889, 739, 29915, 29879, 451, 1554, 393, 508, 367, 10824, 297, 263, 2560, 1234, 29889, 29871, 2, 1, 518, 25580, 29962, 1033, 366, 2367, 592, 278, 330, 391, 310, 920, 372, 1033, 2466, 29973, 518, 29914, 25580, 29962, 18585, 29889, 450, 6996, 2969, 723, 367, 304, 671, 21531, 20159, 304, 27599, 278, 3829, 310, 278, 341, 1475, 705, 11356, 3748, 7613, 322, 8161, 278, 14354, 310, 29086, 2729, 373, 3058, 25002, 15939, 1934, 29889, 910, 1033, 25135, 15783, 278, 3748, 7613, 408, 263, 25002, 2913, 322, 773, 25002, 15939, 1934, 1316, 408, 3632, 327, 2270, 6471, 322, 350, 9890, 3694, 304, 10115, 2472, 1048, 278, 4423, 310, 29086, 373, 278, 7613, 29889, 2398, 29892, 408, 306, 5276, 1434, 29892, 16049, 445, 723, 1996, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29892, 577, 372, 29915, 29879, 451, 1554, 393, 508, 367, 10824, 297, 263, 2560, 1234, 29889, 29871, 2, 1, 518, 25580, 29962, 2367, 592, 263, 2702, 1342, 518, 29914, 25580, 29962, 1094, 263, 2702, 1342, 29892, 7755, 591, 505, 263, 341, 1475, 705, 11356, 3748, 7613, 411, 278, 1494, 5285, 29901, 13, 28956, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 28956, 13, 797, 445, 1206, 29892, 278, 3748, 7613, 508, 367, 9875, 408, 263, 25002, 2913, 29892, 988, 1269, 6862, 373, 278, 7613, 16161, 304, 263, 1298, 297, 278, 2913, 29889, 450, 29086, 373, 278, 7613, 508, 367, 2714, 310, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29892, 607, 526, 3291, 393, 526, 451, 760, 310, 278, 2913, 29889, 5293, 21531, 20159, 29892, 591, 508, 27599, 278, 3829, 310, 445, 2913, 322, 8161, 278, 14354, 310, 278, 29086, 2729, 373, 3058, 25002, 15939, 1934, 29889, 13, 13, 2831, 2777, 29892, 697, 982, 304, 437, 445, 338, 304, 10272, 278, 3632, 327, 2270, 6471, 310, 278, 2913, 29889, 15089, 327, 2270, 6471, 526, 263, 982, 310, 20766, 278, 26532, 297, 263, 2913, 29892, 322, 896, 508, 367, 1304, 304, 10115, 2472, 1048, 278, 14354, 310, 278, 29086, 373, 278, 341, 1475, 705, 11356, 7613, 29889, 512, 278, 1206, 310, 278, 7613, 2038, 29892, 591, 508, 10272, 278, 937, 3632, 327, 2270, 2318, 310, 278, 2913, 29892, 607, 16612, 278, 1353, 322, 376, 12181, 29908, 310, 278, 26532, 297, 278, 2913, 29889, 512, 445, 1206, 29892, 278, 937, 3632, 327, 2270, 2318, 723, 2649, 502, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29889, 13, 13, 29933, 1463, 373, 445, 2472, 29892, 591, 508, 10115, 393, 278, 29086, 297, 278, 3748, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29892, 408, 4318, 2400, 29901, 13, 28956, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 29871, 29896, 29871, 341, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 341, 259, 29896, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 28956, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1342, 310, 920, 21531, 20159, 1033, 367, 1304, 304, 4505, 341, 1475, 705, 11356, 29889, 1670, 526, 1784, 916, 1950, 13501, 29892, 322, 278, 2702, 4902, 310, 920, 304, 2334, 963, 723, 8839, 373, 278, 2702, 3748, 7613, 322, 278, 25002, 15939, 1934, 1641, 1304, 29889, 1094, 306, 5276, 1434, 29892, 16049, 263, 1650, 763, 445, 723, 1996, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29889, 29871, 2, 1, 518, 25580, 29962, 10241, 474, 471, 4856, 411, 263, 6483, 8004, 310, 1716, 29892, 1033, 366, 5649, 278, 1021, 6964, 304, 592, 411, 393, 11833, 29973, 518, 29914, 25580, 29962, 18585, 29889, 17090, 366, 505, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29892, 1244, 338, 920, 366, 1795, 671, 21531, 20159, 304, 4505, 263, 341, 1475, 705, 11356, 3748, 7613, 29901, 13, 13, 29896, 29889, 16314, 278, 341, 1475, 705, 11356, 3748, 7613, 408, 263, 25002, 2913, 29889, 7806, 6862, 373, 278, 7613, 16161, 304, 263, 1298, 297, 278, 2913, 29892, 322, 278, 29086, 373, 278, 7613, 508, 367, 2714, 310, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29892, 607, 526, 3291, 393, 526, 451, 760, 310, 278, 2913, 29889, 13, 29906, 29889, 4803, 25002, 15939, 1934, 304, 10115, 2472, 1048, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 1670, 526, 1784, 1950, 25002, 15939, 1934, 393, 1033, 367, 1304, 363, 445, 29892, 1316, 408, 3632, 327, 2270, 6471, 29892, 350, 9890, 3694, 29892, 470, 28152, 3632, 3002, 29889, 13, 29941, 29889, 4803, 278, 2472, 10115, 1127, 515, 278, 25002, 15939, 1934, 304, 8161, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 1152, 1342, 29892, 565, 278, 937, 3632, 327, 2270, 2318, 310, 278, 2913, 14088, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29892, 769, 591, 508, 10115, 393, 278, 29086, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29889, 13, 29946, 29889, 4803, 445, 2472, 304, 4505, 278, 341, 1475, 705, 11356, 3748, 29889, 9038, 366, 1073, 278, 14354, 310, 278, 29086, 29892, 366, 508, 671, 445, 2472, 304, 4772, 963, 322, 2821, 278, 7613, 29889, 13, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1950, 2948, 304, 773, 21531, 20159, 304, 4505, 341, 1475, 705, 11356, 29889, 1670, 1122, 367, 916, 5837, 304, 437, 372, 29892, 8679, 373, 278, 2702, 3748, 7613, 322, 278, 25002, 15939, 1934, 1641, 1304, 29889, 2398, 29892, 445, 881, 2367, 366, 263, 2498, 2969, 310, 920, 21531, 20159, 508, 367, 1304, 304, 4505, 341, 1475, 705, 11356, 29889, 29871, 2, 1, 518, 25580, 29962, 920, 723, 474, 437, 372, 411, 28152, 3632, 3002, 470, 738, 310, 278, 916, 25002, 15939, 1934, 29973, 518, 29914, 25580, 29962, 1763, 671, 28152, 3632, 3002, 470, 738, 310, 278, 916, 25002, 15939, 1934, 304, 4505, 263, 341, 1475, 705, 11356, 3748, 7613, 29892, 366, 723, 1101, 278, 1021, 6996, 6576, 408, 5439, 2038, 29892, 541, 411, 777, 12651, 297, 278, 2702, 4902, 29889, 13, 13, 2831, 2777, 29892, 1244, 338, 920, 366, 1795, 671, 28152, 3632, 3002, 304, 4505, 263, 341, 1475, 705, 11356, 3748, 7613, 29901, 13, 13, 29896, 29889, 16314, 278, 341, 1475, 705, 11356, 3748, 7613, 408, 263, 25002, 2913, 29889, 7806, 6862, 373, 278, 7613, 16161, 304, 263, 1298, 297, 278, 2913, 29892, 322, 278, 29086, 373, 278, 7613, 508, 367, 2714, 310, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29892, 607, 526, 3291, 393, 526, 451, 760, 310, 278, 2913, 29889, 13, 29906, 29889, 11796, 29872, 278, 28152, 3632, 3002, 310, 278, 2913, 29889, 9034, 9696, 3632, 3002, 338, 263, 982, 310, 20766, 278, 26532, 297, 263, 2913, 975, 931, 29892, 322, 372, 508, 367, 1304, 304, 10115, 2472, 1048, 278, 14354, 310, 278, 29086, 373, 278, 341, 1475, 705, 11356, 7613, 29889, 13, 29941, 29889, 4803, 278, 2472, 515, 278, 28152, 3632, 3002, 304, 8161, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 1152, 1342, 29892, 565, 278, 28152, 3632, 3002, 14088, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29892, 769, 591, 508, 10115, 393, 278, 29086, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29889, 13, 29946, 29889, 4803, 445, 2472, 304, 4505, 278, 341, 1475, 705, 11356, 3748, 29889, 9038, 366, 1073, 278, 14354, 310, 278, 29086, 29892, 366, 508, 671, 445, 2472, 304, 4772, 963, 322, 2821, 278, 7613, 29889, 13, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1950, 2948, 304, 773, 28152, 3632, 3002, 304, 4505, 341, 1475, 705, 11356, 29889, 1670, 1122, 367, 916, 5837, 304, 437, 372, 29892, 8679, 373, 278, 2702, 3748, 7613, 322, 278, 2702, 4902, 310, 278, 28152, 3632, 3002, 16287, 29889, 2398, 29892, 445, 881, 2367, 366, 263, 2498, 2969, 310, 920, 28152, 3632, 3002, 508, 367, 1304, 304, 4505, 341, 1475, 705, 11356, 29889, 29871, 2, 1, 518, 25580, 29962, 920, 1033, 474, 2334, 445, 297, 3017, 29973, 518, 29914, 25580, 29962, 1763, 2334, 263, 341, 1475, 705, 11356, 899, 369, 393, 3913, 21531, 20159, 297, 5132, 29892, 366, 723, 817, 304, 437, 278, 1494, 29901, 13, 13, 29896, 29889, 16052, 263, 5132, 3577, 363, 21531, 20159, 29889, 1670, 526, 3196, 3987, 3625, 29892, 1316, 408, 402, 566, 2918, 29892, 360, 291, 952, 375, 29892, 470, 4560, 7354, 29899, 29873, 1388, 29889, 4525, 9741, 3867, 14009, 322, 848, 12286, 363, 20602, 25002, 15939, 1934, 29892, 1316, 408, 3632, 327, 2270, 6471, 29892, 350, 9890, 3694, 29892, 470, 28152, 3632, 3002, 29889, 13, 29906, 29889, 16314, 278, 341, 1475, 705, 11356, 3748, 7613, 408, 263, 25002, 2913, 29889, 910, 1033, 367, 2309, 773, 263, 848, 3829, 4944, 491, 278, 21531, 20159, 3577, 29892, 1316, 408, 263, 3053, 506, 616, 4280, 470, 263, 478, 2035, 29367, 29899, 29934, 4512, 4280, 29889, 7806, 6862, 373, 278, 7613, 723, 3928, 304, 263, 1298, 297, 278, 2913, 29892, 322, 278, 29086, 373, 278, 7613, 723, 367, 9875, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29889, 13, 29941, 29889, 11796, 29872, 278, 25002, 15939, 1934, 310, 278, 2913, 29889, 910, 1033, 367, 2309, 773, 14009, 4944, 491, 278, 21531, 20159, 3577, 29889, 1152, 1342, 29892, 366, 1033, 10272, 278, 28152, 3632, 3002, 310, 278, 2913, 773, 278, 28152, 3632, 3002, 5687, 29889, 13, 29946, 29889, 4803, 278, 2472, 515, 278, 25002, 15939, 1934, 304, 8161, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 910, 1033, 367, 2309, 773, 263, 10296, 310, 19475, 24481, 322, 8720, 29889, 1152, 1342, 29892, 565, 278, 28152, 3632, 3002, 14088, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29892, 769, 366, 1033, 671, 263, 8424, 310, 775, 304, 10115, 393, 278, 29086, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29889, 13, 29945, 29889, 4803, 445, 2472, 304, 4505, 278, 341, 1475, 705, 11356, 3748, 29889, 9038, 366, 1073, 278, 14354, 310, 278, 29086, 29892, 366, 1033, 671, 263, 8424, 310, 775, 304, 6336, 2828, 373, 278, 9109, 25256, 322, 2821, 278, 7613, 29889, 13, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1950, 2948, 304, 16049, 263, 341, 1475, 705, 11356, 899, 369, 393, 3913, 21531, 20159, 297, 5132, 29889, 1670, 1122, 367, 916, 5837, 304, 437, 372, 29892, 8679, 373, 278, 2702, 4902, 310, 278, 5314, 29889, 2398, 29892, 445, 881, 2367, 366, 263, 2498, 2969, 310, 920, 304, 679, 4687, 411, 1316, 263, 2060, 29889, 29871, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], \"labels\": [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 739, 29915, 29879, 1950, 304, 671, 21531, 20159, 304, 4505, 341, 1475, 705, 11356, 29892, 541, 372, 723, 5517, 367, 3755, 4280, 29889, 838, 5869, 293, 20159, 338, 263, 5443, 310, 23964, 393, 11898, 278, 4426, 310, 8162, 393, 526, 21634, 1090, 9126, 316, 689, 800, 29892, 1316, 408, 16116, 292, 322, 289, 2548, 29889, 512, 278, 3030, 310, 341, 1475, 705, 11356, 29892, 445, 1033, 367, 1304, 304, 27599, 278, 3829, 310, 278, 3748, 7613, 322, 8161, 278, 14354, 310, 29086, 2729, 373, 3058, 25002, 15939, 1934, 29889, 2398, 29892, 16049, 445, 723, 1996, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29889, 739, 29915, 29879, 451, 1554, 393, 508, 367, 10824, 297, 263, 2560, 1234, 29889, 29871, 2, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 18585, 29889, 450, 6996, 2969, 723, 367, 304, 671, 21531, 20159, 304, 27599, 278, 3829, 310, 278, 341, 1475, 705, 11356, 3748, 7613, 322, 8161, 278, 14354, 310, 29086, 2729, 373, 3058, 25002, 15939, 1934, 29889, 910, 1033, 25135, 15783, 278, 3748, 7613, 408, 263, 25002, 2913, 322, 773, 25002, 15939, 1934, 1316, 408, 3632, 327, 2270, 6471, 322, 350, 9890, 3694, 304, 10115, 2472, 1048, 278, 4423, 310, 29086, 373, 278, 7613, 29889, 2398, 29892, 408, 306, 5276, 1434, 29892, 16049, 445, 723, 1996, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29892, 577, 372, 29915, 29879, 451, 1554, 393, 508, 367, 10824, 297, 263, 2560, 1234, 29889, 29871, 2, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1094, 263, 2702, 1342, 29892, 7755, 591, 505, 263, 341, 1475, 705, 11356, 3748, 7613, 411, 278, 1494, 5285, 29901, 13, 28956, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 28956, 13, 797, 445, 1206, 29892, 278, 3748, 7613, 508, 367, 9875, 408, 263, 25002, 2913, 29892, 988, 1269, 6862, 373, 278, 7613, 16161, 304, 263, 1298, 297, 278, 2913, 29889, 450, 29086, 373, 278, 7613, 508, 367, 2714, 310, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29892, 607, 526, 3291, 393, 526, 451, 760, 310, 278, 2913, 29889, 5293, 21531, 20159, 29892, 591, 508, 27599, 278, 3829, 310, 445, 2913, 322, 8161, 278, 14354, 310, 278, 29086, 2729, 373, 3058, 25002, 15939, 1934, 29889, 13, 13, 2831, 2777, 29892, 697, 982, 304, 437, 445, 338, 304, 10272, 278, 3632, 327, 2270, 6471, 310, 278, 2913, 29889, 15089, 327, 2270, 6471, 526, 263, 982, 310, 20766, 278, 26532, 297, 263, 2913, 29892, 322, 896, 508, 367, 1304, 304, 10115, 2472, 1048, 278, 14354, 310, 278, 29086, 373, 278, 341, 1475, 705, 11356, 7613, 29889, 512, 278, 1206, 310, 278, 7613, 2038, 29892, 591, 508, 10272, 278, 937, 3632, 327, 2270, 2318, 310, 278, 2913, 29892, 607, 16612, 278, 1353, 322, 376, 12181, 29908, 310, 278, 26532, 297, 278, 2913, 29889, 512, 445, 1206, 29892, 278, 937, 3632, 327, 2270, 2318, 723, 2649, 502, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29889, 13, 13, 29933, 1463, 373, 445, 2472, 29892, 591, 508, 10115, 393, 278, 29086, 297, 278, 3748, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29892, 408, 4318, 2400, 29901, 13, 28956, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 29871, 29896, 29871, 341, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 341, 259, 29896, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 28956, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1342, 310, 920, 21531, 20159, 1033, 367, 1304, 304, 4505, 341, 1475, 705, 11356, 29889, 1670, 526, 1784, 916, 1950, 13501, 29892, 322, 278, 2702, 4902, 310, 920, 304, 2334, 963, 723, 8839, 373, 278, 2702, 3748, 7613, 322, 278, 25002, 15939, 1934, 1641, 1304, 29889, 1094, 306, 5276, 1434, 29892, 16049, 263, 1650, 763, 445, 723, 1996, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29889, 29871, 2, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 18585, 29889, 17090, 366, 505, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29892, 1244, 338, 920, 366, 1795, 671, 21531, 20159, 304, 4505, 263, 341, 1475, 705, 11356, 3748, 7613, 29901, 13, 13, 29896, 29889, 16314, 278, 341, 1475, 705, 11356, 3748, 7613, 408, 263, 25002, 2913, 29889, 7806, 6862, 373, 278, 7613, 16161, 304, 263, 1298, 297, 278, 2913, 29892, 322, 278, 29086, 373, 278, 7613, 508, 367, 2714, 310, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29892, 607, 526, 3291, 393, 526, 451, 760, 310, 278, 2913, 29889, 13, 29906, 29889, 4803, 25002, 15939, 1934, 304, 10115, 2472, 1048, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 1670, 526, 1784, 1950, 25002, 15939, 1934, 393, 1033, 367, 1304, 363, 445, 29892, 1316, 408, 3632, 327, 2270, 6471, 29892, 350, 9890, 3694, 29892, 470, 28152, 3632, 3002, 29889, 13, 29941, 29889, 4803, 278, 2472, 10115, 1127, 515, 278, 25002, 15939, 1934, 304, 8161, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 1152, 1342, 29892, 565, 278, 937, 3632, 327, 2270, 2318, 310, 278, 2913, 14088, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29892, 769, 591, 508, 10115, 393, 278, 29086, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29889, 13, 29946, 29889, 4803, 445, 2472, 304, 4505, 278, 341, 1475, 705, 11356, 3748, 29889, 9038, 366, 1073, 278, 14354, 310, 278, 29086, 29892, 366, 508, 671, 445, 2472, 304, 4772, 963, 322, 2821, 278, 7613, 29889, 13, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1950, 2948, 304, 773, 21531, 20159, 304, 4505, 341, 1475, 705, 11356, 29889, 1670, 1122, 367, 916, 5837, 304, 437, 372, 29892, 8679, 373, 278, 2702, 3748, 7613, 322, 278, 25002, 15939, 1934, 1641, 1304, 29889, 2398, 29892, 445, 881, 2367, 366, 263, 2498, 2969, 310, 920, 21531, 20159, 508, 367, 1304, 304, 4505, 341, 1475, 705, 11356, 29889, 29871, 2, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1763, 671, 28152, 3632, 3002, 470, 738, 310, 278, 916, 25002, 15939, 1934, 304, 4505, 263, 341, 1475, 705, 11356, 3748, 7613, 29892, 366, 723, 1101, 278, 1021, 6996, 6576, 408, 5439, 2038, 29892, 541, 411, 777, 12651, 297, 278, 2702, 4902, 29889, 13, 13, 2831, 2777, 29892, 1244, 338, 920, 366, 1795, 671, 28152, 3632, 3002, 304, 4505, 263, 341, 1475, 705, 11356, 3748, 7613, 29901, 13, 13, 29896, 29889, 16314, 278, 341, 1475, 705, 11356, 3748, 7613, 408, 263, 25002, 2913, 29889, 7806, 6862, 373, 278, 7613, 16161, 304, 263, 1298, 297, 278, 2913, 29892, 322, 278, 29086, 373, 278, 7613, 508, 367, 2714, 310, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29892, 607, 526, 3291, 393, 526, 451, 760, 310, 278, 2913, 29889, 13, 29906, 29889, 11796, 29872, 278, 28152, 3632, 3002, 310, 278, 2913, 29889, 9034, 9696, 3632, 3002, 338, 263, 982, 310, 20766, 278, 26532, 297, 263, 2913, 975, 931, 29892, 322, 372, 508, 367, 1304, 304, 10115, 2472, 1048, 278, 14354, 310, 278, 29086, 373, 278, 341, 1475, 705, 11356, 7613, 29889, 13, 29941, 29889, 4803, 278, 2472, 515, 278, 28152, 3632, 3002, 304, 8161, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 1152, 1342, 29892, 565, 278, 28152, 3632, 3002, 14088, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29892, 769, 591, 508, 10115, 393, 278, 29086, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29889, 13, 29946, 29889, 4803, 445, 2472, 304, 4505, 278, 341, 1475, 705, 11356, 3748, 29889, 9038, 366, 1073, 278, 14354, 310, 278, 29086, 29892, 366, 508, 671, 445, 2472, 304, 4772, 963, 322, 2821, 278, 7613, 29889, 13, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1950, 2948, 304, 773, 28152, 3632, 3002, 304, 4505, 341, 1475, 705, 11356, 29889, 1670, 1122, 367, 916, 5837, 304, 437, 372, 29892, 8679, 373, 278, 2702, 3748, 7613, 322, 278, 2702, 4902, 310, 278, 28152, 3632, 3002, 16287, 29889, 2398, 29892, 445, 881, 2367, 366, 263, 2498, 2969, 310, 920, 28152, 3632, 3002, 508, 367, 1304, 304, 4505, 341, 1475, 705, 11356, 29889, 29871, 2, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1763, 2334, 263, 341, 1475, 705, 11356, 899, 369, 393, 3913, 21531, 20159, 297, 5132, 29892, 366, 723, 817, 304, 437, 278, 1494, 29901, 13, 13, 29896, 29889, 16052, 263, 5132, 3577, 363, 21531, 20159, 29889, 1670, 526, 3196, 3987, 3625, 29892, 1316, 408, 402, 566, 2918, 29892, 360, 291, 952, 375, 29892, 470, 4560, 7354, 29899, 29873, 1388, 29889, 4525, 9741, 3867, 14009, 322, 848, 12286, 363, 20602, 25002, 15939, 1934, 29892, 1316, 408, 3632, 327, 2270, 6471, 29892, 350, 9890, 3694, 29892, 470, 28152, 3632, 3002, 29889, 13, 29906, 29889, 16314, 278, 341, 1475, 705, 11356, 3748, 7613, 408, 263, 25002, 2913, 29889, 910, 1033, 367, 2309, 773, 263, 848, 3829, 4944, 491, 278, 21531, 20159, 3577, 29892, 1316, 408, 263, 3053, 506, 616, 4280, 470, 263, 478, 2035, 29367, 29899, 29934, 4512, 4280, 29889, 7806, 6862, 373, 278, 7613, 723, 3928, 304, 263, 1298, 297, 278, 2913, 29892, 322, 278, 29086, 373, 278, 7613, 723, 367, 9875, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29889, 13, 29941, 29889, 11796, 29872, 278, 25002, 15939, 1934, 310, 278, 2913, 29889, 910, 1033, 367, 2309, 773, 14009, 4944, 491, 278, 21531, 20159, 3577, 29889, 1152, 1342, 29892, 366, 1033, 10272, 278, 28152, 3632, 3002, 310, 278, 2913, 773, 278, 28152, 3632, 3002, 5687, 29889, 13, 29946, 29889, 4803, 278, 2472, 515, 278, 25002, 15939, 1934, 304, 8161, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 910, 1033, 367, 2309, 773, 263, 10296, 310, 19475, 24481, 322, 8720, 29889, 1152, 1342, 29892, 565, 278, 28152, 3632, 3002, 14088, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29892, 769, 366, 1033, 671, 263, 8424, 310, 775, 304, 10115, 393, 278, 29086, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29889, 13, 29945, 29889, 4803, 445, 2472, 304, 4505, 278, 341, 1475, 705, 11356, 3748, 29889, 9038, 366, 1073, 278, 14354, 310, 278, 29086, 29892, 366, 1033, 671, 263, 8424, 310, 775, 304, 6336, 2828, 373, 278, 9109, 25256, 322, 2821, 278, 7613, 29889, 13, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1950, 2948, 304, 16049, 263, 341, 1475, 705, 11356, 899, 369, 393, 3913, 21531, 20159, 297, 5132, 29889, 1670, 1122, 367, 916, 5837, 304, 437, 372, 29892, 8679, 373, 278, 2702, 4902, 310, 278, 5314, 29889, 2398, 29892, 445, 881, 2367, 366, 263, 2498, 2969, 310, 920, 304, 679, 4687, 411, 1316, 263, 2060, 29889, 29871, 2, 1, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100], \"attention_mask\": [true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false]}\n"
  },
  {
    "path": "tests/hf_offline_utils.py",
    "content": "\"\"\"\ntest utils for helpers and decorators\n\"\"\"\n\nimport os\nfrom contextlib import contextmanager\nfrom functools import wraps\n\n\ndef reload_modules(hf_hub_offline):\n    # Force reload of the modules that check this variable\n    import importlib\n\n    import datasets\n    import huggingface_hub.constants\n    # from huggingface_hub.utils import reset_sessions\n\n    # Reload the constants module first, as others depend on it\n    importlib.reload(huggingface_hub.constants)\n    huggingface_hub.constants.HF_HUB_OFFLINE = hf_hub_offline\n    importlib.reload(datasets.config)\n    datasets.config.HF_HUB_OFFLINE = hf_hub_offline\n\n\ndef enable_hf_offline(test_func):\n    \"\"\"\n    test decorator that sets HF_HUB_OFFLINE environment variable to True and restores it after the test even if the test fails.\n    :param test_func:\n    :return:\n    \"\"\"\n\n    @wraps(test_func)\n    def wrapper(*args, **kwargs):\n        # Save the original value of HF_HUB_OFFLINE environment variable\n        original_hf_offline = os.getenv(\"HF_HUB_OFFLINE\")\n\n        # Set HF_OFFLINE environment variable to True\n        os.environ[\"HF_HUB_OFFLINE\"] = \"1\"\n\n        reload_modules(True)\n        try:\n            # Run the test function\n            return test_func(*args, **kwargs)\n        finally:\n            # Restore the original value of HF_HUB_OFFLINE environment variable\n            if original_hf_offline is not None:\n                os.environ[\"HF_HUB_OFFLINE\"] = original_hf_offline\n                reload_modules(bool(original_hf_offline))\n            else:\n                del os.environ[\"HF_HUB_OFFLINE\"]\n                reload_modules(False)\n\n    return wrapper\n\n\ndef disable_hf_offline(test_func):\n    \"\"\"\n    test decorator that sets HF_HUB_OFFLINE environment variable to False and restores it after the wrapped func\n    :param test_func:\n    :return:\n    \"\"\"\n\n    @wraps(test_func)\n    def wrapper(*args, **kwargs):\n        # Save the original value of HF_HUB_OFFLINE environment variable\n        original_hf_offline = os.getenv(\"HF_HUB_OFFLINE\")\n\n        # Set HF_OFFLINE environment variable to True\n        os.environ[\"HF_HUB_OFFLINE\"] = \"0\"\n\n        reload_modules(False)\n        try:\n            # Run the test function\n            return test_func(*args, **kwargs)\n        finally:\n            # Restore the original value of HF_HUB_OFFLINE environment variable\n            if original_hf_offline is not None:\n                os.environ[\"HF_HUB_OFFLINE\"] = original_hf_offline\n                reload_modules(bool(original_hf_offline))\n            else:\n                del os.environ[\"HF_HUB_OFFLINE\"]\n                reload_modules(False)\n\n    return wrapper\n\n\n@contextmanager\ndef hf_offline_context(hf_hub_offline):\n    \"\"\"\n    Context manager that sets HF_HUB_OFFLINE environment variable to the given value.\n    :param hf_hub_offline: The new value for HF_HUB_OFFLINE.\n    :return: A context manager.\n    \"\"\"\n    original_hf_offline = os.getenv(\"HF_HUB_OFFLINE\")\n    os.environ[\"HF_HUB_OFFLINE\"] = str(hf_hub_offline)\n    reload_modules(bool(hf_hub_offline))\n    yield\n    # Restore the original value of HF_HUB_OFFLINE environment variable\n    if original_hf_offline is not None:\n        os.environ[\"HF_HUB_OFFLINE\"] = original_hf_offline\n        reload_modules(bool(original_hf_offline))\n    else:\n        del os.environ[\"HF_HUB_OFFLINE\"]\n        reload_modules(False)\n"
  },
  {
    "path": "tests/integrations/__init__.py",
    "content": ""
  },
  {
    "path": "tests/integrations/test_diffusion.py",
    "content": "\"\"\"Tests for diffusion trainer integration.\"\"\"\n\n# pylint: disable=redefined-outer-name,protected-access\n\nfrom unittest.mock import Mock\n\nimport pytest\nimport torch\n\nfrom axolotl.integrations.diffusion import DiffusionTrainer\nfrom axolotl.integrations.diffusion.utils import create_bidirectional_attention_mask\nfrom axolotl.utils.dict import DictDefault\n\n\n@pytest.fixture\ndef mock_tokenizer():\n    \"\"\"Create a mock tokenizer.\"\"\"\n    tokenizer = Mock()\n    tokenizer.bos_token_id = 1\n    tokenizer.eos_token_id = 2\n    tokenizer.pad_token_id = 0\n    return tokenizer\n\n\n@pytest.fixture\ndef diffusion_config():\n    \"\"\"Create a diffusion config.\"\"\"\n    return DictDefault(\n        {\n            \"diffusion\": {\n                \"mask_token_id\": 32000,\n                \"eps\": 1e-3,\n                \"importance_weighting\": False,\n            },\n            \"sample_packing\": False,\n        }\n    )\n\n\n@pytest.fixture\ndef diffusion_trainer_instance(mock_tokenizer, diffusion_config):\n    \"\"\"Create a diffusion trainer instance for testing methods directly.\"\"\"\n    # Create a minimal trainer instance just for testing methods\n    trainer = object.__new__(DiffusionTrainer)  # Bypass __init__\n    trainer.cfg = diffusion_config\n    trainer._special_token_ids = {0, 1, 2}  # pad, bos, eos\n    trainer.processing_class = mock_tokenizer\n    trainer.store_metrics = Mock()  # Mock metrics storage\n    return trainer\n\n\nclass TestDiffusionTrainer:\n    \"\"\"Test the DiffusionTrainer class.\"\"\"\n\n    def test_forward_process_basic(self, diffusion_trainer_instance):\n        \"\"\"Test basic forward process without labels.\"\"\"\n        input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)\n\n        noisy_batch, masked_indices, p_mask = (\n            diffusion_trainer_instance._forward_process(input_ids, eps=0.1)\n        )\n\n        # Check shapes\n        assert noisy_batch.shape == input_ids.shape\n        assert masked_indices.shape == input_ids.shape\n        assert p_mask.shape == input_ids.shape\n\n        # Check that special tokens are not masked\n        special_token_positions = (input_ids == 1) | (input_ids == 2) | (input_ids == 0)\n        assert not masked_indices[special_token_positions].any()\n\n        # Check that mask token is applied\n        mask_token_id = diffusion_trainer_instance.cfg.diffusion.mask_token_id\n        masked_positions = masked_indices\n        if masked_positions.any():\n            assert (noisy_batch[masked_positions] == mask_token_id).all()\n\n    def test_forward_process_with_labels(self, diffusion_trainer_instance):\n        \"\"\"Test forward process with SFT labels.\"\"\"\n        input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)\n        labels = torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long)\n\n        noisy_batch, masked_indices, p_mask = (\n            diffusion_trainer_instance._forward_process(\n                input_ids, labels=labels, eps=0.1\n            )\n        )\n\n        # Check shapes\n        assert noisy_batch.shape == input_ids.shape\n        assert masked_indices.shape == input_ids.shape\n        assert p_mask.shape == input_ids.shape\n\n        # Check that only answer tokens can be masked (where labels != -100)\n        non_answer_mask = labels == -100\n\n        # No masking should occur on non-answer tokens\n        assert not masked_indices[non_answer_mask].any()\n\n        # p_mask should be the same for all positions (sampled timestep),\n        # but masking is only applied to answer tokens\n        assert p_mask.shape == input_ids.shape\n        # Verify that masked_indices respects the answer mask\n        assert not masked_indices[non_answer_mask].any()\n\n    def test_forward_process_with_attention_mask(self, diffusion_trainer_instance):\n        \"\"\"Test forward process with attention mask.\"\"\"\n        input_ids = torch.tensor([[1, 10, 20, 0]], dtype=torch.long)\n        attention_mask = torch.tensor([[1, 1, 1, 0]], dtype=torch.long)\n\n        _, masked_indices, p_mask = diffusion_trainer_instance._forward_process(\n            input_ids, attention_mask=attention_mask, eps=0.1\n        )\n\n        # Check that padding tokens are not masked\n        padding_positions = attention_mask == 0\n        assert not masked_indices[padding_positions].any()\n        assert (p_mask[padding_positions] == 0).all()\n\n    def test_bidirectional_attention_mask_no_packing(self, diffusion_trainer_instance):\n        \"\"\"Test bidirectional attention mask without sample packing.\"\"\"\n        input_ids = torch.tensor([[1, 10, 20, 2]], dtype=torch.long)\n\n        mask = create_bidirectional_attention_mask(input_ids)\n\n        # Should be all-to-all attention\n        expected_shape = (1, 1, 4, 4)\n        assert mask.shape == expected_shape\n        assert mask.all()\n\n    def test_bidirectional_attention_mask_with_packing(\n        self, diffusion_trainer_instance\n    ):\n        \"\"\"Test bidirectional attention mask with sample packing.\"\"\"\n        diffusion_trainer_instance.cfg.sample_packing = True\n        input_ids = torch.tensor([[1, 10, 20, 30, 40, 2]], dtype=torch.long)\n        # Sample IDs: first sample (1), second sample (2)\n        attention_mask = torch.tensor([[1, 1, 1, 2, 2, 2]], dtype=torch.long)\n\n        mask = create_bidirectional_attention_mask(\n            input_ids, attention_mask, sample_packing=True\n        )\n\n        # Check that tokens within same sample can attend to each other\n        # but not across samples\n        assert mask[0, 0, 0, 1].item()  # First sample tokens can attend to each other\n        assert mask[0, 0, 1, 2].item()\n        assert not mask[0, 0, 0, 3].item()  # Can't attend across samples\n        assert not mask[0, 0, 2, 4].item()\n        assert mask[0, 0, 3, 4].item()  # Second sample tokens can attend to each other\n\n    def test_compute_loss_basic(self, diffusion_trainer_instance):\n        \"\"\"Test basic loss computation.\"\"\"\n        # Mock model that returns logits\n        mock_model = Mock()\n        mock_outputs = Mock()\n        vocab_size = 1000\n        seq_len = 5\n        mock_outputs.logits = torch.randn(1, seq_len, vocab_size, requires_grad=True)\n        mock_model.return_value = mock_outputs\n        mock_model.training = True\n\n        input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)\n\n        loss, outputs = diffusion_trainer_instance._compute_diffusion_loss(\n            mock_model, input_ids\n        )\n\n        # Check that loss is computed\n        assert isinstance(loss, torch.Tensor)\n        assert loss.requires_grad\n        assert outputs == mock_outputs\n\n        # Check that metrics were stored\n        diffusion_trainer_instance.store_metrics.assert_called_once()\n\n    def test_compute_loss_sft(self, diffusion_trainer_instance):\n        \"\"\"Test loss computation with SFT labels.\"\"\"\n        # Mock model\n        mock_model = Mock()\n        mock_outputs = Mock()\n        vocab_size = 1000\n        seq_len = 5\n        mock_outputs.logits = torch.randn(1, seq_len, vocab_size, requires_grad=True)\n        mock_model.return_value = mock_outputs\n        mock_model.training = True\n        diffusion_trainer_instance.cfg.datasets = Mock()\n\n        input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)\n        labels = torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long)\n\n        loss, _ = diffusion_trainer_instance._compute_diffusion_loss(\n            mock_model, input_ids, labels=labels\n        )\n\n        # Check that loss is computed\n        assert isinstance(loss, torch.Tensor)\n        assert loss.requires_grad\n\n        # Check that SFT metrics were added\n        call_args = diffusion_trainer_instance.store_metrics.call_args[0][0]\n        assert \"answer_ratio\" in call_args\n        assert \"avg_answer_length\" in call_args\n\n    def test_compute_loss_no_masked_tokens(self, diffusion_trainer_instance):\n        \"\"\"Test loss computation when no tokens are masked.\"\"\"\n        # Mock model\n        mock_model = Mock()\n        mock_outputs = Mock()\n        vocab_size = 1000\n        seq_len = 3\n        mock_outputs.logits = torch.randn(1, seq_len, vocab_size)\n        mock_model.return_value = mock_outputs\n        mock_model.training = True\n\n        # Only special tokens (which won't be masked)\n        input_ids = torch.tensor([[1, 0, 2]], dtype=torch.long)\n\n        loss, _ = diffusion_trainer_instance._compute_diffusion_loss(\n            mock_model, input_ids\n        )\n\n        # Loss should be zero when no tokens are masked\n        assert loss.item() == 0.0\n        assert loss.requires_grad\n\n    def test_cache_special_token_ids(self, mock_tokenizer):\n        \"\"\"Test caching of special token IDs.\"\"\"\n        trainer = object.__new__(DiffusionTrainer)\n        trainer.processing_class = mock_tokenizer\n        trainer._cache_special_token_ids()\n        assert trainer._special_token_ids == {0, 1, 2}\n\n    def test_cache_special_token_ids_no_tokenizer(self):\n        \"\"\"Test caching when no tokenizer is available.\"\"\"\n        trainer = object.__new__(DiffusionTrainer)\n        trainer.processing_class = None\n        trainer._cache_special_token_ids()\n\n        assert trainer._special_token_ids == set()\n\n    def test_main_compute_loss_interface(self, diffusion_trainer_instance):\n        \"\"\"Test the main compute_loss interface.\"\"\"\n        # Mock model\n        mock_model = Mock()\n        mock_outputs = Mock()\n        mock_outputs.logits = torch.randn(1, 5, 1000)\n        mock_model.return_value = mock_outputs\n        mock_model.training = True\n\n        inputs = {\n            \"input_ids\": torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long),\n            \"attention_mask\": torch.tensor([[1, 1, 1, 1, 1]], dtype=torch.long),\n            \"labels\": torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long),\n        }\n\n        # Test without return_outputs\n        loss = diffusion_trainer_instance.compute_loss(mock_model, inputs)\n        assert isinstance(loss, torch.Tensor)\n\n        # Test with return_outputs\n        loss, outputs = diffusion_trainer_instance.compute_loss(\n            mock_model, inputs, return_outputs=True\n        )\n        assert isinstance(loss, torch.Tensor)\n        assert outputs == mock_outputs\n\n    def test_missing_input_ids_raises_error(self, diffusion_trainer_instance):\n        \"\"\"Test that missing input_ids raises ValueError.\"\"\"\n        mock_model = Mock()\n        inputs = {\"attention_mask\": torch.tensor([[1, 1, 1]])}\n\n        with pytest.raises(ValueError, match=\"input_ids is required\"):\n            diffusion_trainer_instance.compute_loss(mock_model, inputs)\n"
  },
  {
    "path": "tests/integrations/test_diffusion_callback.py",
    "content": "\"\"\"Tests for diffusion generation callback dataloader selection and triggering.\"\"\"\n\nfrom types import SimpleNamespace\nfrom unittest.mock import Mock\n\nimport pytest\n\nfrom axolotl.integrations.diffusion import DiffusionGenerationCallback\n\n\nclass DummyTrainer:\n    \"\"\"Minimal trainer double with required attributes/methods for the callback.\"\"\"\n\n    def __init__(self, use_eval: bool):\n        # Config used by callback\n        self.cfg = SimpleNamespace(\n            diffusion=SimpleNamespace(\n                generation_interval=1,\n                num_generation_samples=1,\n                generation_max_length=32,\n                generation_steps=4,\n                generation_temperature=0.0,\n                mask_token_id=16,\n            ),\n            use_wandb=False,\n        )\n\n        # Model/tokenizer are passed through to generate_samples; not used here\n        self.model = Mock()\n        self.processing_class = Mock()\n\n        # Datasets and loaders\n        self.eval_dataset = object() if use_eval else None\n        self._train_loader = object()\n        self._eval_loader = object()\n\n        # State for world process check\n        self.state = SimpleNamespace(is_world_process_zero=True)\n\n        # Track which loader was requested\n        self.requested: list[str] = []\n\n    def get_train_dataloader(self):\n        self.requested.append(\"train\")\n        return self._train_loader\n\n    def get_eval_dataloader(self):\n        self.requested.append(\"eval\")\n        return self._eval_loader\n\n\n@pytest.mark.parametrize(\"use_eval\", [False, True])\ndef test_callback_uses_correct_dataloader(monkeypatch, use_eval):\n    trainer = DummyTrainer(use_eval=use_eval)\n    callback = DiffusionGenerationCallback(trainer)\n\n    captured = {}\n\n    # Patch generate_samples in the callback module's namespace\n    def fake_generate_samples(**kwargs):\n        captured[\"dataloader\"] = kwargs.get(\"dataloader\")\n        # Return one dummy sample to exercise logging path\n        return [\n            {\n                \"original\": \"o\",\n                \"masked\": \"m\",\n                \"generated\": \"g\",\n                \"mask_ratio\": 0.5,\n                \"masked_tokens\": 1,\n                \"total_tokens\": 2,\n            }\n        ]\n\n    monkeypatch.setattr(\n        \"axolotl.integrations.diffusion.callbacks.generate_samples\",\n        fake_generate_samples,\n    )\n\n    # Trigger at step 1 (interval=1)\n    args = SimpleNamespace()\n    state = SimpleNamespace(global_step=1)\n    control = SimpleNamespace()\n\n    callback.on_step_end(args=args, state=state, control=control)\n\n    # Assert the expected dataloader path was used\n    if use_eval:\n        assert trainer.requested[0] == \"eval\"\n        assert captured[\"dataloader\"] is trainer._eval_loader\n    else:\n        assert trainer.requested[0] == \"train\"\n        assert captured[\"dataloader\"] is trainer._train_loader\n"
  },
  {
    "path": "tests/integrations/test_kd_chat_template.py",
    "content": "\"\"\"\nTest for KD chat template strategies\n\"\"\"\n\nfrom unittest.mock import Mock\n\nimport pytest\n\nfrom axolotl.integrations.kd.chat_template import ChatTemplateStrategyWithKDv2\n\n\nclass TestChatTemplateStrategyWithKDv2:\n    \"\"\"Test v2 strategy correctly handles target_token_ids\"\"\"\n\n    @pytest.fixture\n    def v2_strategy(self):\n        \"\"\"Create v2 strategy instance with mocked dependencies\"\"\"\n        # Mock prompter\n        mock_prompter = Mock()\n        mock_prompter.roles = {\"user\": \"user\", \"assistant\": \"assistant\"}\n        mock_prompter.chat_template_msg_variables = [\"role\", \"content\"]\n        mock_prompter.chat_template = \"{{ messages }}\"\n\n        # Mock tokenizer\n        mock_tokenizer = Mock()\n        mock_tokenizer.pad_token_id = 0\n        mock_tokenizer.eos_token_id = 2\n        mock_tokenizer.bos_token_id = 1\n        mock_tokenizer.eos_token = \"<|endoftext|>\"\n        mock_tokenizer.apply_chat_template = Mock(return_value=[1, 10, 20, 30, 2])\n        mock_tokenizer.encode = Mock(return_value=[2])\n\n        return ChatTemplateStrategyWithKDv2(\n            prompter=mock_prompter,\n            tokenizer=mock_tokenizer,\n            train_on_inputs=False,\n            sequence_len=512,\n            logprobs_field=\"logprobs\",\n            gen_temperature=1.0,\n            kd_temperature=1.0,\n        )\n\n    def test_v2_prepare_kd_fields_adds_target_token_ids(self, v2_strategy):\n        \"\"\"\n        Test that v2's _prepare_kd_fields hook adds target_token_ids.\n\n        Validates the Template Method pattern fix where v2 overrides\n        the hook to add target_token_ids before transform.\n        \"\"\"\n        tokenized = {\"input_ids\": [1, 10, 20, 30, 2], \"labels\": [1, 10, 20, 30, 2]}\n        original = {\"target_token_ids\": [[10, 20], [30, 40]]}\n\n        result = v2_strategy._prepare_kd_fields(tokenized, original)\n\n        assert \"target_token_ids\" in result\n        assert result[\"target_token_ids\"] == [[10, 20], [30, 40]]\n\n    def test_v2_prepare_kd_fields_handles_missing_field(self, v2_strategy):\n        \"\"\"Test hook handles missing target_token_ids gracefully\"\"\"\n        tokenized = {\"input_ids\": [1, 10, 20, 30, 2], \"labels\": [1, 10, 20, 30, 2]}\n        original = {}\n\n        result = v2_strategy._prepare_kd_fields(tokenized, original)\n\n        assert \"target_token_ids\" not in result\n\n    def test_v2_transform_requires_target_token_ids(self, v2_strategy):\n        \"\"\"\n        Test v2's transform fails without target_token_ids.\n\n        Validates the bug fix - transform expects target_token_ids\n        to be added by the hook.\n        \"\"\"\n        sample = {\n            \"input_ids\": [1, 10, 20, 30, 2],\n            \"labels\": [1, 10, 20, 30, 2],\n            \"logprobs\": [[-0.1, -0.2], [-0.3, -0.4]],\n        }\n\n        with pytest.raises(KeyError, match=\"target_token_ids\"):\n            v2_strategy.transform_logprobs(sample)\n"
  },
  {
    "path": "tests/integrations/test_liger.py",
    "content": "\"\"\"\nconfig validation tests for swiglu args\n\"\"\"\n\nfrom typing import Optional\n\nimport pytest\n\nfrom axolotl.utils.config import prepare_plugins, validate_config\nfrom axolotl.utils.dict import DictDefault\n\n\n@pytest.fixture(name=\"minimal_liger_cfg\")\ndef fixture_cfg():\n    return DictDefault(\n        {\n            \"base_model\": \"TinyLlama/TinyLlama-1.1B-Chat-v0.6\",\n            \"learning_rate\": 0.000001,\n            \"datasets\": [\n                {\n                    \"path\": \"mhenrichsen/alpaca_2k_test\",\n                    \"type\": \"alpaca\",\n                }\n            ],\n            \"micro_batch_size\": 1,\n            \"gradient_accumulation_steps\": 1,\n            \"plugins\": [\"axolotl.integrations.liger.LigerPlugin\"],\n        }\n    )\n\n\nclass TestValidation:\n    \"\"\"\n    Test the validation module for liger\n    \"\"\"\n\n    _caplog: Optional[pytest.LogCaptureFixture] = None\n\n    @pytest.fixture(autouse=True)\n    def inject_fixtures(self, caplog):\n        caplog.set_level(\"WARNING\")\n        self._caplog = caplog\n\n    def test_deprecated_swiglu(self, minimal_liger_cfg):\n        test_cfg = DictDefault(\n            {\n                \"liger_swiglu\": False,\n            }\n            | minimal_liger_cfg\n        )\n\n        with self._caplog.at_level(\"WARNING\", logger=\"axolotl.integrations.liger.args\"):\n            prepare_plugins(test_cfg)\n            updated_cfg = validate_config(test_cfg)\n            # TODO this test is brittle in CI\n            # assert (\n            #     \"The 'liger_swiglu' argument is deprecated\"\n            #     in self._caplog.records[0].message\n            # )\n            assert updated_cfg.liger_swiglu is None\n            assert updated_cfg.liger_glu_activation is False\n\n    def test_conflict_swiglu_ligergluactivation(self, minimal_liger_cfg):\n        test_cfg = DictDefault(\n            {\n                \"liger_swiglu\": False,\n                \"liger_glu_activation\": True,\n            }\n            | minimal_liger_cfg\n        )\n\n        with pytest.raises(\n            ValueError,\n            match=r\".*You cannot have both `liger_swiglu` and `liger_glu_activation` set.*\",\n        ):\n            prepare_plugins(test_cfg)\n            validate_config(test_cfg)\n\n    def test_use_token_scaling_require_flce(self, minimal_liger_cfg):\n        test_cfg = DictDefault(\n            {\n                \"liger_fused_linear_cross_entropy\": False,\n                \"liger_use_token_scaling\": True,\n            }\n            | minimal_liger_cfg\n        )\n\n        with pytest.raises(\n            ValueError,\n            match=r\"`liger_use_token_scaling: true` requires `liger_fused_linear_cross_entropy` enabled.\",\n        ):\n            prepare_plugins(test_cfg)\n            validate_config(test_cfg)\n"
  },
  {
    "path": "tests/integrations/test_routing_parity.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n# Copyright (c) Axolotl AI\n# Licensed under the Apache License, Version 2.0\n\n\"\"\"\nParity tests between scattermoe-lora and sonicmoe routing implementations.\n\nThese tests verify that both implementations produce numerically identical\nresults for the same inputs, ensuring safe centralization of the routing code.\n\nScatterMoE returns 2D tensors [T, K]; SonicMoE returns flattened 1D [T*K].\nThe core algorithm should be identical — only the output format differs.\n\"\"\"\n\nfrom types import SimpleNamespace\n\nimport pytest\nimport torch\n\n\ndef _require_triton():\n    pytest.importorskip(\"triton\")\n\n\n# ============================================================================\n# Fixtures / helpers\n# ============================================================================\n\n\ndef _make_softmax_block(T=8, H=16, E=4, K=2):\n    \"\"\"Qwen/OLMoE-style block usable by both implementations.\"\"\"\n    gate = SimpleNamespace(\n        weight=torch.randn(E, H),\n        top_k=K,\n        num_experts=E,\n        norm_topk_prob=True,\n    )\n    moe_block = SimpleNamespace(gate=gate)\n    hidden = torch.randn(T, H)\n    return moe_block, gate, hidden, T, H, E, K\n\n\ndef _make_sigmoid_block(\n    T=8, H=16, E=16, K=4, n_group=2, topk_group=1, bias_on_gate=True\n):\n    \"\"\"GLM/DeepSeek-style block usable by both implementations.\"\"\"\n    if bias_on_gate:\n        gate = SimpleNamespace(\n            weight=torch.randn(E, H),\n            e_score_correction_bias=torch.zeros(E),\n        )\n        moe_block = SimpleNamespace(\n            gate=gate,\n            top_k=K,\n            n_routed_experts=E,\n            n_group=n_group,\n            topk_group=topk_group,\n            norm_topk_prob=True,\n            routed_scaling_factor=1.0,\n        )\n    else:\n        # minimax_m2 style: bias on block\n        gate = SimpleNamespace(\n            weight=torch.randn(E, H),\n            top_k=K,\n        )\n        moe_block = SimpleNamespace(\n            gate=gate,\n            top_k=K,\n            e_score_correction_bias=torch.zeros(E),\n        )\n    return moe_block, gate, hidden_states(T, H), T, H, E, K\n\n\ndef hidden_states(T, H):\n    return torch.randn(T, H)\n\n\n# ============================================================================\n# 1. Softmax routing parity\n# ============================================================================\n\n\nclass TestSoftmaxRoutingParity:\n    \"\"\"Verify scattermoe and sonicmoe softmax routing produce identical results.\"\"\"\n\n    @pytest.fixture(autouse=True)\n    def _require(self):\n        _require_triton()\n\n    def test_weights_match(self):\n        \"\"\"2D weights from scattermoe == reshaped 1D weights from sonicmoe.\"\"\"\n        from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (\n            _softmax_topk_route,\n        )\n        from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing\n\n        moe_block, gate, hidden, T, H, E, K = _make_softmax_block()\n\n        # ScatterMoE path (no LoRA delta)\n        sm_weights, sm_experts, sm_topk, sm_E = _softmax_topk_route(\n            moe_block, gate, hidden, gate.weight, None\n        )\n\n        # SonicMoE path\n        sonic_scores, sonic_tok_idx, sonic_exp_idx, sonic_logits = softmax_topk_routing(\n            hidden, moe_block\n        )\n\n        # ScatterMoE returns [T, K], SonicMoE returns [T*K] flattened\n        sonic_weights_2d = sonic_scores.reshape(T, K)\n        sonic_experts_2d = sonic_exp_idx.reshape(T, K)\n\n        assert sm_topk == K\n        assert sm_E == E\n\n        # Both should select the same experts and produce the same weights\n        assert torch.equal(sm_experts, sonic_experts_2d.to(sm_experts.dtype))\n        assert torch.allclose(sm_weights, sonic_weights_2d, atol=1e-6)\n\n    def test_logits_not_returned_by_scattermoe(self):\n        \"\"\"ScatterMoE doesn't return logits; SonicMoE does — verify SonicMoE logits shape.\"\"\"\n        from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing\n\n        moe_block, gate, hidden, T, H, E, K = _make_softmax_block()\n        _, _, _, logits = softmax_topk_routing(hidden, moe_block)\n        assert logits.shape == (T, E)\n\n    def test_no_renorm(self):\n        \"\"\"With norm_topk_prob=False, both should skip renormalization.\"\"\"\n        from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (\n            _softmax_topk_route,\n        )\n        from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing\n\n        moe_block, gate, hidden, T, H, E, K = _make_softmax_block()\n        gate.norm_topk_prob = False\n\n        sm_weights, sm_experts, _, _ = _softmax_topk_route(\n            moe_block, gate, hidden, gate.weight, None\n        )\n        sonic_scores, _, sonic_exp_idx, _ = softmax_topk_routing(hidden, moe_block)\n\n        sonic_weights_2d = sonic_scores.reshape(T, K)\n        sonic_experts_2d = sonic_exp_idx.reshape(T, K)\n\n        assert torch.equal(sm_experts, sonic_experts_2d.to(sm_experts.dtype))\n        assert torch.allclose(sm_weights, sonic_weights_2d, atol=1e-6)\n\n    def test_various_expert_counts(self):\n        \"\"\"Parity across different E and K values.\"\"\"\n        from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (\n            _softmax_topk_route,\n        )\n        from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing\n\n        for E, K in [(2, 1), (8, 2), (16, 4), (32, 8)]:\n            moe_block, gate, hidden, T, H, _, _ = _make_softmax_block(E=E, K=K)\n\n            sm_weights, sm_experts, _, _ = _softmax_topk_route(\n                moe_block, gate, hidden, gate.weight, None\n            )\n            sonic_scores, _, sonic_exp_idx, _ = softmax_topk_routing(hidden, moe_block)\n\n            sonic_weights_2d = sonic_scores.reshape(T, K)\n            sonic_experts_2d = sonic_exp_idx.reshape(T, K)\n\n            assert torch.equal(sm_experts, sonic_experts_2d.to(sm_experts.dtype)), (\n                f\"Expert mismatch for E={E}, K={K}\"\n            )\n            assert torch.allclose(sm_weights, sonic_weights_2d, atol=1e-6), (\n                f\"Weight mismatch for E={E}, K={K}\"\n            )\n\n\n# ============================================================================\n# 2. Sigmoid routing parity\n# ============================================================================\n\n\nclass TestSigmoidRoutingParity:\n    \"\"\"Verify scattermoe and sonicmoe sigmoid routing produce identical results.\"\"\"\n\n    @pytest.fixture(autouse=True)\n    def _require(self):\n        _require_triton()\n\n    def test_weights_match_with_groups(self):\n        \"\"\"Both implementations should produce identical weights with group selection.\"\"\"\n        from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (\n            _sigmoid_topk_route,\n        )\n        from axolotl.integrations.kernels.sonicmoe.routing import sigmoid_topk_routing\n\n        moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block(\n            E=16, K=4, n_group=2, topk_group=1, bias_on_gate=True\n        )\n\n        sm_weights, sm_experts, sm_topk, sm_E = _sigmoid_topk_route(\n            moe_block, gate, hidden, gate.weight, None\n        )\n\n        sonic_scores, sonic_tok_idx, sonic_exp_idx, sonic_logits = sigmoid_topk_routing(\n            hidden, moe_block\n        )\n\n        sonic_weights_2d = sonic_scores.reshape(T, K)\n        sonic_experts_2d = sonic_exp_idx.reshape(T, K)\n\n        assert sm_topk == K\n        assert sm_E == E\n\n        # Sort experts within each token to handle different topk orderings\n        sm_sorted, sm_order = sm_experts.sort(dim=-1)\n        sonic_sorted, sonic_order = sonic_experts_2d.to(sm_experts.dtype).sort(dim=-1)\n\n        assert torch.equal(sm_sorted, sonic_sorted)\n\n        # Gather weights in sorted order for comparison\n        sm_weights_sorted = sm_weights.gather(1, sm_order)\n        sonic_weights_sorted = sonic_weights_2d.gather(1, sonic_order)\n        assert torch.allclose(sm_weights_sorted, sonic_weights_sorted, atol=1e-6)\n\n    def test_weights_match_no_groups(self):\n        \"\"\"Both implementations match without group selection (n_group=1).\"\"\"\n        from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (\n            _sigmoid_topk_route,\n        )\n        from axolotl.integrations.kernels.sonicmoe.routing import sigmoid_topk_routing\n\n        moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block(\n            E=16, K=4, n_group=1, topk_group=1, bias_on_gate=True\n        )\n\n        sm_weights, sm_experts, _, _ = _sigmoid_topk_route(\n            moe_block, gate, hidden, gate.weight, None\n        )\n        sonic_scores, _, sonic_exp_idx, _ = sigmoid_topk_routing(hidden, moe_block)\n\n        sonic_weights_2d = sonic_scores.reshape(T, K)\n        sonic_experts_2d = sonic_exp_idx.reshape(T, K)\n\n        # Sort for comparison (topk with sorted=False may differ in order)\n        sm_sorted, sm_order = sm_experts.sort(dim=-1)\n        sonic_sorted, sonic_order = sonic_experts_2d.to(sm_experts.dtype).sort(dim=-1)\n\n        assert torch.equal(sm_sorted, sonic_sorted)\n        sm_weights_sorted = sm_weights.gather(1, sm_order)\n        sonic_weights_sorted = sonic_weights_2d.gather(1, sonic_order)\n        assert torch.allclose(sm_weights_sorted, sonic_weights_sorted, atol=1e-6)\n\n    def test_bias_on_block_parity(self):\n        \"\"\"minimax_m2 style: bias on block, not gate.\"\"\"\n        from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (\n            _sigmoid_topk_route,\n        )\n        from axolotl.integrations.kernels.sonicmoe.routing import sigmoid_topk_routing\n\n        moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block(\n            E=16, K=4, n_group=1, bias_on_gate=False\n        )\n\n        sm_weights, sm_experts, _, _ = _sigmoid_topk_route(\n            moe_block, gate, hidden, gate.weight, None\n        )\n        sonic_scores, _, sonic_exp_idx, _ = sigmoid_topk_routing(hidden, moe_block)\n\n        sonic_weights_2d = sonic_scores.reshape(T, K)\n        sonic_experts_2d = sonic_exp_idx.reshape(T, K)\n\n        sm_sorted, sm_order = sm_experts.sort(dim=-1)\n        sonic_sorted, sonic_order = sonic_experts_2d.to(sm_experts.dtype).sort(dim=-1)\n\n        assert torch.equal(sm_sorted, sonic_sorted)\n        sm_weights_sorted = sm_weights.gather(1, sm_order)\n        sonic_weights_sorted = sonic_weights_2d.gather(1, sonic_order)\n        assert torch.allclose(sm_weights_sorted, sonic_weights_sorted, atol=1e-6)\n\n    def test_scaling_factor_parity(self):\n        \"\"\"routed_scaling_factor applied identically by both.\"\"\"\n        from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (\n            _sigmoid_topk_route,\n        )\n        from axolotl.integrations.kernels.sonicmoe.routing import sigmoid_topk_routing\n\n        moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block(\n            n_group=1, bias_on_gate=True\n        )\n        moe_block.routed_scaling_factor = 2.5\n\n        sm_weights, sm_experts, _, _ = _sigmoid_topk_route(\n            moe_block, gate, hidden, gate.weight, None\n        )\n        sonic_scores, _, sonic_exp_idx, _ = sigmoid_topk_routing(hidden, moe_block)\n\n        sonic_weights_2d = sonic_scores.reshape(T, K)\n        sonic_experts_2d = sonic_exp_idx.reshape(T, K)\n\n        sm_sorted, sm_order = sm_experts.sort(dim=-1)\n        sonic_sorted, sonic_order = sonic_experts_2d.to(sm_experts.dtype).sort(dim=-1)\n\n        assert torch.equal(sm_sorted, sonic_sorted)\n        sm_weights_sorted = sm_weights.gather(1, sm_order)\n        sonic_weights_sorted = sonic_weights_2d.gather(1, sonic_order)\n        assert torch.allclose(sm_weights_sorted, sonic_weights_sorted, atol=1e-6)\n\n    def test_no_renorm_parity(self):\n        \"\"\"norm_topk_prob=False produces same results in both.\"\"\"\n        from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (\n            _sigmoid_topk_route,\n        )\n        from axolotl.integrations.kernels.sonicmoe.routing import sigmoid_topk_routing\n\n        moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block(\n            n_group=1, bias_on_gate=True\n        )\n        moe_block.norm_topk_prob = False\n\n        sm_weights, sm_experts, _, _ = _sigmoid_topk_route(\n            moe_block, gate, hidden, gate.weight, None\n        )\n        sonic_scores, _, sonic_exp_idx, _ = sigmoid_topk_routing(hidden, moe_block)\n\n        sonic_weights_2d = sonic_scores.reshape(T, K)\n        sonic_experts_2d = sonic_exp_idx.reshape(T, K)\n\n        sm_sorted, sm_order = sm_experts.sort(dim=-1)\n        sonic_sorted, sonic_order = sonic_experts_2d.to(sm_experts.dtype).sort(dim=-1)\n\n        assert torch.equal(sm_sorted, sonic_sorted)\n        sm_weights_sorted = sm_weights.gather(1, sm_order)\n        sonic_weights_sorted = sonic_weights_2d.gather(1, sonic_order)\n        assert torch.allclose(sm_weights_sorted, sonic_weights_sorted, atol=1e-6)\n\n\n# ============================================================================\n# 3. Shared expert parity\n# ============================================================================\n\n\nclass TestSharedExpertParity:\n    \"\"\"Verify both _compute_shared_expert implementations behave identically.\"\"\"\n\n    @pytest.fixture(autouse=True)\n    def _require(self):\n        _require_triton()\n\n    def _get_both_fns(self):\n        from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (\n            _compute_shared_expert as scatter_compute,\n        )\n        from axolotl.integrations.kernels.sonicmoe.patch import (\n            _compute_shared_expert as sonic_compute,\n        )\n\n        return scatter_compute, sonic_compute\n\n    def test_shared_expert_singular(self):\n        scatter_fn, sonic_fn = self._get_both_fns()\n        out = torch.randn(4, 8)\n        block = SimpleNamespace(shared_expert=lambda x: out)\n        hidden = torch.randn(4, 8)\n\n        assert torch.equal(scatter_fn(block, hidden), sonic_fn(block, hidden))\n\n    def test_shared_experts_plural(self):\n        scatter_fn, sonic_fn = self._get_both_fns()\n        out = torch.randn(4, 8)\n        block = SimpleNamespace(shared_experts=lambda x: out)\n        hidden = torch.randn(4, 8)\n\n        assert torch.equal(scatter_fn(block, hidden), sonic_fn(block, hidden))\n\n    def test_shared_mlp(self):\n        scatter_fn, sonic_fn = self._get_both_fns()\n        out = torch.randn(4, 8)\n        block = SimpleNamespace(shared_mlp=lambda x: out)\n        hidden = torch.randn(4, 8)\n\n        assert torch.equal(scatter_fn(block, hidden), sonic_fn(block, hidden))\n\n    def test_no_shared_expert(self):\n        scatter_fn, sonic_fn = self._get_both_fns()\n        block = SimpleNamespace()\n        hidden = torch.randn(4, 8)\n\n        assert scatter_fn(block, hidden) is None\n        assert sonic_fn(block, hidden) is None\n\n    def test_shared_expert_gate_only_in_scattermoe(self):\n        \"\"\"ScatterMoE's _compute_shared_expert handles shared_expert_gate;\n        SonicMoE's patch.py handles it externally in the forward function.\n\n        This documents the known divergence: the scattermoe version applies\n        sigmoid gating inline, while sonicmoe applies it in the forward.\n        \"\"\"\n        scatter_fn, sonic_fn = self._get_both_fns()\n\n        H = 8\n        expert_out = torch.ones(4, H)\n        gate_fn = lambda x: torch.zeros(4, H)  # noqa: E731  # sigmoid(0) = 0.5\n\n        block = SimpleNamespace(\n            shared_expert=lambda x: expert_out,\n            shared_expert_gate=gate_fn,\n        )\n        hidden = torch.randn(4, H)\n\n        scatter_result = scatter_fn(block, hidden)\n        sonic_result = sonic_fn(block, hidden)\n\n        # ScatterMoE applies the gate: expert_out * sigmoid(0) = 0.5\n        expected_gated = expert_out * 0.5\n        assert torch.allclose(scatter_result, expected_gated, atol=1e-6)\n\n        # SonicMoE does NOT apply the gate here (it does it in the forward)\n        assert torch.equal(sonic_result, expert_out)\n\n\n# ============================================================================\n# 4. Route dispatcher parity\n# ============================================================================\n\n\nclass TestRouteDispatcherParity:\n    \"\"\"Verify _route in scattermoe dispatches correctly and matches individual fns.\"\"\"\n\n    @pytest.fixture(autouse=True)\n    def _require(self):\n        _require_triton()\n\n    def test_route_dispatches_softmax(self):\n        \"\"\"_route should use softmax when no e_score_correction_bias.\"\"\"\n        from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (\n            _route,\n            _softmax_topk_route,\n        )\n\n        moe_block, gate, hidden, T, H, E, K = _make_softmax_block()\n\n        route_w, route_e, route_k, route_E = _route(\n            moe_block, gate, hidden, gate.weight, None\n        )\n        direct_w, direct_e, direct_k, direct_E = _softmax_topk_route(\n            moe_block, gate, hidden, gate.weight, None\n        )\n\n        assert torch.equal(route_w, direct_w)\n        assert torch.equal(route_e, direct_e)\n        assert route_k == direct_k\n        assert route_E == direct_E\n\n    def test_route_dispatches_sigmoid(self):\n        \"\"\"_route should use sigmoid when e_score_correction_bias is present.\"\"\"\n        from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (\n            _route,\n            _sigmoid_topk_route,\n        )\n\n        moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block(\n            n_group=1, bias_on_gate=True\n        )\n\n        route_w, route_e, route_k, route_E = _route(\n            moe_block, gate, hidden, gate.weight, None\n        )\n        direct_w, direct_e, direct_k, direct_E = _sigmoid_topk_route(\n            moe_block, gate, hidden, gate.weight, None\n        )\n\n        assert torch.equal(route_w, direct_w)\n        assert torch.equal(route_e, direct_e)\n        assert route_k == direct_k\n        assert route_E == direct_E\n"
  },
  {
    "path": "tests/integrations/test_scattermoe_autotune_telemetry.py",
    "content": "\"\"\"Tests for scattermoe autotune telemetry integration.\n\nThese tests use mocking to verify the collection and reporting logic\nwithout requiring Triton or CUDA.\n\"\"\"\n\nimport sys\nfrom types import SimpleNamespace\nfrom unittest.mock import MagicMock, patch\n\n# ---------------------------------------------------------------------------\n# Helpers\n# ---------------------------------------------------------------------------\n\n# Simulate the hash-suffixed module name that LocalLayerRepository creates.\n_FAKE_MODULE_NAME = \"scattermoe_lora_abc123.kernels.lora_ops\"\n\n# Patch target for _find_lora_ops_module inside the collector module.\n_FIND_MODULE_PATH = (\n    \"axolotl.integrations.kernels.autotune_collector._find_lora_ops_module\"\n)\n\n\ndef _make_mock_config(kwargs, num_warps=4, num_stages=3):\n    \"\"\"Create a mock triton.Config-like object.\"\"\"\n    return SimpleNamespace(kwargs=kwargs, num_warps=num_warps, num_stages=num_stages)\n\n\ndef _make_mock_kernel(cache=None):\n    \"\"\"Create a mock autotuned kernel object with a ``.cache`` dict.\"\"\"\n    kernel = SimpleNamespace()\n    kernel.cache = cache if cache is not None else {}\n    return kernel\n\n\ndef _make_mock_lora_ops(\n    fwd_cache=None, dx_cache=None, bwd_cache=None, fused_cache=None\n):\n    \"\"\"Build a mock ``lora_ops`` module with the four kernel attributes.\"\"\"\n    mod = SimpleNamespace(\n        _scatter2scatter_lora=_make_mock_kernel(fwd_cache),\n        _scatter2scatter_lora_dX=_make_mock_kernel(dx_cache),\n        _group_bwd_lora=_make_mock_kernel(bwd_cache),\n        _group_bwd_lora_fused=_make_mock_kernel(fused_cache),\n    )\n    return mod\n\n\ndef _real_lora_ops_module_names():\n    \"\"\"Return sys.modules keys that match the lora_ops discovery pattern.\n\n    Other tests in the same xdist worker may have loaded the *real*\n    lora_ops module.  We need to temporarily hide those entries so the\n    discovery test finds only the mock we inject.\n    \"\"\"\n    return [\n        name\n        for name, mod in list(sys.modules.items())\n        if mod is not None\n        and \"lora_ops\" in name\n        and hasattr(mod, \"_scatter2scatter_lora\")\n    ]\n\n\n# =========================================================================\n# TestAutotuneCollector\n# =========================================================================\n\n\nclass TestAutotuneCollector:\n    \"\"\"Test ``collect_autotune_configs`` with mocked kernel objects.\n\n    Collection tests patch ``_find_lora_ops_module`` directly so they are\n    not affected by real ``lora_ops`` modules that other tests in the same\n    pytest-xdist worker may have loaded into ``sys.modules``.\n    \"\"\"\n\n    def test_empty_cache_returns_empty_list(self):\n        \"\"\"When no kernel has been autotuned yet, return ``[]``.\"\"\"\n        mock_lora_ops = _make_mock_lora_ops()\n\n        with patch(_FIND_MODULE_PATH, return_value=mock_lora_ops):\n            from axolotl.integrations.kernels.autotune_collector import (\n                collect_autotune_configs,\n            )\n\n            result = collect_autotune_configs()\n            assert result == []\n\n    def test_populated_cache_returns_configs(self):\n        \"\"\"When a cache entry exists, it appears in the output.\"\"\"\n        cfg = _make_mock_config(\n            {\"BLOCK_N\": 128, \"BLOCK_K\": 64}, num_warps=8, num_stages=4\n        )\n        mock_lora_ops = _make_mock_lora_ops(fwd_cache={(2048, 4096, 1024): cfg})\n\n        with patch(_FIND_MODULE_PATH, return_value=mock_lora_ops):\n            from axolotl.integrations.kernels.autotune_collector import (\n                collect_autotune_configs,\n            )\n\n            result = collect_autotune_configs()\n\n        assert len(result) == 1\n        entry = result[0]\n        assert entry[\"kernel\"] == \"scatter2scatter_lora_fwd\"\n        assert entry[\"key\"] == {\"M\": 2048, \"N\": 4096, \"K\": 1024}\n        assert entry[\"config\"][\"BLOCK_N\"] == 128\n        assert entry[\"config\"][\"BLOCK_K\"] == 64\n        assert entry[\"config\"][\"num_warps\"] == 8\n        assert entry[\"config\"][\"num_stages\"] == 4\n\n    def test_multiple_kernels_and_keys(self):\n        \"\"\"Multiple cache entries across kernels are all returned.\"\"\"\n        cfg_fwd = _make_mock_config({\"BLOCK_N\": 128, \"BLOCK_K\": 32})\n        cfg_dx = _make_mock_config({\"BLOCK_K\": 64, \"BLOCK_N\": 128}, num_warps=8)\n\n        mock_lora_ops = _make_mock_lora_ops(\n            fwd_cache={(16, 256, 128): cfg_fwd},\n            dx_cache={(16, 256, 128): cfg_dx},\n        )\n\n        with patch(_FIND_MODULE_PATH, return_value=mock_lora_ops):\n            from axolotl.integrations.kernels.autotune_collector import (\n                collect_autotune_configs,\n            )\n\n            result = collect_autotune_configs()\n\n        assert len(result) == 2\n        names = {r[\"kernel\"] for r in result}\n        assert \"scatter2scatter_lora_fwd\" in names\n        assert \"scatter2scatter_lora_dX\" in names\n\n    def test_extra_key_elements_stored(self):\n        \"\"\"Dtype or other extra elements in the cache key are captured.\"\"\"\n        cfg = _make_mock_config({\"BLOCK_N\": 64, \"BLOCK_K\": 32})\n        cache_key = (512, 1024, 256, \"float16\", \"float16\")\n\n        mock_lora_ops = _make_mock_lora_ops(fwd_cache={cache_key: cfg})\n\n        with patch(_FIND_MODULE_PATH, return_value=mock_lora_ops):\n            from axolotl.integrations.kernels.autotune_collector import (\n                collect_autotune_configs,\n            )\n\n            result = collect_autotune_configs()\n\n        assert len(result) == 1\n        key = result[0][\"key\"]\n        assert key[\"M\"] == 512\n        assert key[\"N\"] == 1024\n        assert key[\"K\"] == 256\n        assert key[\"_extra\"] == [\"float16\", \"float16\"]\n\n    def test_no_module_in_sys_modules_returns_empty(self):\n        \"\"\"If no lora_ops module is loaded, return ``[]``.\"\"\"\n        from axolotl.integrations.kernels.autotune_collector import (\n            collect_autotune_configs,\n        )\n\n        with patch(_FIND_MODULE_PATH, return_value=None):\n            result = collect_autotune_configs()\n        assert result == []\n\n    def test_finds_module_under_hash_suffixed_name(self):\n        \"\"\"Collector finds lora_ops regardless of the hash suffix.\"\"\"\n        cfg = _make_mock_config({\"BLOCK_N\": 256, \"BLOCK_K\": 128})\n        mock_lora_ops = _make_mock_lora_ops(fwd_cache={(8, 512, 64): cfg})\n\n        # Use a different hash to prove it's not hardcoded.\n        alt_name = \"scattermoe_lora_deadbeef.kernels.lora_ops\"\n\n        # Temporarily hide any real lora_ops modules that other tests in\n        # the same xdist worker may have loaded, so only our mock is found.\n        real_names = _real_lora_ops_module_names()\n        hide_patch = {name: None for name in real_names}\n\n        with patch.dict(sys.modules, {alt_name: mock_lora_ops, **hide_patch}):\n            from axolotl.integrations.kernels.autotune_collector import (\n                collect_autotune_configs,\n            )\n\n            result = collect_autotune_configs()\n\n        assert len(result) == 1\n        assert result[0][\"config\"][\"BLOCK_N\"] == 256\n\n\n# =========================================================================\n# TestAutotuneReportCallback\n# =========================================================================\n\n\nclass TestAutotuneReportCallback:\n    \"\"\"Test the callback fires once and sends the correct event.\"\"\"\n\n    def test_reports_once_on_first_step(self):\n        \"\"\"Callback should call ``send_event`` exactly once.\"\"\"\n        from axolotl.integrations.kernels.autotune_callback import (\n            AutotuneReportCallback,\n        )\n\n        cb = AutotuneReportCallback()\n        mock_state = MagicMock()\n        mock_state.global_step = 1\n\n        fake_configs = [{\"kernel\": \"test_fwd\", \"key\": {}, \"config\": {}}]\n\n        with (\n            patch(\n                \"axolotl.integrations.kernels.autotune_collector.collect_autotune_configs\",\n                return_value=fake_configs,\n            ),\n            patch(\"axolotl.telemetry.manager.TelemetryManager\") as mock_tm_cls,\n        ):\n            mock_tm = MagicMock()\n            mock_tm.enabled = True\n            mock_tm_cls.get_instance.return_value = mock_tm\n\n            cb.on_step_end(args=MagicMock(), state=mock_state, control=MagicMock())\n            assert mock_tm.send_event.call_count == 1\n\n            call_kwargs = mock_tm.send_event.call_args[1]\n            assert call_kwargs[\"event_type\"] == \"scattermoe-autotune\"\n            assert call_kwargs[\"properties\"][\"kernel_count\"] == 1\n\n            # Second call should NOT send again.\n            cb.on_step_end(args=MagicMock(), state=mock_state, control=MagicMock())\n            assert mock_tm.send_event.call_count == 1\n\n    def test_retries_until_step_5_then_gives_up(self):\n        \"\"\"If no configs found by step 5, stop retrying.\"\"\"\n        from axolotl.integrations.kernels.autotune_callback import (\n            AutotuneReportCallback,\n        )\n\n        cb = AutotuneReportCallback()\n\n        with patch(\n            \"axolotl.integrations.kernels.autotune_collector.collect_autotune_configs\",\n            return_value=[],\n        ):\n            for step in range(1, 7):\n                mock_state = MagicMock()\n                mock_state.global_step = step\n                cb.on_step_end(args=MagicMock(), state=mock_state, control=MagicMock())\n\n            assert cb._reported is True\n\n    def test_reports_on_retry_when_data_arrives(self):\n        \"\"\"If step 1 has no data but step 2 does, report at step 2.\"\"\"\n        from axolotl.integrations.kernels.autotune_callback import (\n            AutotuneReportCallback,\n        )\n\n        cb = AutotuneReportCallback()\n        fake_configs = [{\"kernel\": \"fwd\", \"key\": {}, \"config\": {}}]\n\n        call_count = 0\n\n        def _collector():\n            nonlocal call_count\n            call_count += 1\n            if call_count == 1:\n                return []\n            return fake_configs\n\n        with (\n            patch(\n                \"axolotl.integrations.kernels.autotune_collector.collect_autotune_configs\",\n                side_effect=_collector,\n            ),\n            patch(\"axolotl.telemetry.manager.TelemetryManager\") as mock_tm_cls,\n        ):\n            mock_tm = MagicMock()\n            mock_tm.enabled = True\n            mock_tm_cls.get_instance.return_value = mock_tm\n\n            # Step 1 — empty, no report\n            s1 = MagicMock()\n            s1.global_step = 1\n            cb.on_step_end(args=MagicMock(), state=s1, control=MagicMock())\n            assert mock_tm.send_event.call_count == 0\n\n            # Step 2 — data arrives, report\n            s2 = MagicMock()\n            s2.global_step = 2\n            cb.on_step_end(args=MagicMock(), state=s2, control=MagicMock())\n            assert mock_tm.send_event.call_count == 1\n\n    def test_includes_gpu_info(self):\n        \"\"\"Event properties should include GPU identification.\"\"\"\n        from axolotl.integrations.kernels.autotune_callback import (\n            AutotuneReportCallback,\n        )\n\n        cb = AutotuneReportCallback()\n        mock_state = MagicMock()\n        mock_state.global_step = 1\n\n        fake_configs = [{\"kernel\": \"fwd\", \"key\": {}, \"config\": {}}]\n        fake_gpu = {\n            \"gpu_name\": \"NVIDIA H100\",\n            \"gpu_compute_capability\": \"9.0\",\n            \"gpu_memory_bytes\": 85899345920,\n        }\n\n        fake_smem = {\"smem_capacity_bytes\": 233472}\n\n        with (\n            patch(\n                \"axolotl.integrations.kernels.autotune_collector.collect_autotune_configs\",\n                return_value=fake_configs,\n            ),\n            patch(\n                \"axolotl.integrations.kernels.autotune_callback._get_gpu_info\",\n                return_value=fake_gpu,\n            ),\n            patch(\n                \"axolotl.integrations.kernels.autotune_callback._get_smem_capacity\",\n                return_value=fake_smem,\n            ),\n            patch(\"axolotl.telemetry.manager.TelemetryManager\") as mock_tm_cls,\n        ):\n            mock_tm = MagicMock()\n            mock_tm.enabled = True\n            mock_tm_cls.get_instance.return_value = mock_tm\n\n            cb.on_step_end(args=MagicMock(), state=mock_state, control=MagicMock())\n            props = mock_tm.send_event.call_args[1][\"properties\"]\n            assert props[\"gpu_name\"] == \"NVIDIA H100\"\n            assert props[\"gpu_compute_capability\"] == \"9.0\"\n            assert props[\"gpu_memory_bytes\"] == 85899345920\n            assert props[\"smem_capacity_bytes\"] == 233472\n\n    def test_skips_send_when_telemetry_disabled(self):\n        \"\"\"If telemetry is disabled, no event is sent.\"\"\"\n        from axolotl.integrations.kernels.autotune_callback import (\n            AutotuneReportCallback,\n        )\n\n        cb = AutotuneReportCallback()\n        mock_state = MagicMock()\n        mock_state.global_step = 1\n\n        with (\n            patch(\n                \"axolotl.integrations.kernels.autotune_collector.collect_autotune_configs\",\n                return_value=[{\"kernel\": \"fwd\", \"key\": {}, \"config\": {}}],\n            ),\n            patch(\"axolotl.telemetry.manager.TelemetryManager\") as mock_tm_cls,\n        ):\n            mock_tm = MagicMock()\n            mock_tm.enabled = False\n            mock_tm_cls.get_instance.return_value = mock_tm\n\n            cb.on_step_end(args=MagicMock(), state=mock_state, control=MagicMock())\n            assert mock_tm.send_event.call_count == 0\n            # Should still mark as reported so we don't retry.\n            assert cb._reported is True\n\n\n# =========================================================================\n# TestKernelsPluginCallbackRegistration\n# =========================================================================\n\n\nclass TestKernelsPluginCallbackRegistration:\n    \"\"\"Test that ``KernelsPlugin`` registers the callback correctly.\"\"\"\n\n    def test_scattermoe_registers_callback(self):\n        \"\"\"When ``use_scattermoe=True``, plugin returns the callback.\"\"\"\n        from axolotl.integrations.kernels.autotune_callback import (\n            AutotuneReportCallback,\n        )\n        from axolotl.integrations.kernels.plugin import KernelsPlugin\n\n        plugin = KernelsPlugin()\n        cfg = MagicMock()\n        cfg.use_scattermoe = True\n        model = MagicMock()\n\n        callbacks = plugin.add_callbacks_pre_trainer(cfg, model)\n        assert len(callbacks) == 1\n        assert isinstance(callbacks[0], AutotuneReportCallback)\n\n    def test_no_scattermoe_no_callback(self):\n        \"\"\"When ``use_scattermoe=False``, plugin returns empty list.\"\"\"\n        from axolotl.integrations.kernels.plugin import KernelsPlugin\n\n        plugin = KernelsPlugin()\n        cfg = MagicMock()\n        cfg.use_scattermoe = False\n        model = MagicMock()\n\n        callbacks = plugin.add_callbacks_pre_trainer(cfg, model)\n        assert callbacks == []\n"
  },
  {
    "path": "tests/integrations/test_scattermoe_lora.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n# Copyright (c) Axolotl AI\n# Licensed under the Apache License, Version 2.0\n\n\"\"\"\nUnit tests for scattermoe-lora.\n\nTests cover:\n- KernelsArgs validator: disable_mlp_kernel\n- ParallelExperts: scaling=0.0 not treated as falsy\n- single2scatter: non-aligned K/N dimensions\n- group_compileable: coeff=None accepted\n- HFScatterMoEGatedMLP / ScatterMoEGatedMLP: return value contract\n- Routing strategy detection and sigmoid routing\n- Generic shared expert handling\n\"\"\"\n\nfrom types import SimpleNamespace\nfrom unittest.mock import patch\n\nimport pytest\nimport torch\n\n# ============================================================================\n# 1. KernelsArgs: disable_mlp_kernel validator\n# ============================================================================\n\n\nclass TestKernelsArgsValidator:\n    \"\"\"Test that disable_mlp_kernel sets both flags correctly.\n\n    These tests call the validator classmethod directly on raw dicts,\n    since lora_mlp_kernel / mlp_kernel are not declared model fields.\n    \"\"\"\n\n    def test_disables_lora_mlp_kernel_when_scattermoe(self):\n        \"\"\"lora_mlp_kernel=True gets set to False when use_scattermoe=True.\"\"\"\n        from axolotl.integrations.kernels.args import KernelsArgs\n\n        data = {\n            \"use_kernels\": True,\n            \"use_scattermoe\": True,\n            \"lora_mlp_kernel\": True,\n        }\n        result = KernelsArgs.disable_mlp_kernel(data)\n        assert result[\"lora_mlp_kernel\"] is False\n        assert result[\"mlp_kernel\"] is False\n\n    def test_mlp_kernel_disabled_without_lora(self):\n        \"\"\"Even without lora_mlp_kernel, mlp_kernel should be disabled.\"\"\"\n        from axolotl.integrations.kernels.args import KernelsArgs\n\n        data = {\n            \"use_kernels\": True,\n            \"use_scattermoe\": True,\n        }\n        result = KernelsArgs.disable_mlp_kernel(data)\n        assert result[\"mlp_kernel\"] is False\n        # lora_mlp_kernel was not in data, should not be added\n        assert \"lora_mlp_kernel\" not in result\n\n    def test_lora_mlp_kernel_false_unchanged(self):\n        \"\"\"lora_mlp_kernel=False should stay False (no warning, no change).\"\"\"\n        from axolotl.integrations.kernels.args import KernelsArgs\n\n        data = {\n            \"use_kernels\": True,\n            \"use_scattermoe\": True,\n            \"lora_mlp_kernel\": False,\n        }\n        result = KernelsArgs.disable_mlp_kernel(data)\n        assert result[\"lora_mlp_kernel\"] is False\n\n    def test_no_change_when_scattermoe_disabled(self):\n        \"\"\"When use_scattermoe is not True, nothing should be changed.\"\"\"\n        from axolotl.integrations.kernels.args import KernelsArgs\n\n        data = {\n            \"use_kernels\": True,\n            \"use_scattermoe\": False,\n            \"lora_mlp_kernel\": True,\n        }\n        result = KernelsArgs.disable_mlp_kernel(data)\n        assert result[\"lora_mlp_kernel\"] is True\n\n\nclass TestParallelExpertsScaling:\n    \"\"\"Test that scaling=0.0 is preserved and not overridden to 1.0.\"\"\"\n\n    def test_scaling_zero_preserved(self):\n        \"\"\"scaling=0.0 should be passed as 0.0, not replaced with 1.0.\"\"\"\n        pytest.importorskip(\"triton\")\n        from axolotl.integrations.kernels.libs.scattermoe_lora.lora_ops import (\n            ParallelExperts,\n        )\n\n        pe = ParallelExperts(num_experts=2, input_size=4, output_size=4)\n        pe.set_lora(\n            lora_A=torch.randn(4, 4),\n            lora_B=torch.randn(4, 4),\n            scaling=0.0,\n        )\n        assert pe._lora_scaling == 0.0\n\n        # Patch parallel_linear_lora to capture the scaling arg\n        with patch(\n            \"axolotl.integrations.kernels.libs.scattermoe_lora.lora_ops.parallel_linear_lora\"\n        ) as mock_pll:\n            mock_pll.return_value = torch.randn(4, 4)\n            # Create dummy routing tensors\n            pe.forward(\n                inputs=torch.randn(2, 4),\n                k=1,\n                sorted_expert_idxs=torch.tensor([0, 0, 1, 1]),\n                sorted_scattered_idxs=torch.tensor([0, 1, 0, 1]),\n                expert_offsets=torch.tensor([2, 4]),\n            )\n            # Check that scaling=0.0 was passed, not 1.0\n            call_kwargs = mock_pll.call_args\n            assert (\n                call_kwargs.kwargs.get(\"scaling\") == 0.0\n                or call_kwargs[1].get(\"scaling\") == 0.0\n            ), f\"Expected scaling=0.0 but got {call_kwargs}\"\n\n    def test_scaling_none_defaults_to_one(self):\n        \"\"\"scaling=None (no LoRA attached) should default to 1.0.\"\"\"\n        pytest.importorskip(\"triton\")\n        from axolotl.integrations.kernels.libs.scattermoe_lora.lora_ops import (\n            ParallelExperts,\n        )\n\n        pe = ParallelExperts(num_experts=2, input_size=4, output_size=4)\n        # No set_lora called, so _lora_scaling is None\n\n        with patch(\n            \"axolotl.integrations.kernels.libs.scattermoe_lora.lora_ops.parallel_linear_lora\"\n        ) as mock_pll:\n            mock_pll.return_value = torch.randn(4, 4)\n            pe.forward(\n                inputs=torch.randn(2, 4),\n                k=1,\n                sorted_expert_idxs=torch.tensor([0, 0, 1, 1]),\n                sorted_scattered_idxs=torch.tensor([0, 1, 0, 1]),\n                expert_offsets=torch.tensor([2, 4]),\n            )\n            call_kwargs = mock_pll.call_args\n            scaling_val = call_kwargs.kwargs.get(\"scaling\") or call_kwargs[1].get(\n                \"scaling\"\n            )\n            assert scaling_val == 1.0, (\n                f\"Expected scaling=1.0 for None but got {scaling_val}\"\n            )\n\n    def test_scaling_positive_preserved(self):\n        \"\"\"Normal positive scaling should be preserved.\"\"\"\n        pytest.importorskip(\"triton\")\n        from axolotl.integrations.kernels.libs.scattermoe_lora.lora_ops import (\n            ParallelExperts,\n        )\n\n        pe = ParallelExperts(num_experts=2, input_size=4, output_size=4)\n        pe.set_lora(\n            lora_A=torch.randn(4, 4),\n            lora_B=torch.randn(4, 4),\n            scaling=0.5,\n        )\n\n        with patch(\n            \"axolotl.integrations.kernels.libs.scattermoe_lora.lora_ops.parallel_linear_lora\"\n        ) as mock_pll:\n            mock_pll.return_value = torch.randn(4, 4)\n            pe.forward(\n                inputs=torch.randn(2, 4),\n                k=1,\n                sorted_expert_idxs=torch.tensor([0, 0, 1, 1]),\n                sorted_scattered_idxs=torch.tensor([0, 1, 0, 1]),\n                expert_offsets=torch.tensor([2, 4]),\n            )\n            call_kwargs = mock_pll.call_args\n            scaling_val = call_kwargs.kwargs.get(\"scaling\") or call_kwargs[1].get(\n                \"scaling\"\n            )\n            assert scaling_val == 0.5\n\n\n# ============================================================================\n# 4. single2scatter: non-aligned K/N dimensions (GPU only)\n# ============================================================================\n\n\n@pytest.mark.skipif(not torch.cuda.is_available(), reason=\"CUDA not available\")\nclass TestSingle2ScatterBounds:\n    \"\"\"Test single2scatter with non-aligned dimensions.\"\"\"\n\n    def test_non_aligned_k(self):\n        \"\"\"K not a multiple of BLOCK_K should produce correct results.\"\"\"\n        from axolotl.integrations.kernels.libs.scattermoe_lora.kernels.single import (\n            single2scatter,\n        )\n\n        E, K, N = 2, 100, 128  # K=100 not a multiple of 128\n        W = torch.randn(E, K, N, device=\"cuda\", dtype=torch.float32)\n        X = torch.randn(1, K, device=\"cuda\", dtype=torch.float32)\n        expert_idxs = torch.tensor([[0, 1]], device=\"cuda\", dtype=torch.long)\n\n        Y = single2scatter(X, W, expert_idxs)\n        assert Y.shape == (2, N)\n\n        # Verify against manual computation\n        Y_ref_0 = X[0] @ W[0]\n        Y_ref_1 = X[0] @ W[1]\n        torch.testing.assert_close(Y[0], Y_ref_0, atol=1e-2, rtol=1e-2)\n        torch.testing.assert_close(Y[1], Y_ref_1, atol=1e-2, rtol=1e-2)\n\n    def test_non_aligned_n(self):\n        \"\"\"N not a multiple of BLOCK_N should produce correct results.\"\"\"\n        from axolotl.integrations.kernels.libs.scattermoe_lora.kernels.single import (\n            single2scatter,\n        )\n\n        E, K, N = 2, 128, 100  # N=100 not a multiple of 128\n        W = torch.randn(E, K, N, device=\"cuda\", dtype=torch.float32)\n        X = torch.randn(1, K, device=\"cuda\", dtype=torch.float32)\n        expert_idxs = torch.tensor([[0, 1]], device=\"cuda\", dtype=torch.long)\n\n        Y = single2scatter(X, W, expert_idxs)\n        assert Y.shape == (2, N)\n\n        Y_ref_0 = X[0] @ W[0]\n        Y_ref_1 = X[0] @ W[1]\n        torch.testing.assert_close(Y[0], Y_ref_0, atol=1e-2, rtol=1e-2)\n        torch.testing.assert_close(Y[1], Y_ref_1, atol=1e-2, rtol=1e-2)\n\n    def test_non_aligned_both(self):\n        \"\"\"Both K and N not aligned should produce correct results.\"\"\"\n        from axolotl.integrations.kernels.libs.scattermoe_lora.kernels.single import (\n            single2scatter,\n        )\n\n        E, K, N = 2, 100, 100  # Neither aligned to 128\n        W = torch.randn(E, K, N, device=\"cuda\", dtype=torch.float32)\n        X = torch.randn(1, K, device=\"cuda\", dtype=torch.float32)\n        expert_idxs = torch.tensor([[0, 1]], device=\"cuda\", dtype=torch.long)\n\n        Y = single2scatter(X, W, expert_idxs)\n        assert Y.shape == (2, N)\n\n        Y_ref_0 = X[0] @ W[0]\n        Y_ref_1 = X[0] @ W[1]\n        torch.testing.assert_close(Y[0], Y_ref_0, atol=1e-2, rtol=1e-2)\n        torch.testing.assert_close(Y[1], Y_ref_1, atol=1e-2, rtol=1e-2)\n\n\n# ============================================================================\n# 5. group_compileable: coeff=None accepted\n# ============================================================================\n\n\n@pytest.mark.skipif(not torch.cuda.is_available(), reason=\"CUDA not available\")\nclass TestGroupCoeffNone:\n    \"\"\"Test that group() works with coeff=None.\"\"\"\n\n    def test_group_with_none_coeff(self):\n        \"\"\"group() should accept coeff=None without errors.\"\"\"\n        from axolotl.integrations.kernels.libs.scattermoe_lora.kernels.ops import group\n\n        M, K = 4, 32\n        A = torch.randn(M, K, device=\"cuda\", dtype=torch.float32)\n        sorted_expert_idxs = torch.tensor([0, 1, 2, 3], device=\"cuda\", dtype=torch.long)\n\n        # This should not raise a TypeError\n        Y = group(A, sorted_expert_idxs, coeff=None, fan_out=1)\n        assert Y.shape == (M, K)\n\n    def test_group_with_coeff(self):\n        \"\"\"group() should also work with actual coeff values.\"\"\"\n        from axolotl.integrations.kernels.libs.scattermoe_lora.kernels.ops import group\n\n        M, K = 4, 32\n        A = torch.randn(M, K, device=\"cuda\", dtype=torch.float32)\n        sorted_expert_idxs = torch.tensor([0, 1, 2, 3], device=\"cuda\", dtype=torch.long)\n        coeff = torch.ones(M, device=\"cuda\", dtype=torch.float32) * 0.5\n\n        Y = group(A, sorted_expert_idxs, coeff=coeff, fan_out=1)\n        assert Y.shape == (M, K)\n\n\n# ============================================================================\n# 6. Layer return value contracts\n# ============================================================================\n\n\nclass TestLayerReturnValues:\n    \"\"\"Test that layer forward methods return the correct types.\"\"\"\n\n    def test_hf_scatter_moe_returns_single_tensor(self):\n        \"\"\"HFScatterMoEGatedMLP.forward should return a single tensor, not a tuple.\"\"\"\n        pytest.importorskip(\"triton\")\n        # Verify the forward method signature and return annotation\n        import inspect\n\n        from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (\n            HFScatterMoEGatedMLP,\n        )\n\n        sig = inspect.signature(HFScatterMoEGatedMLP.forward)\n        # It's a staticmethod taking (self, layer_input)\n        params = list(sig.parameters.keys())\n        assert \"self\" in params\n        assert \"layer_input\" in params\n\n    def test_scatter_moe_gated_mlp_docstring_no_router_logits(self):\n        \"\"\"ScatterMoEGatedMLP.forward docstring should not mention router logits as return.\"\"\"\n        pytest.importorskip(\"triton\")\n        from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (\n            ScatterMoEGatedMLP,\n        )\n\n        docstring = ScatterMoEGatedMLP.forward.__doc__\n        assert docstring is not None\n        # The docstring should mention output tensor but NOT router logits\n        assert \"Output tensor\" in docstring or \"output tensor\" in docstring.lower()\n        assert \"Router logits\" not in docstring, (\n            \"Docstring should not mention 'Router logits' in Returns section\"\n        )\n\n\n# ============================================================================\n# 7. Routing strategy detection and sigmoid routing\n# ============================================================================\n\n\ndef _make_softmax_gate(E=4, H=16, K=2):\n    \"\"\"Create a mock softmax-style gate (Qwen/OLMoE).\"\"\"\n    return SimpleNamespace(\n        weight=torch.randn(E, H),\n        top_k=K,\n        num_experts=E,\n        norm_topk_prob=True,\n    )\n\n\ndef _make_sigmoid_gate_with_bias(E=16, H=16):\n    \"\"\"Create a mock sigmoid-style gate with e_score_correction_bias on gate.\"\"\"\n    return SimpleNamespace(\n        weight=torch.randn(E, H),\n        e_score_correction_bias=torch.zeros(E),\n    )\n\n\ndef _make_sigmoid_moe_block(\n    T=8, H=16, E=16, K=4, n_group=2, topk_group=1, bias_on_gate=True\n):\n    \"\"\"Create a mock GLM/DeepSeek-style MoE block for sigmoid routing tests.\"\"\"\n    if bias_on_gate:\n        gate = SimpleNamespace(\n            weight=torch.randn(E, H),\n            e_score_correction_bias=torch.zeros(E),\n        )\n        moe_block = SimpleNamespace(\n            gate=gate,\n            top_k=K,\n            n_routed_experts=E,\n            n_group=n_group,\n            topk_group=topk_group,\n            norm_topk_prob=True,\n            routed_scaling_factor=1.0,\n        )\n    else:\n        # minimax_m2 style: bias on block, not gate\n        gate = SimpleNamespace(\n            weight=torch.randn(E, H),\n            top_k=K,\n        )\n        moe_block = SimpleNamespace(\n            gate=gate,\n            top_k=K,\n            e_score_correction_bias=torch.zeros(E),\n        )\n    return moe_block, T, H, E, K\n\n\ndef _skip_without_triton():\n    pytest.importorskip(\"triton\")\n\n\nclass TestSigmoidRoutingInScatterMoE:\n    \"\"\"Test _sigmoid_topk_route from layers.py.\"\"\"\n\n    @pytest.fixture(autouse=True)\n    def _require_triton(self):\n        _skip_without_triton()\n\n    def test_output_shapes(self):\n        from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (\n            _sigmoid_topk_route,\n        )\n\n        moe_block, T, H, E, K = _make_sigmoid_moe_block()\n        gate = moe_block.gate\n        hidden = torch.randn(T, H)\n\n        weights, experts, top_k, num_experts = _sigmoid_topk_route(\n            moe_block, gate, hidden, gate.weight, None\n        )\n\n        assert weights.shape == (T, K)\n        assert experts.shape == (T, K)\n        assert top_k == K\n        assert num_experts == E\n\n    def test_weights_nonnegative(self):\n        from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (\n            _sigmoid_topk_route,\n        )\n\n        moe_block, T, H, E, K = _make_sigmoid_moe_block()\n        gate = moe_block.gate\n        hidden = torch.randn(T, H)\n\n        weights, _, _, _ = _sigmoid_topk_route(\n            moe_block, gate, hidden, gate.weight, None\n        )\n        assert (weights >= 0).all()\n\n    def test_group_selection_restricts_experts(self):\n        \"\"\"With n_group=4, topk_group=1, experts should be from selected groups.\"\"\"\n        from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (\n            _sigmoid_topk_route,\n        )\n\n        moe_block, T, H, E, K = _make_sigmoid_moe_block(\n            E=16, K=2, n_group=4, topk_group=1\n        )\n        gate = moe_block.gate\n        hidden = torch.randn(T, H)\n\n        _, expert_idx, _, _ = _sigmoid_topk_route(\n            moe_block, gate, hidden, gate.weight, None\n        )\n\n        # Each token's experts should fall within a single group (size E//n_group=4)\n        for t in range(T):\n            experts_t = expert_idx[t]\n            groups = experts_t // (E // moe_block.n_group)\n            assert (groups == groups[0]).all()\n\n    def test_scaling_factor_applied(self):\n        from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (\n            _sigmoid_topk_route,\n        )\n\n        moe_block, T, H, E, K = _make_sigmoid_moe_block(n_group=1)\n        gate = moe_block.gate\n        hidden = torch.randn(T, H)\n\n        weights_1x, _, _, _ = _sigmoid_topk_route(\n            moe_block, gate, hidden, gate.weight, None\n        )\n\n        moe_block.routed_scaling_factor = 2.0\n        weights_2x, _, _, _ = _sigmoid_topk_route(\n            moe_block, gate, hidden, gate.weight, None\n        )\n\n        assert torch.allclose(weights_2x, weights_1x * 2.0, atol=1e-5)\n\n    def test_bias_on_gate(self):\n        \"\"\"e_score_correction_bias on gate is found.\"\"\"\n        from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (\n            _sigmoid_topk_route,\n        )\n\n        moe_block, T, H, E, K = _make_sigmoid_moe_block(bias_on_gate=True)\n        gate = moe_block.gate\n        hidden = torch.randn(T, H)\n\n        weights, experts, _, _ = _sigmoid_topk_route(\n            moe_block, gate, hidden, gate.weight, None\n        )\n        assert weights.shape == (T, K)\n\n    def test_bias_on_block(self):\n        \"\"\"e_score_correction_bias on moe_block (not gate) is found.\"\"\"\n        from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (\n            _sigmoid_topk_route,\n        )\n\n        moe_block, T, H, E, K = _make_sigmoid_moe_block(bias_on_gate=False)\n        gate = moe_block.gate\n        hidden = torch.randn(T, H)\n\n        weights, experts, _, _ = _sigmoid_topk_route(\n            moe_block, gate, hidden, gate.weight, None\n        )\n        assert weights.shape == (T, K)\n\n    def test_gate_lora_delta_applied(self):\n        \"\"\"Gate LoRA delta should affect routing logits.\"\"\"\n        from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (\n            _sigmoid_topk_route,\n        )\n\n        moe_block, T, H, E, K = _make_sigmoid_moe_block(n_group=1)\n        gate = moe_block.gate\n        hidden = torch.randn(T, H)\n\n        weights_no_lora, _, _, _ = _sigmoid_topk_route(\n            moe_block, gate, hidden, gate.weight, None\n        )\n\n        # Large delta should change the results\n        delta = torch.randn(E, H) * 10.0\n        weights_with_lora, _, _, _ = _sigmoid_topk_route(\n            moe_block, gate, hidden, gate.weight, delta\n        )\n\n        assert not torch.equal(weights_no_lora, weights_with_lora)\n\n    def test_no_bias_does_not_crash(self):\n        \"\"\"Calling _sigmoid_topk_route with no e_score_correction_bias should not crash.\"\"\"\n        from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (\n            _sigmoid_topk_route,\n        )\n\n        T, H, E, K = 8, 16, 8, 2\n        gate = SimpleNamespace(weight=torch.randn(E, H))\n        moe_block = SimpleNamespace(\n            gate=gate,\n            top_k=K,\n            n_routed_experts=E,\n            n_group=1,\n            norm_topk_prob=True,\n            routed_scaling_factor=1.0,\n        )\n        hidden = torch.randn(T, H)\n\n        weights, experts, top_k, num_experts = _sigmoid_topk_route(\n            moe_block, gate, hidden, gate.weight, None\n        )\n        assert weights.shape == (T, K)\n        assert experts.shape == (T, K)\n        # Without bias, scores_for_choice == sigmoid(logits) — all positive\n        assert (weights >= 0).all()\n\n    def test_missing_topk_group_defaults_to_n_group(self):\n        \"\"\"When topk_group is absent but n_group > 1, should default to n_group (no-op masking).\"\"\"\n        from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (\n            _sigmoid_topk_route,\n        )\n\n        T, H, E, K, n_group = 8, 16, 16, 2, 4\n        gate = SimpleNamespace(\n            weight=torch.randn(E, H),\n            e_score_correction_bias=torch.zeros(E),\n        )\n        # Intentionally omit topk_group\n        moe_block = SimpleNamespace(\n            gate=gate,\n            top_k=K,\n            n_routed_experts=E,\n            n_group=n_group,\n            norm_topk_prob=True,\n            routed_scaling_factor=1.0,\n        )\n        hidden = torch.randn(T, H)\n\n        # Should not raise AttributeError; defaults topk_group to n_group\n        weights, experts, top_k_out, num_experts = _sigmoid_topk_route(\n            moe_block, gate, hidden, gate.weight, None\n        )\n        assert weights.shape == (T, K)\n        assert experts.shape == (T, K)\n\n\nclass TestRoutingStrategyDetection:\n    \"\"\"Test that _route dispatches to the correct strategy.\"\"\"\n\n    @pytest.fixture(autouse=True)\n    def _require_triton(self):\n        _skip_without_triton()\n\n    def test_softmax_for_qwen_style(self):\n        \"\"\"Block without e_score_correction_bias should use softmax.\"\"\"\n        from axolotl.integrations.kernels.libs.scattermoe_lora.layers import _route\n\n        gate = _make_softmax_gate(E=4, H=16, K=2)\n        moe_block = SimpleNamespace(gate=gate)\n        hidden = torch.randn(8, 16)\n\n        weights, experts, top_k, num_experts = _route(\n            moe_block, gate, hidden, gate.weight, None\n        )\n\n        assert weights.shape == (8, 2)\n        assert experts.shape == (8, 2)\n        assert top_k == 2\n        assert num_experts == 4\n        per_token_sums = weights.sum(dim=-1)\n        assert torch.allclose(per_token_sums, torch.ones(8), atol=1e-5)\n\n    def test_sigmoid_for_glm_style(self):\n        \"\"\"Block with e_score_correction_bias on gate should use sigmoid.\"\"\"\n        from axolotl.integrations.kernels.libs.scattermoe_lora.layers import _route\n\n        moe_block, T, H, E, K = _make_sigmoid_moe_block(bias_on_gate=True, n_group=1)\n        gate = moe_block.gate\n        hidden = torch.randn(T, H)\n\n        weights, experts, top_k, num_experts = _route(\n            moe_block, gate, hidden, gate.weight, None\n        )\n\n        assert weights.shape == (T, K)\n        assert experts.shape == (T, K)\n        assert (weights >= 0).all()\n\n    def test_sigmoid_for_minimax_m2_style(self):\n        \"\"\"Block with e_score_correction_bias on block (not gate) should use sigmoid.\"\"\"\n        from axolotl.integrations.kernels.libs.scattermoe_lora.layers import _route\n\n        moe_block, T, H, E, K = _make_sigmoid_moe_block(bias_on_gate=False)\n        gate = moe_block.gate\n        hidden = torch.randn(T, H)\n\n        weights, experts, top_k, num_experts = _route(\n            moe_block, gate, hidden, gate.weight, None\n        )\n\n        assert weights.shape == (T, K)\n        assert (weights >= 0).all()\n\n\n# ============================================================================\n# 8. Generic shared expert handling\n# ============================================================================\n\n\nclass TestGenericSharedExpert:\n    \"\"\"Test _compute_shared_expert from layers.py.\"\"\"\n\n    @pytest.fixture(autouse=True)\n    def _require_triton(self):\n        _skip_without_triton()\n\n    def test_shared_expert_singular(self):\n        \"\"\"shared_expert attribute (Qwen2MoE style).\"\"\"\n        from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (\n            _compute_shared_expert,\n        )\n\n        called = torch.randn(4, 8)\n        moe_block = SimpleNamespace(\n            shared_expert=lambda x: called,\n        )\n        result = _compute_shared_expert(moe_block, torch.randn(4, 8))\n        assert torch.equal(result, called)\n\n    def test_shared_experts_plural(self):\n        \"\"\"shared_experts attribute (DeepSeek V3 style).\"\"\"\n        from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (\n            _compute_shared_expert,\n        )\n\n        called = torch.randn(4, 8)\n        moe_block = SimpleNamespace(\n            shared_experts=lambda x: called,\n        )\n        result = _compute_shared_expert(moe_block, torch.randn(4, 8))\n        assert torch.equal(result, called)\n\n    def test_shared_mlp(self):\n        \"\"\"shared_mlp attribute (Hunyuan style).\"\"\"\n        from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (\n            _compute_shared_expert,\n        )\n\n        called = torch.randn(4, 8)\n        moe_block = SimpleNamespace(\n            shared_mlp=lambda x: called,\n        )\n        result = _compute_shared_expert(moe_block, torch.randn(4, 8))\n        assert torch.equal(result, called)\n\n    def test_shared_expert_with_gate(self):\n        \"\"\"shared_expert + shared_expert_gate applies sigmoid gating.\"\"\"\n        from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (\n            _compute_shared_expert,\n        )\n\n        H = 8\n        expert_out = torch.ones(4, H)\n        gate_fn = lambda x: torch.zeros(4, H)  # noqa: E731\n\n        moe_block = SimpleNamespace(\n            shared_expert=lambda x: expert_out,\n            shared_expert_gate=gate_fn,\n        )\n        result = _compute_shared_expert(moe_block, torch.randn(4, H))\n        expected = expert_out * 0.5  # sigmoid(0) = 0.5\n        assert torch.allclose(result, expected, atol=1e-6)\n\n    def test_no_shared_expert(self):\n        \"\"\"No shared expert attributes returns None.\"\"\"\n        from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (\n            _compute_shared_expert,\n        )\n\n        moe_block = SimpleNamespace()\n        result = _compute_shared_expert(moe_block, torch.randn(4, 8))\n        assert result is None\n"
  },
  {
    "path": "tests/integrations/test_scattermoe_lora_kernels.py",
    "content": "# SPDX-License-Identifier: Apache-2.0\n# Copyright (c) Axolotl AI\n# Licensed under the Apache License, Version 2.0\n\n\"\"\"\nUnit tests for ScatterMoE LoRA Triton kernels.\n\nTests correctness of:\n  - scatter2scatter_lora (forward)\n  - scatter2scatter_lora_dX (backward input gradient)\n  - group_bwd_lora (backward LoRA weight gradients via split dA/dB)\n  - ScatterMoELoRA autograd function (full forward + backward)\n\nEach kernel is tested against a pure PyTorch per-expert-loop reference\nimplementation at multiple model shapes and LoRA ranks.\n\"\"\"\n\nimport pytest\nimport torch\n\nfrom axolotl.integrations.kernels.libs.scattermoe_lora.kernels import (\n    lora_ops,\n    ops as base_ops,\n)\nfrom axolotl.integrations.kernels.libs.scattermoe_lora.parallel_experts import (\n    flatten_sort_count,\n)\nfrom axolotl.integrations.kernels.libs.scattermoe_lora.parallel_linear_lora import (\n    ScatterMoELoRA,\n)\n\nDEVICE = \"cuda\"\nDTYPE = torch.bfloat16\n\n\ndef _requires_cuda():\n    return pytest.mark.skipif(\n        not torch.cuda.is_available(), reason=\"CUDA not available\"\n    )\n\n\npytestmark = _requires_cuda()\n\n\n# ─── Helpers ─────────────────────────────────────────────────────────────────\n\n\ndef _setup(E, K, N, T, top_k, R, seed=42):\n    \"\"\"Create synthetic expert weights, LoRA, routing, and grouped inputs.\"\"\"\n    torch.manual_seed(seed)\n    x = torch.randn(T, K, device=DEVICE, dtype=DTYPE)\n    W = torch.randn(E, K, N, device=DEVICE, dtype=DTYPE) * 0.02\n    lora_A = torch.randn(R * E, K, device=DEVICE, dtype=DTYPE) * 0.01\n    lora_B = torch.randn(N, R * E, device=DEVICE, dtype=DTYPE) * 0.01\n    logits = torch.randn(T, E, device=DEVICE)\n    _, top_idx = torch.topk(torch.softmax(logits, dim=-1), top_k, dim=-1)\n    sei, ssi, eo = flatten_sort_count(top_idx, E)\n    return x, W, lora_A, lora_B, sei, ssi, eo\n\n\ndef _reference_fwd(x, W, sei, ssi, eo, k, lora_A, lora_B, scaling, E):\n    \"\"\"Per-expert loop reference: Y = X@W + scaling*(X@A^T)@B^T.\"\"\"\n    grouped_x = base_ops.group(x, ssi, fan_out=k)\n    M, N = grouped_x.size(0), W.size(2)\n    R = lora_A.size(0) // E\n    out = torch.zeros(M, N, device=DEVICE, dtype=DTYPE)\n    for e in range(E):\n        s = eo[e - 1].item() if e > 0 else 0\n        end = eo[e].item()\n        if s == end:\n            continue\n        xe = grouped_x[s:end].float()\n        we = W[e].float()\n        ae = lora_A[e * R : (e + 1) * R].float()\n        be = lora_B[:, e * R : (e + 1) * R].float()\n        out[s:end] = (xe @ we + scaling * (xe @ ae.T) @ be.T).to(DTYPE)\n    result = torch.zeros(M, N, device=DEVICE, dtype=DTYPE)\n    result[ssi] = out\n    return result\n\n\ndef _reference_dX(dy_grouped, W, sei, ssi, eo, lora_A, lora_B, scaling, E):\n    \"\"\"Per-expert loop reference: dX = dY@W^T + scaling*(dY@B)@A.\"\"\"\n    M, K = dy_grouped.size(0), W.size(1)\n    R = lora_A.size(0) // E\n    out = torch.zeros(M, K, device=DEVICE, dtype=DTYPE)\n    for e in range(E):\n        s = eo[e - 1].item() if e > 0 else 0\n        end = eo[e].item()\n        if s == end:\n            continue\n        dye = dy_grouped[s:end].float()\n        we = W[e].float()\n        ae = lora_A[e * R : (e + 1) * R].float()\n        be = lora_B[:, e * R : (e + 1) * R].float()\n        out[s:end] = (dye @ we.T + scaling * (dye @ be) @ ae).to(DTYPE)\n    result = torch.zeros(M, K, device=DEVICE, dtype=DTYPE)\n    result[ssi] = out\n    return result\n\n\ndef _reference_bwd_lora(dy, grouped_x, lora_A, lora_B, eo, E, scaling):\n    \"\"\"Per-expert loop reference: dA, dB for LoRA weight gradients.\"\"\"\n    R = lora_A.size(0) // E\n    dA = torch.zeros_like(lora_A)\n    dB = torch.zeros_like(lora_B)\n    for e in range(E):\n        s = eo[e - 1].item() if e > 0 else 0\n        end = eo[e].item()\n        if s == end:\n            continue\n        xe = grouped_x[s:end].float()\n        dye = dy[s:end].float()\n        ae = lora_A[e * R : (e + 1) * R].float()\n        be = lora_B[:, e * R : (e + 1) * R].float()\n        dA[e * R : (e + 1) * R] = (scaling * (dye @ be).T @ xe).to(DTYPE)\n        dB[:, e * R : (e + 1) * R] = (scaling * dye.T @ (xe @ ae.T)).to(DTYPE)\n    return dA, dB\n\n\n# ─── Model shape configs ────────────────────────────────────────────────────\n\n# (E, K, N, T, top_k, R, description)\nCONFIGS_SMALL = [\n    (32, 128, 64, 64, 2, 4, \"tiny\"),\n    (64, 256, 128, 128, 4, 8, \"small\"),\n]\n\nCONFIGS_REAL = [\n    (256, 2048, 1024, 2048, 8, 16, \"qwen35_gate_up\"),\n    (256, 512, 2048, 2048, 8, 16, \"qwen35_down\"),\n    (64, 2048, 2048, 2048, 8, 16, \"olmoe_gate_up\"),\n    (128, 2048, 1536, 2048, 8, 16, \"qwen3_gate_up\"),\n]\n\nSCALING = 2.0\n\n\n# ─── Forward tests ──────────────────────────────────────────────────────────\n\n\nclass TestScatter2ScatterLoRAForward:\n    \"\"\"Test scatter2scatter_lora forward kernel vs reference.\"\"\"\n\n    @pytest.fixture(params=CONFIGS_SMALL + CONFIGS_REAL)\n    def config(self, request):\n        return request.param\n\n    def test_matches_reference(self, config):\n        E, K, N, T, k, R, desc = config\n        x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R)\n\n        kernel_out = lora_ops.scatter2scatter_lora(\n            X=x,\n            W=W,\n            sorted_expert_idxs=sei,\n            sorted_scattered_idxs=ssi,\n            k=k,\n            lora_A=lA,\n            lora_B=lB,\n            scaling=SCALING,\n        )\n        ref_out = _reference_fwd(x, W, sei, ssi, eo, k, lA, lB, SCALING, E)\n\n        err = (kernel_out.float() - ref_out.float()).abs().max().item()\n        assert err < 1.0, f\"[{desc}] fwd max_err={err}\"\n\n    def test_output_shape(self, config):\n        E, K, N, T, k, R, desc = config\n        x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R)\n\n        out = lora_ops.scatter2scatter_lora(\n            X=x,\n            W=W,\n            sorted_expert_idxs=sei,\n            sorted_scattered_idxs=ssi,\n            k=k,\n            lora_A=lA,\n            lora_B=lB,\n            scaling=SCALING,\n        )\n        assert out.shape == (T * k, N)\n        assert out.dtype == DTYPE\n\n\n# ─── Backward dX tests ──────────────────────────────────────────────────────\n\n\nclass TestScatter2ScatterLoRADX:\n    \"\"\"Test scatter2scatter_lora_dX backward kernel vs reference.\"\"\"\n\n    @pytest.fixture(params=CONFIGS_SMALL + CONFIGS_REAL)\n    def config(self, request):\n        return request.param\n\n    def test_matches_reference(self, config):\n        E, K, N, T, k, R, desc = config\n        x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R)\n        gx = base_ops.group(x, ssi, fan_out=k)\n        dy = torch.randn(gx.size(0), N, device=DEVICE, dtype=DTYPE)\n\n        kernel_dx = lora_ops.scatter2scatter_lora_dX(\n            DY=dy,\n            W=W,\n            sorted_expert_idxs=sei,\n            sorted_scattered_idxs=ssi,\n            k=1,\n            lora_A=lA,\n            lora_B=lB,\n            scaling=SCALING,\n            dy_grouped=True,\n            dx_grouped=False,\n        )\n        ref_dx = _reference_dX(dy, W, sei, ssi, eo, lA, lB, SCALING, E)\n\n        err = (kernel_dx.float() - ref_dx.float()).abs().max().item()\n        assert err < 1.0, f\"[{desc}] dX max_err={err}\"\n\n\n# ─── Backward LoRA gradient tests ───────────────────────────────────────────\n\n\nclass TestGroupBwdLoRA:\n    \"\"\"Test group_bwd_lora (split dA/dB kernel) vs reference.\"\"\"\n\n    @pytest.fixture(params=CONFIGS_SMALL + CONFIGS_REAL)\n    def config(self, request):\n        return request.param\n\n    def test_matches_reference(self, config):\n        E, K, N, T, k, R, desc = config\n        x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R)\n        gx = base_ops.group(x, ssi, fan_out=k)\n        dy = torch.randn(gx.size(0), N, device=DEVICE, dtype=DTYPE)\n\n        kern_dA, kern_dB = lora_ops.group_bwd_lora(\n            DY=dy,\n            X=gx,\n            lora_A=lA,\n            lora_B=lB,\n            expert_offsets=eo,\n            E=E,\n            scaling=SCALING,\n        )\n        ref_dA, ref_dB = _reference_bwd_lora(dy, gx, lA, lB, eo, E, SCALING)\n\n        # Use norm-relative error: bf16 accumulation order differs between\n        # kernel (tiled + different reduction order) and reference (per-expert\n        # fp32 loop), so max absolute error can be large on individual elements\n        # while the overall tensor is correct.\n        dA_norm_err = (\n            (kern_dA.float() - ref_dA.float()).norm() / (ref_dA.float().norm() + 1e-6)\n        ).item()\n        dB_norm_err = (\n            (kern_dB.float() - ref_dB.float()).norm() / (ref_dB.float().norm() + 1e-6)\n        ).item()\n        assert dA_norm_err < 0.01, f\"[{desc}] dA norm_rel_err={dA_norm_err}\"\n        assert dB_norm_err < 0.01, f\"[{desc}] dB norm_rel_err={dB_norm_err}\"\n\n    def test_zero_expert_tokens(self):\n        \"\"\"Experts with zero routed tokens produce zero gradients.\"\"\"\n        E, K, N, R = 8, 64, 32, 4\n        torch.manual_seed(42)\n        # Route all tokens to expert 0 only\n        T, k = 16, 1\n        top_idx = torch.zeros(T, k, dtype=torch.long, device=DEVICE)\n        sei, ssi, eo = flatten_sort_count(top_idx, E)\n        gx = torch.randn(T, K, device=DEVICE, dtype=DTYPE)\n        dy = torch.randn(T, N, device=DEVICE, dtype=DTYPE)\n        lA = torch.randn(R * E, K, device=DEVICE, dtype=DTYPE)\n        lB = torch.randn(N, R * E, device=DEVICE, dtype=DTYPE)\n\n        dA, dB = lora_ops.group_bwd_lora(\n            DY=dy,\n            X=gx,\n            lora_A=lA,\n            lora_B=lB,\n            expert_offsets=eo,\n            E=E,\n            scaling=2.0,\n        )\n\n        # Experts 1..7 should have zero gradients\n        for e in range(1, E):\n            assert dA[e * R : (e + 1) * R].abs().max() == 0, f\"Expert {e} dA not zero\"\n            assert dB[:, e * R : (e + 1) * R].abs().max() == 0, (\n                f\"Expert {e} dB not zero\"\n            )\n\n\n# ─── Full autograd tests ────────────────────────────────────────────────────\n\n\nclass TestScatterMoELoRAAutograd:\n    \"\"\"Test full forward + backward through ScatterMoELoRA autograd function.\"\"\"\n\n    @pytest.fixture(params=CONFIGS_SMALL + CONFIGS_REAL[:2])\n    def config(self, request):\n        return request.param\n\n    def test_gradients_exist_and_finite(self, config):\n        E, K, N, T, k, R, desc = config\n        x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R)\n\n        x = x.requires_grad_(True)\n        lA = lA.requires_grad_(True)\n        lB = lB.requires_grad_(True)\n\n        out = ScatterMoELoRA.apply(\n            x,\n            W,\n            k,\n            sei,\n            ssi,\n            eo,\n            lA,\n            lB,\n            SCALING,\n            None,\n            None,\n            False,\n            False,\n            True,\n            False,\n        )\n        out.sum().backward()\n\n        assert x.grad is not None, f\"[{desc}] x.grad is None\"\n        assert lA.grad is not None, f\"[{desc}] lA.grad is None\"\n        assert lB.grad is not None, f\"[{desc}] lB.grad is None\"\n        assert torch.isfinite(x.grad).all(), f\"[{desc}] x.grad has non-finite\"\n        assert torch.isfinite(lA.grad).all(), f\"[{desc}] lA.grad has non-finite\"\n        assert torch.isfinite(lB.grad).all(), f\"[{desc}] lB.grad has non-finite\"\n        assert x.grad.abs().sum() > 0, f\"[{desc}] x.grad all zero\"\n        assert lA.grad.abs().sum() > 0, f\"[{desc}] lA.grad all zero\"\n\n    def test_split_matches_fused(self):\n        \"\"\"Split dispatch (for few large experts) matches fused kernel.\"\"\"\n        # Use a shape where split would be dispatched (large K*N, few E)\n        E, K, N, T, k, R = 8, 512, 1024, 128, 2, 16\n        x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R)\n\n        # Force fused path\n        orig = lora_ops._SPLIT_LORA_FWD_THRESHOLD\n        lora_ops._SPLIT_LORA_FWD_THRESHOLD = 10**18\n        out_fused = lora_ops.scatter2scatter_lora(\n            X=x,\n            W=W,\n            sorted_expert_idxs=sei,\n            sorted_scattered_idxs=ssi,\n            k=k,\n            lora_A=lA,\n            lora_B=lB,\n            scaling=SCALING,\n        )\n\n        # Force split path\n        lora_ops._SPLIT_LORA_FWD_THRESHOLD = 0\n        out_split = lora_ops.scatter2scatter_lora(\n            X=x,\n            W=W,\n            sorted_expert_idxs=sei,\n            sorted_scattered_idxs=ssi,\n            k=k,\n            lora_A=lA,\n            lora_B=lB,\n            scaling=SCALING,\n        )\n        lora_ops._SPLIT_LORA_FWD_THRESHOLD = orig\n\n        norm_err = (\n            (out_fused.float() - out_split.float()).norm()\n            / (out_fused.float().norm() + 1e-6)\n        ).item()\n        assert norm_err < 0.01, f\"split vs fused norm_err={norm_err}\"\n\n    def test_scaling_zero_gives_base_only(self):\n        \"\"\"With scaling=0.0, LoRA contribution vanishes. Output = X@W.\"\"\"\n        E, K, N, T, k, R = 16, 64, 32, 32, 2, 4\n        x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R)\n\n        out_lora = ScatterMoELoRA.apply(\n            x,\n            W,\n            k,\n            sei,\n            ssi,\n            eo,\n            lA,\n            lB,\n            0.0,\n            None,\n            None,\n            False,\n            False,\n            True,\n            False,\n        )\n        out_base = base_ops.scatter2scatter(\n            X=x,\n            W=W,\n            sorted_expert_idxs=sei,\n            sorted_scattered_idxs=ssi,\n            k=k,\n        )\n        err = (out_lora.float() - out_base.float()).abs().max().item()\n        assert err < 0.01, f\"scaling=0 should match base: err={err}\"\n"
  },
  {
    "path": "tests/integrations/test_sonicmoe.py",
    "content": "\"\"\"Unit tests for the SonicMoE integration.\"\"\"\n\nfrom types import SimpleNamespace\n\nimport pytest\nimport torch\n\nfrom axolotl.integrations.kernels.args import KernelsArgs\nfrom axolotl.integrations.kernels.sonicmoe.routing import (\n    sigmoid_topk_routing,\n    softmax_topk_routing,\n)\nfrom axolotl.integrations.kernels.sonicmoe.weight_converter import (\n    ConcatenatedToInterleaved,\n    InterleavedToConcatenated,\n    register_sonicmoe_weight_converter,\n)\n\n\nclass TestKernelsArgs:\n    def test_mutual_exclusivity_raises(self):\n        with pytest.raises(ValueError, match=\"Cannot use both\"):\n            KernelsArgs.model_validate({\"use_scattermoe\": True, \"use_sonicmoe\": True})\n\n    def test_sonicmoe_only(self):\n        result = KernelsArgs.model_validate({\"use_sonicmoe\": True})\n        assert result.use_sonicmoe is True\n        assert result.use_scattermoe is None\n\n    def test_scattermoe_only(self):\n        result = KernelsArgs.model_validate({\"use_scattermoe\": True})\n        assert result.use_scattermoe is True\n        assert result.use_sonicmoe is None\n\n    def test_neither_set(self):\n        result = KernelsArgs.model_validate({})\n        assert result.use_scattermoe is None\n        assert result.use_sonicmoe is None\n\n    def test_disables_mlp_kernel_when_sonicmoe(self):\n        data = {\"use_sonicmoe\": True, \"lora_mlp_kernel\": True}\n        result = KernelsArgs.disable_mlp_kernel(data)\n        assert result[\"lora_mlp_kernel\"] is False\n        assert result[\"mlp_kernel\"] is False\n\n\nclass TestConcatenatedToInterleaved:\n    @pytest.fixture\n    def sample_tensor(self):\n        \"\"\"Create a test tensor [E=2, 2*I=4, H=3] with distinct gate/up values.\"\"\"\n        E, I, H = 2, 2, 3  # noqa: E741\n        gate = torch.arange(1, E * I * H + 1, dtype=torch.float32).reshape(E, I, H)\n        up = torch.arange(100, 100 + E * I * H, dtype=torch.float32).reshape(E, I, H)\n        return torch.cat([gate, up], dim=1)\n\n    def test_interleave_rows_alternate(self, sample_tensor):\n        op = ConcatenatedToInterleaved(dim=1)\n        result = op.convert(\n            {\"test\": sample_tensor},\n            source_patterns=[\"test\"],\n            target_patterns=[\"test\"],\n        )\n        interleaved = result[\"test\"]\n\n        # For expert 0: even rows should be gate, odd rows should be up\n        E, two_I, H = sample_tensor.shape\n        I = two_I // 2  # noqa: E741\n        gate_orig = sample_tensor[:, :I, :]\n        up_orig = sample_tensor[:, I:, :]\n\n        assert torch.equal(interleaved[:, 0::2, :], gate_orig)\n        assert torch.equal(interleaved[:, 1::2, :], up_orig)\n\n    def test_interleave_handles_list_input(self, sample_tensor):\n        op = ConcatenatedToInterleaved(dim=1)\n        result = op.convert(\n            {\"test\": [sample_tensor]},\n            source_patterns=[\"test\"],\n            target_patterns=[\"test\"],\n        )\n        assert result[\"test\"].shape == sample_tensor.shape\n\n    def test_reverse_op_type(self):\n        op = ConcatenatedToInterleaved(dim=1)\n        assert isinstance(op.reverse_op, InterleavedToConcatenated)\n        assert op.reverse_op.dim == 1\n\n\nclass TestInterleavedToConcatenated:\n    @pytest.fixture\n    def interleaved_tensor(self):\n        \"\"\"Create an interleaved tensor [E=2, 2*I=4, H=3].\"\"\"\n        E, I, H = 2, 2, 3  # noqa: E741\n        gate = torch.arange(1, E * I * H + 1, dtype=torch.float32).reshape(E, I, H)\n        up = torch.arange(100, 100 + E * I * H, dtype=torch.float32).reshape(E, I, H)\n        interleaved = torch.empty(E, 2 * I, H)\n        interleaved[:, 0::2, :] = gate\n        interleaved[:, 1::2, :] = up\n        return interleaved\n\n    def test_deinterleave_gate_up_separated(self, interleaved_tensor):\n        op = InterleavedToConcatenated(dim=1)\n        result = op.convert(\n            {\"test\": interleaved_tensor},\n            source_patterns=[\"test\"],\n            target_patterns=[\"test\"],\n        )\n        concatenated = result[\"test\"]\n\n        E, two_I, H = concatenated.shape\n        I = two_I // 2  # noqa: E741\n\n        # First half should be gate (even rows from interleaved)\n        assert torch.equal(concatenated[:, :I, :], interleaved_tensor[:, 0::2, :])\n        # Second half should be up (odd rows from interleaved)\n        assert torch.equal(concatenated[:, I:, :], interleaved_tensor[:, 1::2, :])\n\n    def test_reverse_op_type(self):\n        op = InterleavedToConcatenated(dim=1)\n        assert isinstance(op.reverse_op, ConcatenatedToInterleaved)\n        assert op.reverse_op.dim == 1\n\n\nclass TestRoundTrip:\n    @pytest.fixture\n    def concat_tensor(self):\n        E, I, H = 4, 8, 16  # noqa: E741\n        gate = torch.randn(E, I, H)\n        up = torch.randn(E, I, H)\n        return torch.cat([gate, up], dim=1)\n\n    def test_interleave_then_deinterleave_is_identity(self, concat_tensor):\n        fwd = ConcatenatedToInterleaved(dim=1)\n        rev = InterleavedToConcatenated(dim=1)\n\n        interleaved = fwd.convert(\n            {\"k\": concat_tensor}, source_patterns=[\"k\"], target_patterns=[\"k\"]\n        )[\"k\"]\n        recovered = rev.convert(\n            {\"k\": interleaved}, source_patterns=[\"k\"], target_patterns=[\"k\"]\n        )[\"k\"]\n\n        assert torch.equal(concat_tensor, recovered)\n\n    def test_reverse_op_chain_is_identity(self, concat_tensor):\n        \"\"\"Verify that op.reverse_op produces an exact inverse.\"\"\"\n        op = ConcatenatedToInterleaved(dim=1)\n        rev = op.reverse_op\n\n        interleaved = op.convert(\n            {\"k\": concat_tensor}, source_patterns=[\"k\"], target_patterns=[\"k\"]\n        )[\"k\"]\n        recovered = rev.convert(\n            {\"k\": interleaved}, source_patterns=[\"k\"], target_patterns=[\"k\"]\n        )[\"k\"]\n\n        assert torch.equal(concat_tensor, recovered)\n\n    def test_various_shapes(self):\n        \"\"\"Test with different expert counts and dimensions.\"\"\"\n        fwd = ConcatenatedToInterleaved(dim=1)\n        rev = InterleavedToConcatenated(dim=1)\n\n        for E, I, H in [(1, 4, 8), (8, 16, 32), (16, 128, 256)]:  # noqa: E741\n            concat = torch.randn(E, 2 * I, H)\n            interleaved = fwd.convert(\n                {\"k\": concat}, source_patterns=[\"k\"], target_patterns=[\"k\"]\n            )[\"k\"]\n            recovered = rev.convert(\n                {\"k\": interleaved}, source_patterns=[\"k\"], target_patterns=[\"k\"]\n            )[\"k\"]\n            assert torch.equal(concat, recovered), (\n                f\"Failed for shape ({E}, {2 * I}, {H})\"\n            )\n\n\nclass TestWeightConverterRegistration:\n    def test_register_appends_interleave_op(self):\n        from transformers.conversion_mapping import get_checkpoint_conversion_mapping\n\n        register_sonicmoe_weight_converter(\"qwen3_moe\")\n\n        modified = get_checkpoint_conversion_mapping(\"qwen3_moe\")\n        # Find the gate_up_proj converter\n        gate_up_converter = None\n        for conv in modified:\n            if hasattr(conv, \"operations\") and any(\n                \"gate_up_proj\" in pat for pat in conv.target_patterns\n            ):\n                gate_up_converter = conv\n                break\n\n        assert gate_up_converter is not None\n        assert isinstance(gate_up_converter.operations[-1], ConcatenatedToInterleaved)\n\n    def test_double_registration_is_idempotent(self):\n        from transformers.conversion_mapping import get_checkpoint_conversion_mapping\n\n        register_sonicmoe_weight_converter(\"qwen3_moe\")\n        register_sonicmoe_weight_converter(\"qwen3_moe\")\n\n        modified = get_checkpoint_conversion_mapping(\"qwen3_moe\")\n        for conv in modified:\n            if hasattr(conv, \"operations\") and any(\n                \"gate_up_proj\" in pat for pat in conv.target_patterns\n            ):\n                interleave_count = sum(\n                    isinstance(op, ConcatenatedToInterleaved) for op in conv.operations\n                )\n                assert interleave_count == 1, (\n                    f\"Expected 1 ConcatenatedToInterleaved op, got {interleave_count}\"\n                )\n                break\n\n    def test_register_unsupported_model_type_warns(self):\n        # A model type with no conversion mapping should warn but not raise\n        register_sonicmoe_weight_converter(\"nonexistent_model_type_xyz\")\n\n\ndef _make_qwen_moe_block(T=8, H=16, E=4, K=2):\n    \"\"\"Create a mock qwen-style MoE block for routing tests.\"\"\"\n    gate = SimpleNamespace(\n        weight=torch.randn(E, H),\n        top_k=K,\n        num_experts=E,\n        norm_topk_prob=True,\n    )\n    return SimpleNamespace(gate=gate), T, H, E, K\n\n\ndef _make_glm_moe_block(T=8, H=16, E=16, K=4, n_group=2, topk_group=1):\n    \"\"\"Create a mock GLM5-style MoE block for routing tests.\"\"\"\n    gate = SimpleNamespace(\n        weight=torch.randn(E, H),\n        e_score_correction_bias=torch.zeros(E),\n    )\n    moe_block = SimpleNamespace(\n        gate=gate,\n        top_k=K,\n        n_routed_experts=E,\n        n_group=n_group,\n        topk_group=topk_group,\n        norm_topk_prob=True,\n        routed_scaling_factor=1.0,\n    )\n    return moe_block, T, H, E, K\n\n\ndef _make_minimax_m2_moe_block(T=8, H=16, E=16, K=4):\n    \"\"\"Create a mock minimax_m2-style MoE block for routing tests.\n\n    minimax_m2 uses sigmoid->topk WITHOUT group selection:\n    - e_score_correction_bias is on the moe_block (not on gate)\n    - No n_group / topk_group attributes\n    - Always normalizes (norm_topk_prob defaults to True)\n    - No routed_scaling_factor (defaults to 1.0)\n    \"\"\"\n    gate = SimpleNamespace(\n        weight=torch.randn(E, H),\n        top_k=K,\n    )\n    moe_block = SimpleNamespace(\n        gate=gate,\n        top_k=K,\n        e_score_correction_bias=torch.zeros(E),\n    )\n    return moe_block, T, H, E, K\n\n\nclass TestSoftmaxTopkRouting:\n    def test_output_shapes(self):\n        moe_block, T, H, E, K = _make_qwen_moe_block()\n        hidden = torch.randn(T, H)\n\n        scores, token_idx, expert_idx, logits = softmax_topk_routing(hidden, moe_block)\n\n        assert scores.shape == (T * K,)\n        assert token_idx.shape == (T * K,)\n        assert expert_idx.shape == (T * K,)\n        assert logits.shape == (T, E)\n\n    def test_scores_are_float32(self):\n        moe_block, T, H, E, K = _make_qwen_moe_block()\n        hidden = torch.randn(T, H)\n\n        scores, _, _, _ = softmax_topk_routing(hidden, moe_block)\n        assert scores.dtype == torch.float32\n\n    def test_token_indices_sorted_ascending(self):\n        moe_block, T, H, E, K = _make_qwen_moe_block()\n        hidden = torch.randn(T, H)\n\n        _, token_idx, _, _ = softmax_topk_routing(hidden, moe_block)\n\n        # Token indices must be sorted ascending (SonicMoE requirement)\n        diffs = token_idx[1:] - token_idx[:-1]\n        assert (diffs >= 0).all()\n\n    def test_expert_indices_in_range(self):\n        moe_block, T, H, E, K = _make_qwen_moe_block()\n        hidden = torch.randn(T, H)\n\n        _, _, expert_idx, _ = softmax_topk_routing(hidden, moe_block)\n\n        assert (expert_idx >= 0).all()\n        assert (expert_idx < E).all()\n\n    def test_renormalized_scores_sum_to_one(self):\n        moe_block, T, H, E, K = _make_qwen_moe_block()\n        hidden = torch.randn(T, H)\n\n        scores, _, _, _ = softmax_topk_routing(hidden, moe_block)\n        per_token_sums = scores.reshape(T, K).sum(dim=-1)\n        assert torch.allclose(per_token_sums, torch.ones(T), atol=1e-5)\n\n\nclass TestSigmoidTopkRouting:\n    def test_output_shapes(self):\n        moe_block, T, H, E, K = _make_glm_moe_block()\n        hidden = torch.randn(T, H)\n\n        scores, token_idx, expert_idx, logits = sigmoid_topk_routing(hidden, moe_block)\n\n        assert scores.shape == (T * K,)\n        assert token_idx.shape == (T * K,)\n        assert expert_idx.shape == (T * K,)\n        assert logits.shape == (T, E)\n\n    def test_scores_are_float32(self):\n        moe_block, T, H, E, K = _make_glm_moe_block()\n        hidden = torch.randn(T, H)\n\n        scores, _, _, _ = sigmoid_topk_routing(hidden, moe_block)\n        assert scores.dtype == torch.float32\n\n    def test_token_indices_sorted_ascending(self):\n        moe_block, T, H, E, K = _make_glm_moe_block()\n        hidden = torch.randn(T, H)\n\n        _, token_idx, _, _ = sigmoid_topk_routing(hidden, moe_block)\n\n        diffs = token_idx[1:] - token_idx[:-1]\n        assert (diffs >= 0).all()\n\n    def test_expert_indices_in_range(self):\n        moe_block, T, H, E, K = _make_glm_moe_block()\n        hidden = torch.randn(T, H)\n\n        _, _, expert_idx, _ = sigmoid_topk_routing(hidden, moe_block)\n\n        assert (expert_idx >= 0).all()\n        assert (expert_idx < E).all()\n\n    def test_scores_are_nonnegative(self):\n        \"\"\"Sigmoid outputs are in [0, 1], so scores should be non-negative.\"\"\"\n        moe_block, T, H, E, K = _make_glm_moe_block()\n        hidden = torch.randn(T, H)\n\n        scores, _, _, _ = sigmoid_topk_routing(hidden, moe_block)\n        assert (scores >= 0).all()\n\n    def test_scaling_factor_applied(self):\n        moe_block, T, H, E, K = _make_glm_moe_block()\n        hidden = torch.randn(T, H)\n\n        # Get scores with scaling_factor=1.0\n        scores_1x, _, _, _ = sigmoid_topk_routing(hidden, moe_block)\n\n        # Get scores with scaling_factor=2.0\n        moe_block.routed_scaling_factor = 2.0\n        scores_2x, _, _, _ = sigmoid_topk_routing(hidden, moe_block)\n\n        assert torch.allclose(scores_2x, scores_1x * 2.0, atol=1e-5)\n\n    def test_group_selection_restricts_experts(self):\n        \"\"\"With n_group=4 and topk_group=1, only 1/4 of experts should be selectable.\"\"\"\n        moe_block, T, H, E, K = _make_glm_moe_block(E=16, K=2, n_group=4, topk_group=1)\n        hidden = torch.randn(T, H)\n\n        _, _, expert_idx, _ = sigmoid_topk_routing(hidden, moe_block)\n\n        # Each token's experts should all fall within a single group (size E//n_group=4)\n        expert_idx_2d = expert_idx.reshape(T, K)\n        for t in range(T):\n            experts = expert_idx_2d[t]\n            groups = experts // (E // moe_block.n_group)\n            # All selected experts should be from the same group\n            assert (groups == groups[0]).all()\n\n\nclass TestMiniMaxM2SigmoidRouting:\n    \"\"\"Tests for minimax_m2 routing: sigmoid->topk without group selection.\"\"\"\n\n    def test_output_shapes(self):\n        \"\"\"Validates getattr defaults work: n_group=1, E from gate.weight.shape[0].\"\"\"\n        moe_block, T, H, E, K = _make_minimax_m2_moe_block()\n        hidden = torch.randn(T, H)\n\n        scores, token_idx, expert_idx, logits = sigmoid_topk_routing(hidden, moe_block)\n\n        assert scores.shape == (T * K,)\n        assert token_idx.shape == (T * K,)\n        assert expert_idx.shape == (T * K,)\n        assert logits.shape == (T, E)\n\n    def test_bias_on_block_not_gate(self):\n        \"\"\"Verify that e_score_correction_bias on the block (not gate) is used.\"\"\"\n        T, H, E, K = 8, 16, 8, 2\n        gate = SimpleNamespace(\n            weight=torch.randn(E, H),\n            top_k=K,\n        )\n        # Large positive bias on expert 0 should make it selected more often\n        bias = torch.zeros(E)\n        bias[0] = 100.0\n        moe_block = SimpleNamespace(\n            gate=gate,\n            top_k=K,\n            e_score_correction_bias=bias,\n        )\n        hidden = torch.randn(T, H)\n\n        _, _, expert_idx, _ = sigmoid_topk_routing(hidden, moe_block)\n\n        # Expert 0 should appear for every token due to the large bias\n        expert_idx_2d = expert_idx.reshape(T, K)\n        for t in range(T):\n            assert 0 in expert_idx_2d[t]\n"
  },
  {
    "path": "tests/integrations/test_sonicmoe_gradients.py",
    "content": "\"\"\"\nGradient correctness tests for SonicMoE routing functions (CPU-only).\n\nUses torch.autograd.gradcheck with float32 inputs to match the production\ncode path where routing happens in float32.\n\"\"\"\n\nimport torch\n\nfrom axolotl.integrations.kernels.sonicmoe.routing import (\n    sigmoid_topk_routing,\n    softmax_topk_routing,\n)\n\n_GC_EPS = 1e-3\n_GC_ATOL = 1e-3\n_GC_RTOL = 1e-3\n\n\ndef _make_softmax_moe_block(weight):\n    gate = torch.nn.Module()\n    gate.weight = weight\n    gate.top_k = 2\n    gate.norm_topk_prob = True\n\n    moe_block = torch.nn.Module()\n    moe_block.gate = gate\n    return moe_block\n\n\ndef _make_sigmoid_moe_block(weight, bias):\n    gate = torch.nn.Module()\n    gate.weight = weight\n    gate.e_score_correction_bias = bias\n\n    moe_block = torch.nn.Module()\n    moe_block.gate = gate\n    moe_block.top_k = 2\n    moe_block.n_routed_experts = weight.shape[0]\n    moe_block.n_group = 1\n    moe_block.norm_topk_prob = True\n    moe_block.routed_scaling_factor = 1.0\n    return moe_block\n\n\nclass TestSoftmaxTopkRoutingGradcheck:\n    \"\"\"Numerical gradient verification for softmax_topk_routing.\"\"\"\n\n    def test_gradcheck_wrt_gate_weight(self):\n        T, H, E = 4, 8, 4\n\n        hidden = torch.randn(T, H, dtype=torch.float32)\n\n        def fn(weight):\n            moe_block = _make_softmax_moe_block(weight)\n            scores, _, _, _ = softmax_topk_routing(hidden, moe_block)\n            return scores\n\n        weight = torch.randn(E, H, dtype=torch.float32, requires_grad=True)\n        torch.autograd.gradcheck(\n            fn, (weight,), eps=_GC_EPS, atol=_GC_ATOL, rtol=_GC_RTOL\n        )\n\n    def test_gradcheck_wrt_hidden_states(self):\n        T, H, E = 4, 8, 4\n\n        weight = torch.randn(E, H, dtype=torch.float32)\n        moe_block = _make_softmax_moe_block(weight)\n\n        def fn(hidden):\n            scores, _, _, _ = softmax_topk_routing(hidden, moe_block)\n            return scores\n\n        hidden = torch.randn(T, H, dtype=torch.float32, requires_grad=True)\n        torch.autograd.gradcheck(\n            fn, (hidden,), eps=_GC_EPS, atol=_GC_ATOL, rtol=_GC_RTOL\n        )\n\n    def test_gradcheck_wrt_router_logits(self):\n        T, H, E = 4, 8, 4\n\n        hidden = torch.randn(T, H, dtype=torch.float32)\n\n        def fn(weight):\n            moe_block = _make_softmax_moe_block(weight)\n            _, _, _, router_logits = softmax_topk_routing(hidden, moe_block)\n            return router_logits\n\n        weight = torch.randn(E, H, dtype=torch.float32, requires_grad=True)\n        torch.autograd.gradcheck(\n            fn, (weight,), eps=_GC_EPS, atol=_GC_ATOL, rtol=_GC_RTOL\n        )\n\n    def test_no_norm_variant(self):\n        T, H, E = 4, 8, 4\n\n        hidden = torch.randn(T, H, dtype=torch.float32)\n\n        def fn(weight):\n            moe_block = _make_softmax_moe_block(weight)\n            moe_block.gate.norm_topk_prob = False\n            scores, _, _, _ = softmax_topk_routing(hidden, moe_block)\n            return scores\n\n        weight = torch.randn(E, H, dtype=torch.float32, requires_grad=True)\n        torch.autograd.gradcheck(\n            fn, (weight,), eps=_GC_EPS, atol=_GC_ATOL, rtol=_GC_RTOL\n        )\n\n\nclass TestSigmoidTopkRoutingGradcheck:\n    \"\"\"Numerical gradient verification for sigmoid_topk_routing.\"\"\"\n\n    def test_gradcheck_wrt_gate_weight(self):\n        T, H, E = 4, 8, 4\n\n        hidden = torch.randn(T, H, dtype=torch.float32)\n        bias = torch.zeros(E, dtype=torch.float32)\n\n        def fn(weight):\n            moe_block = _make_sigmoid_moe_block(weight, bias)\n            scores, _, _, _ = sigmoid_topk_routing(hidden, moe_block)\n            return scores\n\n        weight = torch.randn(E, H, dtype=torch.float32, requires_grad=True)\n        torch.autograd.gradcheck(\n            fn, (weight,), eps=_GC_EPS, atol=_GC_ATOL, rtol=_GC_RTOL\n        )\n\n    def test_gradcheck_wrt_hidden_states(self):\n        T, H, E = 4, 8, 4\n\n        weight = torch.randn(E, H, dtype=torch.float32)\n        bias = torch.zeros(E, dtype=torch.float32)\n        moe_block = _make_sigmoid_moe_block(weight, bias)\n\n        def fn(hidden):\n            scores, _, _, _ = sigmoid_topk_routing(hidden, moe_block)\n            return scores\n\n        hidden = torch.randn(T, H, dtype=torch.float32, requires_grad=True)\n        torch.autograd.gradcheck(\n            fn, (hidden,), eps=_GC_EPS, atol=_GC_ATOL, rtol=_GC_RTOL\n        )\n\n    def test_gradcheck_wrt_bias(self):\n        T, H, E = 4, 8, 4\n\n        hidden = torch.randn(T, H, dtype=torch.float32)\n        weight = torch.randn(E, H, dtype=torch.float32)\n\n        def fn(bias):\n            moe_block = _make_sigmoid_moe_block(weight, bias)\n            scores, _, _, _ = sigmoid_topk_routing(hidden, moe_block)\n            return scores\n\n        bias = torch.zeros(E, dtype=torch.float32, requires_grad=True)\n        torch.autograd.gradcheck(fn, (bias,), eps=_GC_EPS, atol=_GC_ATOL, rtol=_GC_RTOL)\n"
  },
  {
    "path": "tests/integrations/test_swanlab.py",
    "content": "# Copyright 2024 Axolotl AI. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nUnit tests for SwanLab Integration Plugin.\n\nTests conflict detection, configuration validation, and multi-logger warnings.\n\"\"\"\n\nimport importlib.util\nimport logging\nimport os\nimport time\nfrom unittest.mock import MagicMock, patch\n\nimport pytest\nfrom pydantic import ValidationError\n\nfrom axolotl.integrations.swanlab.args import SwanLabConfig\nfrom axolotl.integrations.swanlab.plugins import SwanLabPlugin\n\nSWANLAB_INSTALLED = importlib.util.find_spec(\"swanlab\") is not None\n\n\n@pytest.mark.skipif(not SWANLAB_INSTALLED, reason=\"swanlab package not installed\")\nclass TestSwanLabConfigValidators:\n    \"\"\"Tests for Pydantic field validators in SwanLabConfig.\"\"\"\n\n    def test_valid_swanlab_mode_cloud(self):\n        \"\"\"Test that 'cloud' mode is valid.\"\"\"\n        config = SwanLabConfig(swanlab_mode=\"cloud\")\n        assert config.swanlab_mode == \"cloud\"\n\n    def test_valid_swanlab_mode_local(self):\n        \"\"\"Test that 'local' mode is valid.\"\"\"\n        config = SwanLabConfig(swanlab_mode=\"local\")\n        assert config.swanlab_mode == \"local\"\n\n    def test_valid_swanlab_mode_offline(self):\n        \"\"\"Test that 'offline' mode is valid.\"\"\"\n        config = SwanLabConfig(swanlab_mode=\"offline\")\n        assert config.swanlab_mode == \"offline\"\n\n    def test_valid_swanlab_mode_disabled(self):\n        \"\"\"Test that 'disabled' mode is valid.\"\"\"\n        config = SwanLabConfig(swanlab_mode=\"disabled\")\n        assert config.swanlab_mode == \"disabled\"\n\n    def test_invalid_swanlab_mode(self):\n        \"\"\"Test that invalid mode raises ValueError.\"\"\"\n        with pytest.raises(ValidationError) as exc_info:\n            SwanLabConfig(swanlab_mode=\"invalid\")\n\n        error_msg = str(exc_info.value)\n        assert \"Invalid swanlab_mode\" in error_msg\n        assert \"cloud\" in error_msg\n        assert \"local\" in error_msg\n        assert \"offline\" in error_msg\n        assert \"disabled\" in error_msg\n\n    def test_swanlab_mode_none_allowed(self):\n        \"\"\"Test that None mode is allowed (will use default).\"\"\"\n        config = SwanLabConfig(swanlab_mode=None)\n        assert config.swanlab_mode is None\n\n    def test_valid_swanlab_project(self):\n        \"\"\"Test that valid project name is accepted.\"\"\"\n        config = SwanLabConfig(swanlab_project=\"my-project\")\n        assert config.swanlab_project == \"my-project\"\n\n    def test_swanlab_project_none_allowed(self):\n        \"\"\"Test that None project is allowed.\"\"\"\n        config = SwanLabConfig(swanlab_project=None)\n        assert config.swanlab_project is None\n\n    def test_empty_swanlab_project_rejected(self):\n        \"\"\"Test that empty string project name is rejected.\"\"\"\n        with pytest.raises(ValidationError) as exc_info:\n            SwanLabConfig(swanlab_project=\"\")\n\n        error_msg = str(exc_info.value)\n        assert \"cannot be an empty string\" in error_msg\n\n    def test_whitespace_only_project_rejected(self):\n        \"\"\"Test that whitespace-only project name is rejected.\"\"\"\n        with pytest.raises(ValidationError) as exc_info:\n            SwanLabConfig(swanlab_project=\"   \")\n\n        error_msg = str(exc_info.value)\n        assert \"cannot be an empty string\" in error_msg\n\n    def test_use_swanlab_true_requires_project(self):\n        \"\"\"Test that use_swanlab=True requires swanlab_project.\"\"\"\n        with pytest.raises(ValidationError) as exc_info:\n            SwanLabConfig(use_swanlab=True, swanlab_project=None)\n\n        error_msg = str(exc_info.value)\n        assert \"swanlab_project\" in error_msg.lower()\n        assert \"not set\" in error_msg.lower()\n\n    def test_use_swanlab_true_with_project_valid(self):\n        \"\"\"Test that use_swanlab=True with project is valid.\"\"\"\n        config = SwanLabConfig(use_swanlab=True, swanlab_project=\"my-project\")\n        assert config.use_swanlab is True\n        assert config.swanlab_project == \"my-project\"\n\n    def test_use_swanlab_false_no_project_valid(self):\n        \"\"\"Test that use_swanlab=False without project is valid.\"\"\"\n        config = SwanLabConfig(use_swanlab=False, swanlab_project=None)\n        assert config.use_swanlab is False\n        assert config.swanlab_project is None\n\n    def test_use_swanlab_none_no_project_valid(self):\n        \"\"\"Test that use_swanlab=None without project is valid.\"\"\"\n        config = SwanLabConfig(use_swanlab=None, swanlab_project=None)\n        assert config.use_swanlab is None\n        assert config.swanlab_project is None\n\n\n@pytest.mark.skipif(not SWANLAB_INSTALLED, reason=\"swanlab package not installed\")\nclass TestSwanLabPluginRegister:\n    \"\"\"Tests for SwanLabPlugin.register() conflict detection.\"\"\"\n\n    def test_register_without_use_swanlab(self):\n        \"\"\"Test that register works when SwanLab is not enabled.\"\"\"\n        plugin = SwanLabPlugin()\n        cfg = {\"use_swanlab\": False}\n        # Should not raise\n        plugin.register(cfg)\n\n    def test_register_use_swanlab_missing_project(self):\n        \"\"\"Test that use_swanlab=True without project raises ValueError.\"\"\"\n        plugin = SwanLabPlugin()\n        cfg = {\"use_swanlab\": True}\n\n        with pytest.raises(ValueError) as exc_info:\n            plugin.register(cfg)\n\n        error_msg = str(exc_info.value)\n        assert \"swanlab_project\" in error_msg\n        assert \"not set\" in error_msg\n        assert \"Solutions\" in error_msg\n\n    def test_register_use_swanlab_with_project_valid(self):\n        \"\"\"Test that use_swanlab=True with project is valid.\"\"\"\n        plugin = SwanLabPlugin()\n        cfg = {\"use_swanlab\": True, \"swanlab_project\": \"my-project\"}\n        # Should not raise\n        plugin.register(cfg)\n\n    def test_register_invalid_mode(self):\n        \"\"\"Test that invalid swanlab_mode raises ValueError.\"\"\"\n        plugin = SwanLabPlugin()\n        cfg = {\n            \"use_swanlab\": True,\n            \"swanlab_project\": \"my-project\",\n            \"swanlab_mode\": \"invalid-mode\",\n        }\n\n        with pytest.raises(ValueError) as exc_info:\n            plugin.register(cfg)\n\n        error_msg = str(exc_info.value)\n        assert \"Invalid swanlab_mode\" in error_msg\n        assert \"cloud\" in error_msg\n        assert \"local\" in error_msg\n\n    def test_register_valid_modes(self):\n        \"\"\"Test that all valid modes are accepted.\"\"\"\n        plugin = SwanLabPlugin()\n        valid_modes = [\"cloud\", \"local\", \"offline\", \"disabled\"]\n\n        for mode in valid_modes:\n            cfg = {\n                \"use_swanlab\": True,\n                \"swanlab_project\": \"my-project\",\n                \"swanlab_mode\": mode,\n            }\n            # Should not raise\n            plugin.register(cfg)\n\n    def test_register_auto_enable_swanlab(self):\n        \"\"\"Test that providing swanlab_project auto-enables use_swanlab.\"\"\"\n        plugin = SwanLabPlugin()\n        cfg = {\"swanlab_project\": \"my-project\"}\n\n        plugin.register(cfg)\n\n        assert cfg[\"use_swanlab\"] is True\n\n    def test_register_cloud_mode_without_api_key_warns(self, caplog):\n        \"\"\"Test that cloud mode without API key logs warning.\"\"\"\n        plugin = SwanLabPlugin()\n        cfg = {\n            \"use_swanlab\": True,\n            \"swanlab_project\": \"my-project\",\n            \"swanlab_mode\": \"cloud\",\n        }\n\n        # Clear environment variable to ensure it's not set\n        with patch.dict(os.environ, {}, clear=True):\n            with caplog.at_level(logging.WARNING):\n                plugin.register(cfg)\n\n            # Should log warning about missing API key\n            warning_messages = [record.message for record in caplog.records]\n            assert any(\"API key\" in msg for msg in warning_messages)\n\n\n@pytest.mark.skipif(not SWANLAB_INSTALLED, reason=\"swanlab package not installed\")\nclass TestMultiLoggerDetection:\n    \"\"\"Tests for multi-logger conflict detection.\"\"\"\n\n    def test_single_logger_no_warning(self, caplog):\n        \"\"\"Test that single logger doesn't trigger warning.\"\"\"\n        plugin = SwanLabPlugin()\n        cfg = {\"use_swanlab\": True, \"swanlab_project\": \"my-project\"}\n\n        with caplog.at_level(logging.WARNING):\n            plugin.register(cfg)\n\n        # Should not log multi-logger warning\n        warning_messages = [record.message for record in caplog.records]\n        assert not any(\"Multiple logging tools\" in msg for msg in warning_messages)\n\n    def test_two_loggers_warning(self, caplog):\n        \"\"\"Test that two loggers trigger warning.\"\"\"\n        plugin = SwanLabPlugin()\n        cfg = {\n            \"use_swanlab\": True,\n            \"swanlab_project\": \"my-project\",\n            \"use_wandb\": True,\n        }\n\n        with caplog.at_level(logging.WARNING):\n            plugin.register(cfg)\n\n        # Should log multi-logger warning\n        warning_messages = [record.message for record in caplog.records]\n        assert any(\"Multiple logging tools\" in msg for msg in warning_messages)\n        assert any(\"SwanLab\" in msg and \"WandB\" in msg for msg in warning_messages)\n\n    def test_three_loggers_error(self, caplog):\n        \"\"\"Test that three loggers trigger error-level warning.\"\"\"\n        plugin = SwanLabPlugin()\n        cfg = {\n            \"use_swanlab\": True,\n            \"swanlab_project\": \"my-project\",\n            \"use_wandb\": True,\n            \"use_mlflow\": True,\n        }\n\n        with caplog.at_level(logging.ERROR):\n            plugin.register(cfg)\n\n        # Should log error-level warning\n        error_messages = [\n            record.message\n            for record in caplog.records\n            if record.levelno >= logging.ERROR\n        ]\n        assert any(\"logging tools enabled\" in msg for msg in error_messages)\n\n    def test_multi_logger_with_comet(self, caplog):\n        \"\"\"Test that Comet is detected in multi-logger scenario.\"\"\"\n        plugin = SwanLabPlugin()\n        cfg = {\n            \"use_swanlab\": True,\n            \"swanlab_project\": \"my-project\",\n            \"comet_api_key\": \"test-key\",\n        }\n\n        with caplog.at_level(logging.WARNING):\n            plugin.register(cfg)\n\n        # Should detect Comet\n        warning_messages = [record.message for record in caplog.records]\n        assert any(\"Comet\" in msg for msg in warning_messages)\n\n    def test_multi_logger_with_comet_project(self, caplog):\n        \"\"\"Test that Comet is detected via comet_project_name.\"\"\"\n        plugin = SwanLabPlugin()\n        cfg = {\n            \"use_swanlab\": True,\n            \"swanlab_project\": \"my-project\",\n            \"comet_project_name\": \"test-project\",\n        }\n\n        with caplog.at_level(logging.WARNING):\n            plugin.register(cfg)\n\n        # Should detect Comet\n        warning_messages = [record.message for record in caplog.records]\n        assert any(\"Comet\" in msg for msg in warning_messages)\n\n\n@pytest.mark.skipif(not SWANLAB_INSTALLED, reason=\"swanlab package not installed\")\nclass TestSwanLabPluginPreModelLoad:\n    \"\"\"Tests for SwanLabPlugin.pre_model_load() runtime checks.\"\"\"\n\n    def test_pre_model_load_disabled(self):\n        \"\"\"Test that pre_model_load does nothing when SwanLab is disabled.\"\"\"\n        plugin = SwanLabPlugin()\n        cfg = MagicMock()\n        cfg.use_swanlab = False\n\n        # Should not raise\n        plugin.pre_model_load(cfg)\n\n    def test_pre_model_load_import_error(self):\n        \"\"\"Test that missing swanlab package raises clear ImportError.\"\"\"\n        plugin = SwanLabPlugin()\n        cfg = MagicMock()\n        cfg.use_swanlab = True\n\n        with patch(\n            \"builtins.__import__\", side_effect=ImportError(\"No module named 'swanlab'\")\n        ):\n            with pytest.raises(ImportError) as exc_info:\n                plugin.pre_model_load(cfg)\n\n            error_msg = str(exc_info.value)\n            assert \"SwanLab is not installed\" in error_msg\n            assert \"pip install swanlab\" in error_msg\n\n    @patch(\"axolotl.utils.distributed.is_main_process\")\n    @patch(\"axolotl.utils.distributed.get_world_size\")\n    def test_pre_model_load_non_main_process_skips(\n        self, mock_get_world_size, mock_is_main_process\n    ):\n        \"\"\"Test that non-main process skips SwanLab initialization.\"\"\"\n        mock_get_world_size.return_value = 2\n        mock_is_main_process.return_value = False\n\n        plugin = SwanLabPlugin()\n        cfg = MagicMock()\n        cfg.use_swanlab = True\n\n        with patch(\"swanlab.init\") as mock_init:\n            plugin.pre_model_load(cfg)\n            # Should NOT call swanlab.init\n            mock_init.assert_not_called()\n\n    @patch(\"axolotl.utils.distributed.is_main_process\")\n    @patch(\"axolotl.utils.distributed.get_world_size\")\n    def test_pre_model_load_distributed_logging(\n        self, mock_get_world_size, mock_is_main_process, caplog\n    ):\n        \"\"\"Test that distributed training logs world size info.\"\"\"\n        mock_get_world_size.return_value = 4\n        mock_is_main_process.return_value = True\n\n        plugin = SwanLabPlugin()\n        cfg = MagicMock()\n        cfg.use_swanlab = True\n        cfg.swanlab_project = \"test-project\"\n        cfg.swanlab_mode = \"cloud\"\n\n        with patch(\"swanlab.init\"), patch(\"swanlab.__version__\", \"0.3.0\"):\n            with caplog.at_level(logging.INFO):\n                plugin.pre_model_load(cfg)\n\n            # Should log distributed training info\n            info_messages = [record.message for record in caplog.records]\n            assert any(\"world_size=4\" in msg for msg in info_messages)\n            assert any(\"Only rank 0\" in msg for msg in info_messages)\n\n\n@pytest.mark.skipif(not SWANLAB_INSTALLED, reason=\"swanlab package not installed\")\nclass TestSwanLabInitKwargs:\n    \"\"\"Tests for SwanLab initialization with direct parameter passing.\"\"\"\n\n    def test_custom_branding_added_to_config(self):\n        \"\"\"Test that Axolotl custom branding is added to SwanLab config.\"\"\"\n        from axolotl.integrations.swanlab.plugins import SwanLabPlugin\n        from axolotl.utils.dict import DictDefault\n\n        plugin = SwanLabPlugin()\n        cfg = DictDefault(\n            {\n                \"use_swanlab\": True,\n                \"swanlab_project\": \"test-project\",\n            }\n        )\n\n        init_kwargs = plugin._get_swanlab_init_kwargs(cfg)\n\n        # Verify custom branding is present\n        assert \"config\" in init_kwargs\n        assert init_kwargs[\"config\"][\"UPPERFRAME\"] == \"🦎 Axolotl\"\n\n    def test_api_key_passed_directly(self):\n        \"\"\"Test that API key is passed directly to swanlab.init() instead of via env var.\"\"\"\n        from axolotl.integrations.swanlab.plugins import SwanLabPlugin\n        from axolotl.utils.dict import DictDefault\n\n        plugin = SwanLabPlugin()\n        cfg = DictDefault(\n            {\n                \"use_swanlab\": True,\n                \"swanlab_project\": \"test-project\",\n                \"swanlab_api_key\": \"test-api-key-12345\",\n            }\n        )\n\n        init_kwargs = plugin._get_swanlab_init_kwargs(cfg)\n\n        # Verify API key is in init_kwargs (not set as env var)\n        assert \"api_key\" in init_kwargs\n        assert init_kwargs[\"api_key\"] == \"test-api-key-12345\"\n\n    def test_private_deployment_hosts_passed_directly(self):\n        \"\"\"Test that private deployment hosts are passed directly to swanlab.init().\"\"\"\n        from axolotl.integrations.swanlab.plugins import SwanLabPlugin\n        from axolotl.utils.dict import DictDefault\n\n        plugin = SwanLabPlugin()\n        cfg = DictDefault(\n            {\n                \"use_swanlab\": True,\n                \"swanlab_project\": \"internal-project\",\n                \"swanlab_web_host\": \"https://swanlab.company.com\",\n                \"swanlab_api_host\": \"https://api-swanlab.company.com\",\n            }\n        )\n\n        init_kwargs = plugin._get_swanlab_init_kwargs(cfg)\n\n        # Verify private deployment hosts are in init_kwargs\n        assert \"web_host\" in init_kwargs\n        assert init_kwargs[\"web_host\"] == \"https://swanlab.company.com\"\n        assert \"api_host\" in init_kwargs\n        assert init_kwargs[\"api_host\"] == \"https://api-swanlab.company.com\"\n\n    @patch(\"axolotl.utils.distributed.is_main_process\")\n    def test_full_private_deployment_init(self, mock_is_main_process):\n        \"\"\"Test complete initialization with private deployment configuration.\"\"\"\n        mock_is_main_process.return_value = True\n\n        from axolotl.integrations.swanlab.plugins import SwanLabPlugin\n        from axolotl.utils.dict import DictDefault\n\n        plugin = SwanLabPlugin()\n        cfg = DictDefault(\n            {\n                \"use_swanlab\": True,\n                \"swanlab_project\": \"secure-project\",\n                \"swanlab_experiment_name\": \"experiment-001\",\n                \"swanlab_mode\": \"cloud\",\n                \"swanlab_api_key\": \"private-key-xyz\",\n                \"swanlab_web_host\": \"https://swanlab.internal.net\",\n                \"swanlab_api_host\": \"https://api.swanlab.internal.net\",\n                \"swanlab_workspace\": \"research-team\",\n            }\n        )\n\n        with patch(\"swanlab.init\") as mock_init:\n            plugin.pre_model_load(cfg)\n\n            # Verify swanlab.init was called with all parameters\n            mock_init.assert_called_once()\n            call_kwargs = mock_init.call_args[1]\n\n            assert call_kwargs[\"project\"] == \"secure-project\"\n            assert call_kwargs[\"experiment_name\"] == \"experiment-001\"\n            assert call_kwargs[\"mode\"] == \"cloud\"\n            assert call_kwargs[\"api_key\"] == \"private-key-xyz\"\n            assert call_kwargs[\"web_host\"] == \"https://swanlab.internal.net\"\n            assert call_kwargs[\"api_host\"] == \"https://api.swanlab.internal.net\"\n            assert call_kwargs[\"workspace\"] == \"research-team\"\n            assert call_kwargs[\"config\"][\"UPPERFRAME\"] == \"🦎 Axolotl\"\n\n    def test_env_vars_not_set_for_api_params(self):\n        \"\"\"Test that environment variables are NOT set for API parameters.\"\"\"\n        import os\n\n        from axolotl.integrations.swanlab.plugins import SwanLabPlugin\n        from axolotl.utils.dict import DictDefault\n\n        # Clear any existing env vars\n        for key in [\n            \"SWANLAB_API_KEY\",\n            \"SWANLAB_WEB_HOST\",\n            \"SWANLAB_API_HOST\",\n            \"SWANLAB_MODE\",\n        ]:\n            os.environ.pop(key, None)\n\n        plugin = SwanLabPlugin()\n        cfg = DictDefault(\n            {\n                \"use_swanlab\": True,\n                \"swanlab_project\": \"test-project\",\n                \"swanlab_api_key\": \"test-key\",\n                \"swanlab_web_host\": \"https://test.com\",\n                \"swanlab_api_host\": \"https://api-test.com\",\n                \"swanlab_mode\": \"cloud\",\n            }\n        )\n\n        with (\n            patch(\"axolotl.utils.distributed.is_main_process\", return_value=True),\n            patch(\"swanlab.init\"),\n        ):\n            plugin.pre_model_load(cfg)\n\n        # Verify env vars were NOT set (simplified approach)\n        # The old _setup_swanlab_env() method is removed, so these shouldn't be set\n        # Note: SwanLab itself might set these, but our plugin shouldn't\n        # We're just testing that our plugin doesn't call _setup_swanlab_env()\n\n\n@pytest.mark.skipif(not SWANLAB_INSTALLED, reason=\"swanlab package not installed\")\nclass TestLarkNotificationIntegration:\n    \"\"\"Tests for Lark (Feishu) notification integration.\"\"\"\n\n    def test_lark_callback_registration_with_webhook_only(self):\n        \"\"\"Test Lark callback registration with webhook URL only (no secret).\"\"\"\n        plugin = SwanLabPlugin()\n\n        cfg = MagicMock()\n        cfg.use_swanlab = True\n        cfg.swanlab_project = \"test-project\"\n        cfg.swanlab_mode = \"local\"\n        cfg.swanlab_lark_webhook_url = (\n            \"https://open.feishu.cn/open-apis/bot/v2/hook/test-webhook\"\n        )\n        cfg.swanlab_lark_secret = None\n\n        with (\n            patch(\"swanlab.init\"),\n            patch(\"swanlab.__version__\", \"0.3.0\"),\n            patch(\"swanlab.register_callbacks\") as mock_register,\n            patch(\"axolotl.utils.distributed.is_main_process\", return_value=True),\n            patch(\"axolotl.utils.distributed.get_world_size\", return_value=1),\n        ):\n            # Mock LarkCallback import\n            with patch(\"swanlab.plugin.notification.LarkCallback\") as MockLarkCallback:\n                mock_lark_instance = MagicMock()\n                MockLarkCallback.return_value = mock_lark_instance\n\n                plugin.pre_model_load(cfg)\n\n                # Verify LarkCallback was instantiated with correct params\n                MockLarkCallback.assert_called_once_with(\n                    webhook_url=\"https://open.feishu.cn/open-apis/bot/v2/hook/test-webhook\",\n                    secret=None,\n                )\n\n                # Verify callback was registered\n                mock_register.assert_called_once_with([mock_lark_instance])\n\n    def test_lark_callback_registration_with_secret(self):\n        \"\"\"Test Lark callback registration with webhook URL and HMAC secret.\"\"\"\n        plugin = SwanLabPlugin()\n\n        cfg = MagicMock()\n        cfg.use_swanlab = True\n        cfg.swanlab_project = \"test-project\"\n        cfg.swanlab_mode = \"local\"\n        cfg.swanlab_lark_webhook_url = (\n            \"https://open.feishu.cn/open-apis/bot/v2/hook/test-webhook\"\n        )\n        cfg.swanlab_lark_secret = \"test-hmac-secret\"\n\n        with (\n            patch(\"swanlab.init\"),\n            patch(\"swanlab.__version__\", \"0.3.0\"),\n            patch(\"swanlab.register_callbacks\") as mock_register,\n            patch(\"axolotl.utils.distributed.is_main_process\", return_value=True),\n            patch(\"axolotl.utils.distributed.get_world_size\", return_value=1),\n        ):\n            with patch(\"swanlab.plugin.notification.LarkCallback\") as MockLarkCallback:\n                mock_lark_instance = MagicMock()\n                MockLarkCallback.return_value = mock_lark_instance\n\n                plugin.pre_model_load(cfg)\n\n                # Verify LarkCallback was instantiated with secret\n                MockLarkCallback.assert_called_once_with(\n                    webhook_url=\"https://open.feishu.cn/open-apis/bot/v2/hook/test-webhook\",\n                    secret=\"test-hmac-secret\",\n                )\n\n                mock_register.assert_called_once_with([mock_lark_instance])\n\n    def test_lark_callback_not_registered_without_webhook(self):\n        \"\"\"Test that Lark callback is NOT registered when webhook URL not provided.\"\"\"\n        plugin = SwanLabPlugin()\n\n        cfg = MagicMock()\n        cfg.use_swanlab = True\n        cfg.swanlab_project = \"test-project\"\n        cfg.swanlab_mode = \"local\"\n        cfg.swanlab_lark_webhook_url = None  # No webhook\n        cfg.swanlab_lark_secret = None\n\n        with (\n            patch(\"swanlab.init\"),\n            patch(\"swanlab.__version__\", \"0.3.0\"),\n            patch(\"swanlab.register_callbacks\") as mock_register,\n            patch(\"axolotl.utils.distributed.is_main_process\", return_value=True),\n            patch(\"axolotl.utils.distributed.get_world_size\", return_value=1),\n        ):\n            plugin.pre_model_load(cfg)\n\n            # Verify register_callbacks was NOT called\n            mock_register.assert_not_called()\n\n    def test_lark_import_error_handled_gracefully(self, caplog):\n        \"\"\"Test that ImportError for Lark plugin is handled gracefully.\"\"\"\n        plugin = SwanLabPlugin()\n\n        cfg = MagicMock()\n        cfg.use_swanlab = True\n        cfg.swanlab_project = \"test-project\"\n        cfg.swanlab_mode = \"local\"\n        cfg.swanlab_lark_webhook_url = (\n            \"https://open.feishu.cn/open-apis/bot/v2/hook/test-webhook\"\n        )\n        cfg.swanlab_lark_secret = None\n\n        with (\n            patch(\"swanlab.init\"),\n            patch(\"swanlab.__version__\", \"0.3.0\"),\n            patch(\"axolotl.utils.distributed.is_main_process\", return_value=True),\n            patch(\"axolotl.utils.distributed.get_world_size\", return_value=1),\n        ):\n            # Mock ImportError for LarkCallback\n            with patch(\n                \"swanlab.plugin.notification.LarkCallback\",\n                side_effect=ImportError(\n                    \"No module named 'swanlab.plugin.notification'\"\n                ),\n            ):\n                with caplog.at_level(logging.WARNING):\n                    plugin.pre_model_load(cfg)\n\n                    # Should log warning about missing Lark plugin\n                    warning_messages = [record.message for record in caplog.records]\n                    assert any(\n                        \"Failed to import SwanLab Lark plugin\" in msg\n                        for msg in warning_messages\n                    )\n                    assert any(\"SwanLab >= 0.3.0\" in msg for msg in warning_messages)\n\n    def test_lark_warning_for_missing_secret(self, caplog):\n        \"\"\"Test that warning is logged when Lark webhook has no HMAC secret.\"\"\"\n        plugin = SwanLabPlugin()\n\n        cfg = MagicMock()\n        cfg.use_swanlab = True\n        cfg.swanlab_project = \"test-project\"\n        cfg.swanlab_mode = \"local\"\n        cfg.swanlab_lark_webhook_url = (\n            \"https://open.feishu.cn/open-apis/bot/v2/hook/test-webhook\"\n        )\n        cfg.swanlab_lark_secret = None  # No secret\n\n        with (\n            patch(\"swanlab.init\"),\n            patch(\"swanlab.__version__\", \"0.3.0\"),\n            patch(\"swanlab.register_callbacks\"),\n            patch(\"axolotl.utils.distributed.is_main_process\", return_value=True),\n            patch(\"axolotl.utils.distributed.get_world_size\", return_value=1),\n        ):\n            with patch(\"swanlab.plugin.notification.LarkCallback\"):\n                with caplog.at_level(logging.WARNING):\n                    plugin.pre_model_load(cfg)\n\n                    # Should log warning about missing secret\n                    warning_messages = [record.message for record in caplog.records]\n                    assert any(\n                        \"no secret configured\" in msg.lower()\n                        for msg in warning_messages\n                    )\n                    assert any(\"swanlab_lark_secret\" in msg for msg in warning_messages)\n\n\n@pytest.mark.skipif(not SWANLAB_INSTALLED, reason=\"swanlab package not installed\")\nclass TestSwanLabPluginIntegration:\n    \"\"\"Integration tests for SwanLab plugin lifecycle.\"\"\"\n\n    def test_full_lifecycle_valid_config(self):\n        \"\"\"Test full plugin lifecycle with valid configuration.\"\"\"\n        plugin = SwanLabPlugin()\n\n        # Register\n        cfg_dict = {\n            \"use_swanlab\": True,\n            \"swanlab_project\": \"test-project\",\n            \"swanlab_mode\": \"local\",\n        }\n        plugin.register(cfg_dict)\n\n        # Pre-model load (mock SwanLab)\n        cfg_obj = MagicMock()\n        cfg_obj.use_swanlab = True\n        cfg_obj.swanlab_project = \"test-project\"\n        cfg_obj.swanlab_mode = \"local\"\n        cfg_obj.swanlab_lark_webhook_url = None  # No Lark\n\n        with (\n            patch(\"swanlab.init\") as mock_init,\n            patch(\"swanlab.__version__\", \"0.3.0\"),\n            patch(\"axolotl.utils.distributed.is_main_process\", return_value=True),\n            patch(\"axolotl.utils.distributed.get_world_size\", return_value=1),\n        ):\n            plugin.pre_model_load(cfg_obj)\n            # Should call swanlab.init\n            mock_init.assert_called_once()\n\n    def test_lifecycle_with_multi_logger_warning(self, caplog):\n        \"\"\"Test lifecycle with multi-logger warning.\"\"\"\n        plugin = SwanLabPlugin()\n\n        cfg_dict = {\n            \"use_swanlab\": True,\n            \"swanlab_project\": \"test-project\",\n            \"use_wandb\": True,\n        }\n\n        with caplog.at_level(logging.WARNING):\n            plugin.register(cfg_dict)\n\n        # Should have multi-logger warning\n        warning_messages = [record.message for record in caplog.records]\n        assert any(\"Multiple logging tools\" in msg for msg in warning_messages)\n\n    def test_lifecycle_invalid_config_fails_early(self):\n        \"\"\"Test that invalid config fails at register stage.\"\"\"\n        plugin = SwanLabPlugin()\n\n        cfg_dict = {\n            \"use_swanlab\": True,\n            # Missing swanlab_project\n        }\n\n        # Should fail at register, not pre_model_load\n        with pytest.raises(ValueError):\n            plugin.register(cfg_dict)\n\n    def test_full_lifecycle_with_lark_notifications(self):\n        \"\"\"Test full lifecycle including Lark notification registration.\"\"\"\n        plugin = SwanLabPlugin()\n\n        # Register\n        cfg_dict = {\n            \"use_swanlab\": True,\n            \"swanlab_project\": \"test-project\",\n            \"swanlab_mode\": \"cloud\",\n        }\n        plugin.register(cfg_dict)\n\n        # Pre-model load with Lark config\n        cfg_obj = MagicMock()\n        cfg_obj.use_swanlab = True\n        cfg_obj.swanlab_project = \"test-project\"\n        cfg_obj.swanlab_mode = \"cloud\"\n        cfg_obj.swanlab_lark_webhook_url = (\n            \"https://open.feishu.cn/open-apis/bot/v2/hook/test\"\n        )\n        cfg_obj.swanlab_lark_secret = \"secret123\"\n\n        with (\n            patch(\"swanlab.init\"),\n            patch(\"swanlab.__version__\", \"0.3.0\"),\n            patch(\"swanlab.register_callbacks\") as mock_register,\n            patch(\"axolotl.utils.distributed.is_main_process\", return_value=True),\n            patch(\"axolotl.utils.distributed.get_world_size\", return_value=1),\n        ):\n            with patch(\"swanlab.plugin.notification.LarkCallback\") as MockLarkCallback:\n                mock_lark_instance = MagicMock()\n                MockLarkCallback.return_value = mock_lark_instance\n\n                plugin.pre_model_load(cfg_obj)\n\n                # Verify both SwanLab init AND Lark callback registration\n                MockLarkCallback.assert_called_once()\n                mock_register.assert_called_once_with([mock_lark_instance])\n\n\n@pytest.mark.skipif(not SWANLAB_INSTALLED, reason=\"swanlab package not installed\")\nclass TestCompletionLogger:\n    \"\"\"Tests for CompletionLogger utility class.\"\"\"\n\n    def test_completion_logger_initialization(self):\n        \"\"\"Test CompletionLogger initializes with correct maxlen.\"\"\"\n        from axolotl.integrations.swanlab.completion_logger import CompletionLogger\n\n        logger = CompletionLogger(maxlen=64)\n        assert logger.maxlen == 64\n        assert len(logger) == 0\n\n    def test_add_dpo_completion(self):\n        \"\"\"Test adding DPO completions to buffer.\"\"\"\n        from axolotl.integrations.swanlab.completion_logger import CompletionLogger\n\n        logger = CompletionLogger(maxlen=10)\n\n        logger.add_dpo_completion(\n            step=0,\n            prompt=\"What is AI?\",\n            chosen=\"Artificial Intelligence is...\",\n            rejected=\"AI means...\",\n            reward_diff=0.5,\n        )\n\n        assert len(logger) == 1\n        entry = logger.data[0]\n        assert entry[\"step\"] == 0\n        assert entry[\"prompt\"] == \"What is AI?\"\n        assert entry[\"chosen\"] == \"Artificial Intelligence is...\"\n        assert entry[\"rejected\"] == \"AI means...\"\n        assert entry[\"reward_diff\"] == 0.5\n\n    def test_add_kto_completion(self):\n        \"\"\"Test adding KTO completions to buffer.\"\"\"\n        from axolotl.integrations.swanlab.completion_logger import CompletionLogger\n\n        logger = CompletionLogger(maxlen=10)\n\n        logger.add_kto_completion(\n            step=1,\n            prompt=\"Explain quantum physics\",\n            completion=\"Quantum physics is...\",\n            label=True,\n            reward=0.8,\n        )\n\n        assert len(logger) == 1\n        entry = logger.data[0]\n        assert entry[\"step\"] == 1\n        assert entry[\"prompt\"] == \"Explain quantum physics\"\n        assert entry[\"completion\"] == \"Quantum physics is...\"\n        assert entry[\"label\"] == \"desirable\"\n        assert entry[\"reward\"] == 0.8\n\n    def test_add_orpo_completion(self):\n        \"\"\"Test adding ORPO completions to buffer.\"\"\"\n        from axolotl.integrations.swanlab.completion_logger import CompletionLogger\n\n        logger = CompletionLogger(maxlen=10)\n\n        logger.add_orpo_completion(\n            step=2,\n            prompt=\"Write a poem\",\n            chosen=\"Roses are red...\",\n            rejected=\"Violets are blue...\",\n            log_odds_ratio=1.2,\n        )\n\n        assert len(logger) == 1\n        entry = logger.data[0]\n        assert entry[\"step\"] == 2\n        assert entry[\"chosen\"] == \"Roses are red...\"\n        assert entry[\"rejected\"] == \"Violets are blue...\"\n        assert entry[\"log_odds_ratio\"] == 1.2\n\n    def test_add_grpo_completion(self):\n        \"\"\"Test adding GRPO completions to buffer.\"\"\"\n        from axolotl.integrations.swanlab.completion_logger import CompletionLogger\n\n        logger = CompletionLogger(maxlen=10)\n\n        logger.add_grpo_completion(\n            step=3,\n            prompt=\"Solve this problem\",\n            completion=\"The answer is 42\",\n            reward=0.9,\n            advantage=0.3,\n        )\n\n        assert len(logger) == 1\n        entry = logger.data[0]\n        assert entry[\"step\"] == 3\n        assert entry[\"completion\"] == \"The answer is 42\"\n        assert entry[\"reward\"] == 0.9\n        assert entry[\"advantage\"] == 0.3\n\n    def test_memory_bounded_buffer(self):\n        \"\"\"Test that buffer respects maxlen and drops oldest entries.\"\"\"\n        from axolotl.integrations.swanlab.completion_logger import CompletionLogger\n\n        logger = CompletionLogger(maxlen=3)\n\n        # Add 5 completions\n        for i in range(5):\n            logger.add_dpo_completion(\n                step=i,\n                prompt=f\"Prompt {i}\",\n                chosen=f\"Chosen {i}\",\n                rejected=f\"Rejected {i}\",\n            )\n\n        # Should only keep last 3\n        assert len(logger) == 3\n        assert logger.data[0][\"step\"] == 2  # Oldest kept\n        assert logger.data[1][\"step\"] == 3\n        assert logger.data[2][\"step\"] == 4  # Newest\n\n    def test_log_to_swanlab_when_not_initialized(self):\n        \"\"\"Test logging gracefully fails when SwanLab not initialized.\"\"\"\n        from axolotl.integrations.swanlab.completion_logger import CompletionLogger\n\n        logger = CompletionLogger(maxlen=10)\n        logger.add_dpo_completion(\n            step=0,\n            prompt=\"Test\",\n            chosen=\"A\",\n            rejected=\"B\",\n        )\n\n        with patch(\"swanlab.get_run\", return_value=None):\n            result = logger.log_to_swanlab()\n            assert result is False  # Should fail gracefully\n\n    def test_log_to_swanlab_success(self):\n        \"\"\"Test successful logging to SwanLab.\"\"\"\n        from axolotl.integrations.swanlab.completion_logger import CompletionLogger\n\n        logger = CompletionLogger(maxlen=10)\n        logger.add_dpo_completion(\n            step=0,\n            prompt=\"Test prompt\",\n            chosen=\"Chosen response\",\n            rejected=\"Rejected response\",\n            reward_diff=0.5,\n        )\n\n        with (\n            patch(\"swanlab.get_run\") as mock_get_run,\n            patch(\"swanlab.log\") as mock_log,\n            patch(\"swanlab.echarts.Table\") as MockTable,\n        ):\n            mock_get_run.return_value = MagicMock()  # SwanLab initialized\n            mock_table_instance = MagicMock()\n            MockTable.return_value = mock_table_instance\n\n            result = logger.log_to_swanlab(table_name=\"test_table\")\n\n            assert result is True\n            mock_log.assert_called_once()\n            mock_table_instance.add.assert_called_once()\n\n    def test_clear_buffer(self):\n        \"\"\"Test clearing the completion buffer.\"\"\"\n        from axolotl.integrations.swanlab.completion_logger import CompletionLogger\n\n        logger = CompletionLogger(maxlen=10)\n        logger.add_dpo_completion(\n            step=0,\n            prompt=\"Test\",\n            chosen=\"A\",\n            rejected=\"B\",\n        )\n\n        assert len(logger) == 1\n        logger.clear()\n        assert len(logger) == 0\n\n    def test_repr(self):\n        \"\"\"Test string representation.\"\"\"\n        from axolotl.integrations.swanlab.completion_logger import CompletionLogger\n\n        logger = CompletionLogger(maxlen=128)\n        logger.add_dpo_completion(\n            step=0,\n            prompt=\"Test\",\n            chosen=\"A\",\n            rejected=\"B\",\n        )\n\n        repr_str = repr(logger)\n        assert \"CompletionLogger\" in repr_str\n        assert \"maxlen=128\" in repr_str\n        assert \"buffered=1/128\" in repr_str\n\n\n@pytest.mark.skipif(not SWANLAB_INSTALLED, reason=\"swanlab package not installed\")\nclass TestSwanLabRLHFCompletionCallback:\n    \"\"\"Tests for SwanLabRLHFCompletionCallback.\"\"\"\n\n    def test_callback_initialization(self):\n        \"\"\"Test callback initializes with correct parameters.\"\"\"\n        from axolotl.integrations.swanlab.callbacks import SwanLabRLHFCompletionCallback\n\n        callback = SwanLabRLHFCompletionCallback(\n            log_interval=50,\n            max_completions=64,\n            table_name=\"custom_table\",\n        )\n\n        assert callback.log_interval == 50\n        assert callback.logger.maxlen == 64\n        assert callback.table_name == \"custom_table\"\n        assert callback.trainer_type is None\n\n    def test_trainer_type_detection_dpo(self):\n        \"\"\"Test DPO trainer type is detected correctly.\"\"\"\n        from axolotl.integrations.swanlab.callbacks import SwanLabRLHFCompletionCallback\n\n        callback = SwanLabRLHFCompletionCallback()\n\n        # Mock trainer with DPO in name\n        mock_trainer = MagicMock()\n        mock_trainer.__class__.__name__ = \"AxolotlDPOTrainer\"\n\n        callback.on_init_end(\n            args=MagicMock(),\n            state=MagicMock(),\n            control=MagicMock(),\n            trainer=mock_trainer,\n        )\n\n        assert callback.trainer_type == \"dpo\"\n\n    def test_trainer_type_detection_kto(self):\n        \"\"\"Test KTO trainer type is detected correctly.\"\"\"\n        from axolotl.integrations.swanlab.callbacks import SwanLabRLHFCompletionCallback\n\n        callback = SwanLabRLHFCompletionCallback()\n\n        mock_trainer = MagicMock()\n        mock_trainer.__class__.__name__ = \"AxolotlKTOTrainer\"\n\n        callback.on_init_end(\n            args=MagicMock(),\n            state=MagicMock(),\n            control=MagicMock(),\n            trainer=mock_trainer,\n        )\n\n        assert callback.trainer_type == \"kto\"\n\n    def test_on_train_end_logs_completions(self):\n        \"\"\"Test that completions are logged at end of training.\"\"\"\n        from axolotl.integrations.swanlab.callbacks import SwanLabRLHFCompletionCallback\n\n        callback = SwanLabRLHFCompletionCallback()\n        callback.trainer_type = \"dpo\"\n\n        # Add some completions to buffer\n        callback.logger.add_dpo_completion(\n            step=0,\n            prompt=\"Test\",\n            chosen=\"A\",\n            rejected=\"B\",\n        )\n\n        with patch.object(callback.logger, \"log_to_swanlab\") as mock_log:\n            callback.on_train_end(\n                args=MagicMock(),\n                state=MagicMock(global_step=100),\n                control=MagicMock(),\n            )\n\n            # Should log remaining completions\n            mock_log.assert_called_once()\n\n\n@pytest.mark.skipif(not SWANLAB_INSTALLED, reason=\"swanlab package not installed\")\nclass TestSwanLabPluginCompletionIntegration:\n    \"\"\"Integration tests for completion logging in SwanLabPlugin.\"\"\"\n\n    def test_completion_callback_registered_for_dpo_trainer(self):\n        \"\"\"Test that completion callback is registered for DPO trainer.\"\"\"\n        from axolotl.integrations.swanlab.plugins import SwanLabPlugin\n        from axolotl.utils.dict import DictDefault\n\n        plugin = SwanLabPlugin()\n        plugin.swanlab_initialized = True  # Simulate SwanLab initialized\n\n        cfg = {\n            \"use_swanlab\": True,\n            \"swanlab_project\": \"test-project\",\n            \"swanlab_log_completions\": True,\n            \"swanlab_completion_log_interval\": 50,\n            \"swanlab_completion_max_buffer\": 64,\n        }\n        cfg_obj = DictDefault(cfg)\n\n        # Mock DPO trainer\n        mock_trainer = MagicMock()\n        mock_trainer.__class__.__name__ = \"AxolotlDPOTrainer\"\n        mock_trainer.state = MagicMock(max_steps=1000)\n        mock_trainer.args = MagicMock(\n            num_train_epochs=3,\n            train_batch_size=4,\n            gradient_accumulation_steps=2,\n        )\n\n        with patch(\"swanlab.config.update\"):\n            plugin.post_trainer_create(cfg_obj, mock_trainer)\n\n        # Verify callback was added\n        mock_trainer.add_callback.assert_called_once()\n        callback = mock_trainer.add_callback.call_args[0][0]\n        assert callback.__class__.__name__ == \"SwanLabRLHFCompletionCallback\"\n        assert callback.log_interval == 50\n        assert callback.logger.maxlen == 64\n\n    def test_completion_callback_not_registered_for_non_rlhf_trainer(self):\n        \"\"\"Test that completion callback is NOT registered for non-RLHF trainers.\"\"\"\n        from axolotl.integrations.swanlab.plugins import SwanLabPlugin\n        from axolotl.utils.dict import DictDefault\n\n        plugin = SwanLabPlugin()\n        plugin.swanlab_initialized = True\n\n        cfg = {\n            \"use_swanlab\": True,\n            \"swanlab_project\": \"test-project\",\n            \"swanlab_log_completions\": True,\n        }\n        cfg_obj = DictDefault(cfg)\n\n        # Mock regular SFT trainer (not RLHF)\n        mock_trainer = MagicMock()\n        mock_trainer.__class__.__name__ = \"AxolotlTrainer\"  # Not RLHF\n        mock_trainer.state = MagicMock(max_steps=1000)\n        mock_trainer.args = MagicMock()\n\n        with patch(\"swanlab.config.update\"):\n            plugin.post_trainer_create(cfg_obj, mock_trainer)\n\n        # Callback should NOT be added for non-RLHF trainer\n        mock_trainer.add_callback.assert_not_called()\n\n    def test_completion_callback_not_registered_when_disabled(self):\n        \"\"\"Test that completion callback is not registered when disabled in config.\"\"\"\n        from axolotl.integrations.swanlab.plugins import SwanLabPlugin\n        from axolotl.utils.dict import DictDefault\n\n        plugin = SwanLabPlugin()\n        plugin.swanlab_initialized = True\n\n        cfg = {\n            \"use_swanlab\": True,\n            \"swanlab_project\": \"test-project\",\n            \"swanlab_log_completions\": False,  # Disabled\n        }\n        cfg_obj = DictDefault(cfg)\n\n        # Mock DPO trainer\n        mock_trainer = MagicMock()\n        mock_trainer.__class__.__name__ = \"AxolotlDPOTrainer\"\n        mock_trainer.state = MagicMock(max_steps=1000)\n        mock_trainer.args = MagicMock()\n\n        with patch(\"swanlab.config.update\"):\n            plugin.post_trainer_create(cfg_obj, mock_trainer)\n\n        # Callback should NOT be added when disabled\n        mock_trainer.add_callback.assert_not_called()\n\n\n@pytest.mark.skipif(not SWANLAB_INSTALLED, reason=\"swanlab package not installed\")\nclass TestSwanLabProfiling:\n    \"\"\"Tests for SwanLab profiling utilities.\"\"\"\n\n    def test_profiling_context_logs_duration(self):\n        \"\"\"Test that profiling context logs execution duration.\"\"\"\n        from axolotl.integrations.swanlab.profiling import swanlab_profiling_context\n\n        # Mock trainer with SwanLab enabled\n        mock_trainer = MagicMock()\n        mock_trainer.cfg = MagicMock(use_swanlab=True)\n        mock_trainer.__class__.__name__ = \"TestTrainer\"\n\n        with patch(\"swanlab.get_run\") as mock_get_run, patch(\"swanlab.log\") as mock_log:\n            mock_get_run.return_value = MagicMock()  # SwanLab initialized\n\n            with swanlab_profiling_context(mock_trainer, \"test_function\"):\n                time.sleep(0.01)  # Simulate work\n\n            # Verify log was called with correct metric name\n            mock_log.assert_called_once()\n            logged_data = mock_log.call_args[0][0]\n            assert \"profiling/Time taken: TestTrainer.test_function\" in logged_data\n            # Duration should be > 0.01 seconds\n            assert (\n                logged_data[\"profiling/Time taken: TestTrainer.test_function\"] >= 0.01\n            )\n\n    def test_profiling_context_skips_when_swanlab_disabled(self):\n        \"\"\"Test that profiling is skipped when SwanLab is disabled.\"\"\"\n        from axolotl.integrations.swanlab.profiling import swanlab_profiling_context\n\n        mock_trainer = MagicMock()\n        mock_trainer.cfg = MagicMock(use_swanlab=False)  # Disabled\n\n        with patch(\"swanlab.log\") as mock_log:\n            with swanlab_profiling_context(mock_trainer, \"test_function\"):\n                time.sleep(0.01)\n\n            # Should NOT log when disabled\n            mock_log.assert_not_called()\n\n    def test_profiling_context_skips_when_swanlab_not_initialized(self):\n        \"\"\"Test that profiling is skipped when SwanLab not initialized.\"\"\"\n        from axolotl.integrations.swanlab.profiling import swanlab_profiling_context\n\n        mock_trainer = MagicMock()\n        mock_trainer.cfg = MagicMock(use_swanlab=True)\n\n        with (\n            patch(\"swanlab.get_run\", return_value=None),\n            patch(\"swanlab.log\") as mock_log,\n        ):\n            with swanlab_profiling_context(mock_trainer, \"test_function\"):\n                time.sleep(0.01)\n\n            # Should NOT log when not initialized\n            mock_log.assert_not_called()\n\n    def test_profiling_decorator(self):\n        \"\"\"Test swanlab_profile decorator.\"\"\"\n        from axolotl.integrations.swanlab.profiling import swanlab_profile\n\n        class MockTrainer:\n            def __init__(self):\n                self.cfg = MagicMock(use_swanlab=True)\n\n            @swanlab_profile\n            def expensive_method(self, x):\n                time.sleep(0.01)\n                return x * 2\n\n        trainer = MockTrainer()\n\n        with patch(\"swanlab.get_run\") as mock_get_run, patch(\"swanlab.log\") as mock_log:\n            mock_get_run.return_value = MagicMock()\n\n            result = trainer.expensive_method(5)\n\n            # Verify method still works correctly\n            assert result == 10\n\n            # Verify profiling was logged\n            mock_log.assert_called_once()\n            logged_data = mock_log.call_args[0][0]\n            assert \"profiling/Time taken: MockTrainer.expensive_method\" in logged_data\n\n    def test_profiling_config(self):\n        \"\"\"Test ProfilingConfig class.\"\"\"\n        from axolotl.integrations.swanlab.profiling import ProfilingConfig\n\n        config = ProfilingConfig(\n            enabled=True,\n            min_duration_ms=1.0,\n            log_interval=5,\n        )\n\n        # Test enabled check\n        assert config.enabled is True\n\n        # Test minimum duration filtering\n        assert config.should_log(\"func1\", 0.0001) is False  # 0.1ms < 1.0ms threshold\n        assert config.should_log(\"func2\", 0.002) is True  # 2.0ms > 1.0ms threshold\n\n        # Test log interval\n        assert config.should_log(\"func3\", 0.002) is True  # 1st call\n        assert config.should_log(\"func3\", 0.002) is False  # 2nd call\n        assert config.should_log(\"func3\", 0.002) is False  # 3rd call\n        assert config.should_log(\"func3\", 0.002) is False  # 4th call\n        assert config.should_log(\"func3\", 0.002) is True  # 5th call (interval=5)\n\n    def test_profiling_config_when_disabled(self):\n        \"\"\"Test ProfilingConfig when disabled.\"\"\"\n        from axolotl.integrations.swanlab.profiling import ProfilingConfig\n\n        config = ProfilingConfig(enabled=False)\n\n        # Should never log when disabled\n        assert config.should_log(\"func1\", 100.0) is False\n\n    def test_profiling_context_advanced(self):\n        \"\"\"Test advanced profiling context with custom config.\"\"\"\n        from axolotl.integrations.swanlab.profiling import (\n            ProfilingConfig,\n            swanlab_profiling_context_advanced,\n        )\n\n        mock_trainer = MagicMock()\n        mock_trainer.cfg = MagicMock(use_swanlab=True)\n        mock_trainer.__class__.__name__ = \"TestTrainer\"\n\n        # Config that filters out very fast operations\n        config = ProfilingConfig(min_duration_ms=10.0)  # 10ms minimum\n\n        with patch(\"swanlab.get_run\") as mock_get_run, patch(\"swanlab.log\") as mock_log:\n            mock_get_run.return_value = MagicMock()\n\n            # Fast operation (< 10ms) - should NOT log\n            with swanlab_profiling_context_advanced(mock_trainer, \"fast_op\", config):\n                time.sleep(0.001)  # 1ms\n\n            mock_log.assert_not_called()\n\n            # Slow operation (> 10ms) - should log\n            with swanlab_profiling_context_advanced(mock_trainer, \"slow_op\", config):\n                time.sleep(0.015)  # 15ms\n\n            mock_log.assert_called_once()\n\n    def test_profiling_with_exception(self):\n        \"\"\"Test that profiling still logs even when exception occurs.\"\"\"\n        from axolotl.integrations.swanlab.profiling import swanlab_profiling_context\n\n        mock_trainer = MagicMock()\n        mock_trainer.cfg = MagicMock(use_swanlab=True)\n        mock_trainer.__class__.__name__ = \"TestTrainer\"\n\n        with patch(\"swanlab.get_run\") as mock_get_run, patch(\"swanlab.log\") as mock_log:\n            mock_get_run.return_value = MagicMock()\n\n            try:\n                with swanlab_profiling_context(mock_trainer, \"error_function\"):\n                    time.sleep(0.01)\n                    raise ValueError(\"Test error\")\n            except ValueError:\n                pass  # Expected\n\n            # Should still log duration even with exception\n            mock_log.assert_called_once()\n"
  },
  {
    "path": "tests/monkeypatch/test_llama_attn_hijack_flash.py",
    "content": "\"\"\"\nUnit tests for the monkeypatch utils\n\"\"\"\n\nimport unittest\n\nimport torch\n\nfrom axolotl.monkeypatch.utils import (\n    get_cu_seqlens,\n    get_cu_seqlens_from_pos_ids,\n    get_max_seqlen_in_batch,\n    get_unpad_data,\n)\n\n\nclass TestMonkeyPatchUtils(unittest.TestCase):\n    \"\"\"\n    Unit test class for monkeypatch utils\n    \"\"\"\n\n    def test_get_cu_seqlens_1d(self):\n        attn_mask = torch.tensor([[1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 0, 0]])\n        target_res = torch.tensor([0, 4, 7, 12, 14, 16], dtype=torch.int32)\n        self.assertTrue(torch.allclose(get_cu_seqlens(attn_mask)[0], target_res))\n\n    def test_get_cu_seqlens_from_pos_ids_1d(self):\n        position_ids = torch.tensor([[0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0, 1, 0, 0]])\n        target_res = torch.tensor([0, 4, 7, 12, 14, 16], dtype=torch.int32)\n        self.assertTrue(\n            torch.allclose(get_cu_seqlens_from_pos_ids(position_ids)[0], target_res)\n        )\n\n    def test_get_cu_seqlens_from_pos_ids_2d(self):\n        position_ids = torch.tensor(\n            [\n                [0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0, 1, 0, 0],\n                [0, 1, 2, 3, 4, 0, 1, 2, 0, 1, 2, 3, 4, 5, 6, 0],\n            ]\n        )\n        target_res = torch.tensor(\n            [[0, 4, 7, 12, 14, 16], [0, 5, 8, 15, 16, 16]], dtype=torch.int32\n        )\n        self.assertTrue(\n            torch.allclose(get_cu_seqlens_from_pos_ids(position_ids)[0], target_res)\n        )\n\n    def test_get_max_seqlen_in_batch(self):\n        attn_mask = torch.tensor([[1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 0, 0]])\n        target_res = torch.tensor([4, 3, 5, 2], dtype=torch.int32)\n        self.assertTrue(torch.allclose(get_max_seqlen_in_batch(attn_mask), target_res))\n\n    def test_get_unpad_data(self):\n        attn_mask = torch.tensor([[1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 0, 0]])\n        target_indices = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13])\n        target_cu_seqlen = torch.tensor([0, 4, 7, 12, 14], dtype=torch.int32)\n        target_max_seqlen_in_batch = 5\n        indices, cu_seqlen, max_seqlen_in_batch = get_unpad_data(attn_mask)\n        self.assertTrue(torch.allclose(target_indices, indices))\n        self.assertTrue(torch.allclose(target_cu_seqlen, cu_seqlen))\n        self.assertEqual(target_max_seqlen_in_batch, max_seqlen_in_batch)\n\n        attn_mask = torch.tensor(\n            [\n                [1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 0, 0],\n                [1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 4, 4, 4, 5, 5, 5],\n            ]\n        )\n        target_indices = torch.tensor(\n            [\n                0,\n                1,\n                2,\n                3,\n                4,\n                5,\n                6,\n                7,\n                8,\n                9,\n                10,\n                11,\n                12,\n                13,\n                16,\n                17,\n                18,\n                19,\n                20,\n                21,\n                22,\n                23,\n                24,\n                25,\n                26,\n                27,\n                28,\n                29,\n                30,\n                31,\n            ]\n        )\n        target_cu_seqlen = torch.tensor(\n            [0, 4, 7, 12, 14, 17, 22, 24, 27, 30], dtype=torch.int32\n        )\n        target_max_seqlen_in_batch = 5\n        indices, cu_seqlen, max_seqlen_in_batch = get_unpad_data(attn_mask)\n        self.assertTrue(torch.allclose(target_indices, indices))\n        self.assertTrue(torch.allclose(target_cu_seqlen, cu_seqlen))\n        self.assertEqual(target_max_seqlen_in_batch, max_seqlen_in_batch)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/monkeypatch/test_pixtral_flash_attention_patch.py",
    "content": "\"\"\"Integration tests for Pixtral Flash Attention patches.\"\"\"\n\nimport pytest\nimport torch\n\n\nclass TestPixtralFlashAttentionPatchIntegration:\n    \"\"\"Test Pixtral Flash Attention patch integration.\"\"\"\n\n    @pytest.mark.integration\n    def test_pixtral_flash_attention_patch(self):\n        \"\"\"Test that Pixtral Flash Attention patch can be applied and works correctly.\"\"\"\n        try:\n            from transformers import modeling_flash_attention_utils\n        except ImportError:\n            pytest.skip(\"Flash Attention utils not available\")\n\n        from axolotl.monkeypatch.models.pixtral.modeling_flash_attention_utils import (\n            apply_patch_is_packed_sequence,\n        )\n\n        # Store original method\n        original_is_packed_sequence = modeling_flash_attention_utils._is_packed_sequence\n\n        # Apply patch and get unpatch function\n        unpatch_fn = apply_patch_is_packed_sequence()\n\n        # Verify patch was applied\n        assert (\n            modeling_flash_attention_utils._is_packed_sequence\n            != original_is_packed_sequence\n        ), \"_is_packed_sequence was not patched\"\n\n        # Test the patched function with 1D position_ids\n        patched_fn = modeling_flash_attention_utils._is_packed_sequence\n\n        # Test 1D position_ids 1 sequence\n        position_ids_1d = torch.tensor([0, 1, 2, 3])\n        result = patched_fn(position_ids_1d, batch_size=1)\n        assert isinstance(result, bool), \"Function should return a boolean\"\n        assert result is False, \"1D sequential position_ids should not be packed\"\n\n        # Test 1D packed 2 sequences\n        position_ids_1d_packed = torch.tensor([0, 1, 2, 0, 1, 2])\n        result = patched_fn(position_ids_1d_packed, batch_size=1)\n        assert isinstance(result, bool), \"Function should return a boolean\"\n        assert result is True, \"1D packed position_ids should be detected as packed\"\n\n        # Test 2D packed 2 sequences\n        position_ids_2d_packed = torch.tensor([[0, 1, 2, 3, 0, 1]])\n        result = patched_fn(position_ids_2d_packed, batch_size=1)\n        assert isinstance(result, bool), \"Function should return a boolean\"\n        assert result is True, \"2D packed position_ids should be detected as packed\"\n\n        # Test 2D 1 sequence\n        position_ids_2d_normal = torch.tensor([[0, 1, 2, 3, 4, 5]])\n        result = patched_fn(position_ids_2d_normal, batch_size=1)\n        assert isinstance(result, bool), \"Function should return a boolean\"\n        assert result is False, \"2D sequential position_ids should not be packed\"\n\n        # Test 2D batch size 2\n        position_ids_2d_normal = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8]])\n        result = patched_fn(position_ids_2d_normal, batch_size=2)\n        assert isinstance(result, bool), \"Function should return a boolean\"\n        assert result is False, \"2D position_ids batch 2 should not be packed\"\n\n        # Test None case\n        result = patched_fn(None, batch_size=1)\n        assert isinstance(result, bool), \"Function should return a boolean\"\n        assert result is False, \"None position_ids should return False\"\n\n        # Test unpatch function\n        unpatch_fn()\n        assert (\n            modeling_flash_attention_utils._is_packed_sequence\n            == original_is_packed_sequence\n        ), \"unpatch function did not restore original method\"\n"
  },
  {
    "path": "tests/monkeypatch/test_qwen3_next_modeling_patch.py",
    "content": "\"\"\"Integration tests for Qwen3 Next modeling patches.\"\"\"\n\nimport pytest\nimport torch\n\n# Skip entire module if qwen3_next not available\nqwen3_next = pytest.importorskip(\"transformers.models.qwen3_next.modeling_qwen3_next\")\n\n\nclass TestQwen3NextModelingPatchIntegration:\n    \"\"\"Test Qwen3 Next modeling patch integration.\"\"\"\n\n    @pytest.mark.integration\n    def test_qwen3_next_decoder_layer_patch(self):\n        \"\"\"Test that Qwen3Next decoder layer patch can be applied.\"\"\"\n        from axolotl.monkeypatch.models.qwen3_next.modeling import (\n            patch_qwen3_next_decoder_layer,\n        )\n\n        # Store original method\n        original_forward = qwen3_next.Qwen3NextDecoderLayer.forward\n\n        # Apply patch and get unpatch function\n        unpatch_fn = patch_qwen3_next_decoder_layer()\n\n        # Verify patch was applied\n        assert qwen3_next.Qwen3NextDecoderLayer.forward != original_forward, (\n            \"decoder layer forward method was not patched\"\n        )\n\n        # Verify the method is still callable\n        assert callable(qwen3_next.Qwen3NextDecoderLayer.forward), (\n            \"Patched method is not callable\"\n        )\n\n        # Test unpatch function\n        if unpatch_fn:\n            unpatch_fn()\n            assert qwen3_next.Qwen3NextDecoderLayer.forward == original_forward, (\n                \"unpatch function did not restore original method\"\n            )\n\n    @pytest.mark.integration\n    def test_qwen3_next_gateddelta_layer_patch(self):\n        \"\"\"Test that Qwen3Next GatedDeltaNet patch can be applied.\"\"\"\n        from axolotl.monkeypatch.models.qwen3_next.modeling import (\n            patch_qwen3_next_gateddelta_layer,\n        )\n\n        # Store original method\n        original_forward = qwen3_next.Qwen3NextGatedDeltaNet.forward\n\n        # Apply patch and get unpatch function\n        unpatch_fn = patch_qwen3_next_gateddelta_layer()\n\n        # Verify patch was applied\n        assert qwen3_next.Qwen3NextGatedDeltaNet.forward != original_forward, (\n            \"GatedDeltaNet forward method was not patched\"\n        )\n\n        # Verify the method is still callable\n        assert callable(qwen3_next.Qwen3NextGatedDeltaNet.forward), (\n            \"Patched method is not callable\"\n        )\n\n        # Test unpatch function\n        if unpatch_fn:\n            unpatch_fn()\n            assert qwen3_next.Qwen3NextGatedDeltaNet.forward == original_forward, (\n                \"unpatch function did not restore original method\"\n            )\n\n    @pytest.mark.integration\n    def test_qwen3_next_imports_patch(self):\n        \"\"\"Test that Qwen3Next imports patch can be applied without errors.\"\"\"\n        from axolotl.monkeypatch.models.qwen3_next.modeling import (\n            patch_qwen3_next_imports,\n        )\n\n        # Apply patch - should not raise any exceptions even if modules unavailable\n        unpatch_fn = patch_qwen3_next_imports()\n\n        # Test that unpatch function is returned (or None if skipped)\n        assert unpatch_fn is None or callable(unpatch_fn), (\n            \"patch_qwen3_next_imports should return None or callable unpatch function\"\n        )\n\n    @pytest.mark.integration\n    def test_qwen3_next_modeling_packing_patch(self):\n        \"\"\"Test that all Qwen3Next modeling patches can be applied together.\"\"\"\n        from axolotl.monkeypatch.models.qwen3_next.modeling import (\n            patch_qwen3_next_modeling_packing,\n        )\n\n        # This should not raise any exceptions\n        patch_qwen3_next_modeling_packing()\n\n\n@pytest.mark.integration\ndef test_get_cu_seqlens_utility():\n    \"\"\"Test the get_cu_seqlens utility function.\"\"\"\n    from axolotl.monkeypatch.models.qwen3_next.modeling import get_cu_seqlens\n\n    # Test with simple position_ids\n    position_ids = torch.tensor([[0, 1, 2, 0, 1]])\n    cu_seqlens = get_cu_seqlens(position_ids)\n    assert cu_seqlens.dtype == torch.int32, \"Should be int32 dtype\"\n\n    # Should return tensor with start positions and total length\n    expected = torch.tensor([0, 3, 5], dtype=torch.int32)\n    assert torch.equal(cu_seqlens, expected), f\"Expected {expected}, got {cu_seqlens}\"\n"
  },
  {
    "path": "tests/monkeypatch/test_trainer_accelerator_args.py",
    "content": "\"\"\"\nUnit tests for trainer accelerator args monkeypatch\n\"\"\"\n\nimport unittest\n\nfrom axolotl.monkeypatch.trainer_accelerator_args import (\n    check_create_accelerate_code_is_patchable,\n)\n\n\nclass TestTrainerAcceleratorArgs(unittest.TestCase):\n    \"\"\"\n    Unit test class for trainer accelerator args monkeypatch\n    \"\"\"\n\n    def test_check_create_accelerate_code_is_patchable(self):\n        \"\"\"\n        Test that the upstream transformers code is still patchable.\n        This will fail if the patched code changes upstream.\n        \"\"\"\n        assert check_create_accelerate_code_is_patchable()\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/monkeypatch/test_trainer_context_parallel_patch.py",
    "content": "\"\"\"Tests for the HF Trainer context parallel patch.\"\"\"\n\nimport pytest\nfrom transformers import Trainer\n\nfrom axolotl.monkeypatch.transformers.trainer_context_parallel import (\n    GUARD_PATTERN,\n    PATCHED_GUARD,\n    patch_prepare_context_parallel_inputs,\n)\n\n\n@pytest.fixture\ndef restore_trainer_prepare_method():\n    \"\"\"Ensure Trainer._prepare_context_parallel_inputs is restored after a test.\"\"\"\n    original_method = getattr(\n        Trainer,\n        \"_original_prepare_context_parallel_inputs\",\n        Trainer._prepare_context_parallel_inputs,\n    )\n    patched_attr_present = hasattr(\n        Trainer, \"_axolotl_prepare_context_parallel_inputs_patched\"\n    )\n\n    yield\n\n    Trainer._prepare_context_parallel_inputs = original_method\n    if patched_attr_present:\n        delattr(Trainer, \"_axolotl_prepare_context_parallel_inputs_patched\")\n    if hasattr(Trainer, \"_original_prepare_context_parallel_inputs\"):\n        delattr(Trainer, \"_original_prepare_context_parallel_inputs\")\n    if hasattr(Trainer, \"_axolotl_prepare_context_parallel_inputs_source\"):\n        delattr(Trainer, \"_axolotl_prepare_context_parallel_inputs_source\")\n\n\ndef test_patch_attention_guard(restore_trainer_prepare_method):\n    \"\"\"Patch should swap the guard to allow sdpa or flash attention.\"\"\"\n    # Ensure we start from the unpatched method\n    if hasattr(Trainer, \"_original_prepare_context_parallel_inputs\"):\n        Trainer._prepare_context_parallel_inputs = (\n            Trainer._original_prepare_context_parallel_inputs\n        )\n        delattr(Trainer, \"_original_prepare_context_parallel_inputs\")\n    if hasattr(Trainer, \"_axolotl_prepare_context_parallel_inputs_patched\"):\n        delattr(Trainer, \"_axolotl_prepare_context_parallel_inputs_patched\")\n\n    patch_prepare_context_parallel_inputs()\n\n    patched_method = Trainer._prepare_context_parallel_inputs\n    assert patched_method is not None\n    assert getattr(Trainer, \"_axolotl_prepare_context_parallel_inputs_patched\", False)\n\n    source = Trainer._axolotl_prepare_context_parallel_inputs_source\n    assert GUARD_PATTERN not in source\n    assert PATCHED_GUARD in source\n\n\ndef test_patch_is_idempotent(restore_trainer_prepare_method):\n    \"\"\"Calling the patch twice should leave the same patched function in place.\"\"\"\n    patch_prepare_context_parallel_inputs()\n    first_patched = Trainer._prepare_context_parallel_inputs\n\n    patch_prepare_context_parallel_inputs()\n    second_patched = Trainer._prepare_context_parallel_inputs\n\n    assert first_patched is second_patched\n"
  },
  {
    "path": "tests/monkeypatch/test_trainer_loss_calc.py",
    "content": "\"\"\"Unit tests for trainer loss calc monkeypatch.\"\"\"\n\nimport unittest\n\nfrom axolotl.monkeypatch.transformers.trainer_loss_calc import (\n    check_evaluation_loop_is_patchable,\n    check_maybe_log_save_evaluate_is_patchable,\n)\n\n\nclass TestTrainerLossCalc(unittest.TestCase):\n    \"\"\"\n    Unit test class for trainer loss calc monkeypatch\n    \"\"\"\n\n    def test_trainer_loss_calc_is_patchable(self):\n        \"\"\"\n        Test that the upstream transformers code is still patchable. This will fail if\n        the patched code changes upstream.\n        \"\"\"\n        assert check_evaluation_loop_is_patchable()\n        assert check_maybe_log_save_evaluate_is_patchable()\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/monkeypatch/test_trl_vllm.py",
    "content": "\"\"\"Unit tests for TRL vLLM monkeypatches.\n\nTests:\n- split_tensor_dict: scalar type preservation (int/float/bool)\n- shuffle_sequence_dict: scalar type preservation\n- extract_logprobs: NaN → 0.0 replacement\n- VLLMClient.batch_update_named_params: method exists after patch\n- VLLMGeneration: weight_sync_chunk_size attribute after patch\n- Patch idempotency: applying patch twice doesn't break anything\n\"\"\"\n\nimport unittest\nfrom dataclasses import dataclass\nfrom unittest.mock import MagicMock\n\nimport torch\n\n\nclass TestSplitTensorDict(unittest.TestCase):\n    \"\"\"Tests for patched split_tensor_dict.\"\"\"\n\n    def setUp(self):\n        from axolotl.monkeypatch.trainer.trl_vllm import _patched_split_tensor_dict\n\n        self.split = _patched_split_tensor_dict\n\n    def test_scalar_int_preserved(self):\n        d = {\"a\": torch.randn(4, 3), \"count\": 42}\n        chunks = self.split(d, 2)\n        self.assertEqual(len(chunks), 2)\n        self.assertEqual(chunks[0][\"count\"], 42)\n        self.assertEqual(chunks[1][\"count\"], 42)\n\n    def test_scalar_float_preserved(self):\n        d = {\"a\": torch.randn(6, 2), \"lr\": 1e-5}\n        chunks = self.split(d, 3)\n        for c in chunks:\n            self.assertEqual(c[\"lr\"], 1e-5)\n\n    def test_scalar_bool_preserved(self):\n        d = {\"a\": torch.randn(4, 2), \"flag\": True}\n        chunks = self.split(d, 2)\n        for c in chunks:\n            self.assertTrue(c[\"flag\"])\n\n    def test_none_preserved(self):\n        d = {\"a\": torch.randn(4, 2), \"b\": None}\n        chunks = self.split(d, 2)\n        for c in chunks:\n            self.assertIsNone(c[\"b\"])\n\n    def test_tensor_split(self):\n        t = torch.arange(8).reshape(4, 2)\n        d = {\"a\": t, \"n\": 10}\n        chunks = self.split(d, 2)\n        self.assertEqual(chunks[0][\"a\"].shape, (2, 2))\n        self.assertEqual(chunks[1][\"a\"].shape, (2, 2))\n        torch.testing.assert_close(chunks[0][\"a\"], t[:2])\n        torch.testing.assert_close(chunks[1][\"a\"], t[2:])\n\n    def test_0d_tensor_preserved(self):\n        d = {\"a\": torch.randn(4, 2), \"scalar_t\": torch.tensor(3.14)}\n        chunks = self.split(d, 2)\n        for c in chunks:\n            self.assertAlmostEqual(c[\"scalar_t\"].item(), 3.14, places=5)\n\n    def test_list_split(self):\n        d = {\"a\": torch.randn(4, 2), \"names\": [\"a\", \"b\", \"c\", \"d\"]}\n        chunks = self.split(d, 2)\n        self.assertEqual(chunks[0][\"names\"], [\"a\", \"b\"])\n        self.assertEqual(chunks[1][\"names\"], [\"c\", \"d\"])\n\n\nclass TestShuffleSequenceDict(unittest.TestCase):\n    \"\"\"Tests for patched shuffle_sequence_dict.\"\"\"\n\n    def setUp(self):\n        from axolotl.monkeypatch.trainer.trl_vllm import _patched_shuffle_sequence_dict\n\n        self.shuffle = _patched_shuffle_sequence_dict\n\n    def test_scalar_int_preserved(self):\n        d = {\"a\": torch.randn(4, 3), \"count\": 42}\n        result = self.shuffle(d)\n        self.assertEqual(result[\"count\"], 42)\n\n    def test_scalar_float_preserved(self):\n        d = {\"a\": torch.randn(4, 3), \"lr\": 1e-5}\n        result = self.shuffle(d)\n        self.assertEqual(result[\"lr\"], 1e-5)\n\n    def test_scalar_bool_preserved(self):\n        d = {\"a\": torch.randn(4, 3), \"flag\": False}\n        result = self.shuffle(d)\n        self.assertFalse(result[\"flag\"])\n\n    def test_none_preserved(self):\n        d = {\"a\": torch.randn(4, 3), \"b\": None}\n        result = self.shuffle(d)\n        self.assertIsNone(result[\"b\"])\n\n    def test_tensor_permuted(self):\n        torch.manual_seed(42)\n        t = torch.arange(4).float()\n        d = {\"a\": t}\n        result = self.shuffle(d)\n        # Same elements, possibly different order\n        self.assertEqual(sorted(result[\"a\"].tolist()), sorted(t.tolist()))\n        self.assertEqual(result[\"a\"].shape, t.shape)\n\n    def test_list_permuted(self):\n        torch.manual_seed(42)\n        d = {\"a\": torch.randn(3, 2), \"names\": [\"x\", \"y\", \"z\"]}\n        result = self.shuffle(d)\n        self.assertEqual(sorted(result[\"names\"]), [\"x\", \"y\", \"z\"])\n        self.assertEqual(len(result[\"names\"]), 3)\n\n    def test_0d_tensor_preserved(self):\n        d = {\"a\": torch.randn(4, 2), \"scalar_t\": torch.tensor(3.14)}\n        result = self.shuffle(d)\n        self.assertAlmostEqual(result[\"scalar_t\"].item(), 3.14, places=5)\n\n\nclass TestExtractLogprobs(unittest.TestCase):\n    \"\"\"Tests for patched extract_logprobs (NaN → 0.0).\"\"\"\n\n    def setUp(self):\n        from axolotl.monkeypatch.trainer.trl_vllm import _patched_extract_logprobs\n\n        self.extract = _patched_extract_logprobs\n\n    def _make_output(self, logprob_values):\n        \"\"\"Create a mock vLLM RequestOutput with given logprob values.\"\"\"\n\n        @dataclass\n        class LogprobItem:\n            logprob: float\n            rank: int\n\n        @dataclass\n        class SeqOutput:\n            logprobs: list[dict[int, LogprobItem]] | None\n\n        @dataclass\n        class RequestOutput:\n            outputs: list[SeqOutput]\n\n        logprobs_list = []\n        for vals in logprob_values:\n            lp_dict = {i: LogprobItem(logprob=v, rank=i) for i, v in enumerate(vals)}\n            logprobs_list.append(lp_dict)\n\n        return RequestOutput(outputs=[SeqOutput(logprobs=logprobs_list)])\n\n    def test_nan_replaced_with_zero(self):\n        output = self._make_output([[float(\"nan\"), 0.5], [-0.3, float(\"nan\")]])\n        logprobs, token_ids = self.extract([output])\n        self.assertEqual(logprobs[0][0][0], 0.0)  # NaN → 0.0\n        self.assertEqual(logprobs[0][0][1], 0.5)\n        self.assertEqual(logprobs[0][1][0], -0.3)\n        self.assertEqual(logprobs[0][1][1], 0.0)  # NaN → 0.0\n\n    def test_normal_values_preserved(self):\n        output = self._make_output([[-0.5, -1.2], [-0.1, -2.0]])\n        logprobs, token_ids = self.extract([output])\n        self.assertAlmostEqual(logprobs[0][0][0], -0.5)\n        self.assertAlmostEqual(logprobs[0][0][1], -1.2)\n\n    def test_none_logprobs_returns_none(self):\n        @dataclass\n        class SeqOutput:\n            logprobs: None = None\n\n        @dataclass\n        class RequestOutput:\n            outputs: list\n\n        output = RequestOutput(outputs=[SeqOutput()])\n        logprobs, token_ids = self.extract([output])\n        self.assertIsNone(logprobs)\n        self.assertIsNone(token_ids)\n\n    def test_token_ids_extracted(self):\n        output = self._make_output([[-0.5]])\n        logprobs, token_ids = self.extract([output])\n        self.assertEqual(token_ids[0][0], [0])  # token_id=0 from enumerate\n\n\nclass TestPatchApplication(unittest.TestCase):\n    \"\"\"Tests for patch_trl_vllm() application.\"\"\"\n\n    def test_batch_update_added_to_client(self):\n        from axolotl.monkeypatch.trainer.trl_vllm import patch_trl_vllm\n\n        patch_trl_vllm()\n        from trl.generation.vllm_client import VLLMClient\n\n        self.assertTrue(hasattr(VLLMClient, \"batch_update_named_params\"))\n\n    def test_extract_logprobs_patched(self):\n        from axolotl.monkeypatch.trainer.trl_vllm import (\n            _patched_extract_logprobs,\n            patch_trl_vllm,\n        )\n\n        patch_trl_vllm()\n        from trl.generation import vllm_generation\n\n        self.assertIs(vllm_generation.extract_logprobs, _patched_extract_logprobs)\n\n    def test_utils_patched(self):\n        from axolotl.monkeypatch.trainer.trl_vllm import (\n            _patched_shuffle_sequence_dict,\n            _patched_split_tensor_dict,\n            patch_trl_vllm,\n        )\n\n        patch_trl_vllm()\n        import trl.trainer.utils\n\n        self.assertIs(trl.trainer.utils.split_tensor_dict, _patched_split_tensor_dict)\n        self.assertIs(\n            trl.trainer.utils.shuffle_sequence_dict, _patched_shuffle_sequence_dict\n        )\n\n    def test_patch_idempotent(self):\n        from axolotl.monkeypatch.trainer.trl_vllm import patch_trl_vllm\n\n        patch_trl_vllm()\n        patch_trl_vllm()  # second call should not error\n        from trl.generation.vllm_client import VLLMClient\n\n        self.assertTrue(hasattr(VLLMClient, \"batch_update_named_params\"))\n\n\nclass TestBatchUpdateChunking(unittest.TestCase):\n    \"\"\"Tests for batch_update_named_params chunking logic.\"\"\"\n\n    def test_no_chunk_single_batch(self):\n        from axolotl.monkeypatch.trainer.trl_vllm import _batch_update_named_params\n\n        # Test that with chunk_size=None, all params go in one chunk\n        client = MagicMock()\n        client.base_url = \"http://localhost:8000\"\n        client.session.post.return_value = MagicMock(status_code=200)\n        client.communicator = MagicMock()\n        client.communicator.group = MagicMock()\n        client.rank = 0\n\n        params = [\n            (\"layer.0.weight\", torch.randn(10, 10)),\n            (\"layer.1.weight\", torch.randn(10, 10)),\n        ]\n        _batch_update_named_params(client, params, chunk_size=None)\n\n        # Should make exactly 1 HTTP call\n        self.assertEqual(client.session.post.call_count, 1)\n\n    def test_chunk_splits_params(self):\n        from axolotl.monkeypatch.trainer.trl_vllm import _batch_update_named_params\n\n        client = MagicMock()\n        client.base_url = \"http://localhost:8000\"\n        client.session.post.return_value = MagicMock(status_code=200)\n        client.communicator = MagicMock()\n        client.communicator.group = MagicMock()\n        client.rank = 0\n\n        params = [\n            (\"a\", torch.randn(100)),  # 100 elements\n            (\"b\", torch.randn(100)),  # 100 elements\n            (\"c\", torch.randn(100)),  # 100 elements\n        ]\n        _batch_update_named_params(client, params, chunk_size=150)\n\n        # Should make 2 HTTP calls: [a,b] then [c] (100+100 > 150 triggers split)\n        # Actually: a=100 < 150, a+b=200 > 150 → chunk [a], then b=100 < 150,\n        # b+c=200 > 150 → chunk [b], then [c]. So 3 calls.\n        # Wait: first a added (100 < 150), then b: 100+100=200 > 150, so chunk=[a],\n        # new chunk starts with b (100 < 150), then c: 100+100=200 > 150, so chunk=[b],\n        # final chunk=[c]. 3 HTTP calls.\n        self.assertEqual(client.session.post.call_count, 3)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/monkeypatch/test_voxtral_modeling_patch.py",
    "content": "\"\"\"Integration tests for Voxtral modeling patches.\"\"\"\n\nimport pytest\n\n\nclass TestVoxtralModelingPatchIntegration:\n    \"\"\"Test Voxtral modeling patch integration.\"\"\"\n\n    @pytest.mark.integration\n    def test_voxtral_conditional_generation_patch(self):\n        \"\"\"Test that Voxtral conditional generation patch can be applied.\"\"\"\n        try:\n            from transformers.models.voxtral.modeling_voxtral import (\n                VoxtralForConditionalGeneration,\n            )\n        except ImportError:\n            pytest.skip(\"VoxtralForConditionalGeneration not available\")\n\n        from axolotl.monkeypatch.models.voxtral.modeling import (\n            patch_voxtral_conditional_generation_forward,\n        )\n\n        # Store original method\n        original_forward = VoxtralForConditionalGeneration.forward\n\n        # Apply patch and get unpatch function\n        unpatch_fn = patch_voxtral_conditional_generation_forward()\n\n        # Verify patch was applied\n        assert VoxtralForConditionalGeneration.forward != original_forward, (\n            \"forward method was not patched\"\n        )\n\n        # Verify the method is still callable\n        assert callable(VoxtralForConditionalGeneration.forward), (\n            \"Patched method is not callable\"\n        )\n\n        # Test unpatch function\n        unpatch_fn()\n        assert VoxtralForConditionalGeneration.forward == original_forward, (\n            \"unpatch function did not restore original method\"\n        )\n"
  },
  {
    "path": "tests/patched/test_validation.py",
    "content": "\"\"\"Module for testing the validation module\"\"\"\n\nimport os\nimport warnings\nfrom typing import Optional\n\nimport pytest\nfrom pydantic import ValidationError\n\nfrom axolotl.loaders.utils import check_model_config\nfrom axolotl.utils import is_comet_available\nfrom axolotl.utils.config import validate_config\nfrom axolotl.utils.dict import DictDefault\nfrom axolotl.utils.mlflow_ import setup_mlflow_env_vars\nfrom axolotl.utils.schemas.config import AxolotlConfigWCapabilities\nfrom axolotl.utils.wandb_ import setup_wandb_env_vars\n\nwarnings.filterwarnings(\"error\")\n\n\n@pytest.fixture(name=\"minimal_cfg\")\ndef fixture_cfg():\n    return DictDefault(\n        {\n            \"base_model\": \"TinyLlama/TinyLlama-1.1B-Chat-v0.6\",\n            \"learning_rate\": 0.000001,\n            \"datasets\": [\n                {\n                    \"path\": \"mhenrichsen/alpaca_2k_test\",\n                    \"type\": \"alpaca\",\n                }\n            ],\n            \"micro_batch_size\": 1,\n            \"gradient_accumulation_steps\": 1,\n        }\n    )\n\n\nclass BaseValidation:\n    \"\"\"\n    Base validation module to setup the log capture\n    \"\"\"\n\n    _caplog: Optional[pytest.LogCaptureFixture] = None\n\n    @pytest.fixture(autouse=True)\n    def inject_fixtures(self, caplog):\n        self._caplog = caplog\n\n\nclass TestValidation(BaseValidation):\n    \"\"\"\n    Test the validation module\n    \"\"\"\n\n    def test_defaults(self, minimal_cfg):\n        test_cfg = DictDefault(\n            {\n                \"weight_decay\": None,\n            }\n            | minimal_cfg\n        )\n        cfg = validate_config(test_cfg)\n\n        assert cfg.train_on_inputs is False\n        assert cfg.weight_decay is None\n\n    def test_zero3_qlora_use_reentrant_false(self, minimal_cfg):\n        test_cfg = DictDefault(\n            {\n                \"deepspeed\": \"deepspeed_configs/zero3_bf16.json\",\n                \"gradient_checkpointing\": True,\n                \"gradient_checkpointing_kwargs\": {\"use_reentrant\": False},\n                \"load_in_4bit\": True,\n                \"adapter\": \"qlora\",\n            }\n            | minimal_cfg\n        )\n\n        with self._caplog.at_level(\"WARNING\"):\n            validate_config(test_cfg)\n            assert (\n                \"qlora + zero3 with use_reentrant: false may result in a CheckpointError about recomputed values\"\n                in self._caplog.records[0].message\n            )\n\n    def test_deepspeed_empty(self, minimal_cfg):\n        test_cfg = DictDefault(\n            {\n                \"deepspeed\": \"\",\n                \"gradient_checkpointing\": True,\n                \"gradient_checkpointing_kwargs\": {\"use_reentrant\": False},\n                \"load_in_4bit\": True,\n                \"adapter\": \"qlora\",\n            }\n            | minimal_cfg\n        )\n\n        _ = validate_config(test_cfg)\n\n    def test_deepspeed_not_set(self, minimal_cfg):\n        test_cfg = DictDefault(\n            {\n                \"deepspeed\": None,\n                \"gradient_checkpointing\": True,\n                \"gradient_checkpointing_kwargs\": {\"use_reentrant\": False},\n                \"load_in_4bit\": True,\n                \"adapter\": \"qlora\",\n            }\n            | minimal_cfg\n        )\n\n        _ = validate_config(test_cfg)\n\n    def test_datasets_min_length(self):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"TinyLlama/TinyLlama-1.1B-Chat-v0.6\",\n                \"learning_rate\": 0.000001,\n                \"datasets\": [],\n                \"micro_batch_size\": 1,\n                \"gradient_accumulation_steps\": 1,\n            }\n        )\n\n        with pytest.raises(\n            ValidationError,\n            match=r\".*List should have at least 1 item after validation*\",\n        ):\n            validate_config(cfg)\n\n    def test_datasets_min_length_empty(self):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"TinyLlama/TinyLlama-1.1B-Chat-v0.6\",\n                \"learning_rate\": 0.000001,\n                \"micro_batch_size\": 1,\n                \"gradient_accumulation_steps\": 1,\n            }\n        )\n\n        with pytest.raises(\n            ValueError, match=r\".*either datasets or pretraining_dataset is required*\"\n        ):\n            validate_config(cfg)\n\n    def test_pretrain_dataset_min_length(self):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"TinyLlama/TinyLlama-1.1B-Chat-v0.6\",\n                \"learning_rate\": 0.000001,\n                \"pretraining_dataset\": [],\n                \"micro_batch_size\": 1,\n                \"gradient_accumulation_steps\": 1,\n                \"max_steps\": 100,\n            }\n        )\n\n        with pytest.raises(\n            ValidationError,\n            match=r\".*List should have at least 1 item after validation*\",\n        ):\n            validate_config(cfg)\n\n    def test_valid_pretrain_dataset(self):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"TinyLlama/TinyLlama-1.1B-Chat-v0.6\",\n                \"learning_rate\": 0.000001,\n                \"pretraining_dataset\": [\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": \"alpaca\",\n                    }\n                ],\n                \"micro_batch_size\": 1,\n                \"gradient_accumulation_steps\": 1,\n                \"max_steps\": 100,\n            }\n        )\n\n        validate_config(cfg)\n\n    def test_valid_sft_dataset(self):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"TinyLlama/TinyLlama-1.1B-Chat-v0.6\",\n                \"learning_rate\": 0.000001,\n                \"datasets\": [\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": \"alpaca\",\n                    }\n                ],\n                \"micro_batch_size\": 1,\n                \"gradient_accumulation_steps\": 1,\n            }\n        )\n\n        validate_config(cfg)\n\n    def test_batch_size_unused_warning(self):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"TinyLlama/TinyLlama-1.1B-Chat-v0.6\",\n                \"learning_rate\": 0.000001,\n                \"datasets\": [\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": \"alpaca\",\n                    }\n                ],\n                \"micro_batch_size\": 4,\n                \"batch_size\": 32,\n            }\n        )\n\n        with self._caplog.at_level(\"WARNING\"):\n            validate_config(cfg)\n            assert \"batch_size is not recommended\" in self._caplog.records[0].message\n\n    def test_batch_size_more_params(self):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"TinyLlama/TinyLlama-1.1B-Chat-v0.6\",\n                \"learning_rate\": 0.000001,\n                \"datasets\": [\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": \"alpaca\",\n                    }\n                ],\n                \"batch_size\": 32,\n            }\n        )\n\n        with pytest.raises(ValueError, match=r\".*At least two of*\"):\n            validate_config(cfg)\n\n    def test_lr_as_float(self, minimal_cfg):\n        cfg = (\n            DictDefault(\n                {\n                    \"learning_rate\": \"5e-5\",\n                }\n            )\n            | minimal_cfg\n        )\n\n        new_cfg = validate_config(cfg)\n\n        assert new_cfg.learning_rate == 0.00005\n\n    def test_model_config_remap(self, minimal_cfg):\n        cfg = (\n            DictDefault(\n                {\n                    \"model_config\": {\"model_type\": \"mistral\"},\n                }\n            )\n            | minimal_cfg\n        )\n\n        new_cfg = validate_config(cfg)\n        assert new_cfg.overrides_of_model_config[\"model_type\"] == \"mistral\"\n\n    def test_model_type_remap(self, minimal_cfg):\n        cfg = (\n            DictDefault(\n                {\n                    \"model_type\": \"AutoModelForCausalLM\",\n                }\n            )\n            | minimal_cfg\n        )\n\n        new_cfg = validate_config(cfg)\n        assert new_cfg.type_of_model == \"AutoModelForCausalLM\"\n\n    def test_reward_model_defaults(self, minimal_cfg):\n        cfg = (\n            DictDefault(\n                {\n                    \"reward_model\": True,\n                }\n            )\n            | minimal_cfg\n        )\n\n        new_cfg = validate_config(cfg)\n        assert new_cfg.num_labels == 1\n        assert new_cfg.type_of_model == \"AutoModelForSequenceClassification\"\n\n    def test_process_reward_model_defaults(self, minimal_cfg):\n        cfg = (\n            DictDefault(\n                {\n                    \"process_reward_model\": True,\n                }\n            )\n            | minimal_cfg\n        )\n\n        new_cfg = validate_config(cfg)\n        assert new_cfg.num_labels == 2\n        assert new_cfg.type_of_model == \"AutoModelForTokenClassification\"\n\n    def test_model_revision_remap(self, minimal_cfg):\n        cfg = (\n            DictDefault(\n                {\n                    \"model_revision\": \"main\",\n                }\n            )\n            | minimal_cfg\n        )\n\n        new_cfg = validate_config(cfg)\n        assert new_cfg.revision_of_model == \"main\"\n\n    def test_qlora(self, minimal_cfg):\n        base_cfg = (\n            DictDefault(\n                {\n                    \"adapter\": \"qlora\",\n                }\n            )\n            | minimal_cfg\n        )\n\n        cfg = (\n            DictDefault(\n                {\n                    \"load_in_8bit\": True,\n                }\n            )\n            | base_cfg\n        )\n\n        with pytest.raises(ValueError, match=r\".*8bit.*\"):\n            validate_config(cfg)\n\n        cfg = (\n            DictDefault(\n                {\n                    \"gptq\": True,\n                }\n            )\n            | base_cfg\n        )\n\n        with pytest.raises(ValueError, match=r\".*gptq.*\"):\n            validate_config(cfg)\n\n        cfg = (\n            DictDefault(\n                {\n                    \"load_in_4bit\": False,\n                }\n            )\n            | base_cfg\n        )\n\n        with pytest.raises(ValueError, match=r\".*4bit.*\"):\n            validate_config(cfg)\n\n        cfg = (\n            DictDefault(\n                {\n                    \"load_in_4bit\": True,\n                }\n            )\n            | base_cfg\n        )\n\n        validate_config(cfg)\n\n    def test_qlora_merge(self, minimal_cfg):\n        base_cfg = (\n            DictDefault(\n                {\n                    \"adapter\": \"qlora\",\n                    \"merge_lora\": True,\n                }\n            )\n            | minimal_cfg\n        )\n\n        cfg = (\n            DictDefault(\n                {\n                    \"load_in_8bit\": True,\n                }\n            )\n            | base_cfg\n        )\n\n        with pytest.raises(ValueError, match=r\".*8bit.*\"):\n            validate_config(cfg)\n\n        cfg = (\n            DictDefault(\n                {\n                    \"gptq\": True,\n                }\n            )\n            | base_cfg\n        )\n\n        with pytest.raises(ValueError, match=r\".*gptq.*\"):\n            validate_config(cfg)\n\n        cfg = (\n            DictDefault(\n                {\n                    \"load_in_4bit\": True,\n                }\n            )\n            | base_cfg\n        )\n\n        with pytest.raises(ValueError, match=r\".*4bit.*\"):\n            validate_config(cfg)\n\n    def test_hf_use_auth_token(self, minimal_cfg):\n        cfg = (\n            DictDefault(\n                {\n                    \"push_dataset_to_hub\": \"namespace/repo\",\n                }\n            )\n            | minimal_cfg\n        )\n\n        with pytest.raises(ValueError, match=r\".*hf_use_auth_token.*\"):\n            validate_config(cfg)\n\n        cfg = (\n            DictDefault(\n                {\n                    \"push_dataset_to_hub\": \"namespace/repo\",\n                    \"hf_use_auth_token\": True,\n                }\n            )\n            | minimal_cfg\n        )\n        validate_config(cfg)\n\n    def test_gradient_accumulations_or_batch_size(self):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"TinyLlama/TinyLlama-1.1B-Chat-v0.6\",\n                \"learning_rate\": 0.000001,\n                \"datasets\": [\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": \"alpaca\",\n                    }\n                ],\n                \"gradient_accumulation_steps\": 1,\n                \"batch_size\": 1,\n            }\n        )\n\n        with pytest.raises(\n            ValueError, match=r\".*gradient_accumulation_steps or batch_size.*\"\n        ):\n            validate_config(cfg)\n\n    def test_falcon_fsdp(self, minimal_cfg):\n        regex_exp = r\".*FSDP is not supported for falcon models.*\"\n\n        # Check for lower-case\n        cfg = (\n            DictDefault(\n                {\n                    \"base_model\": \"tiiuae/falcon-7b\",\n                    \"fsdp\": [\"full_shard\", \"auto_wrap\"],\n                }\n            )\n            | minimal_cfg\n        )\n\n        with pytest.raises(ValueError, match=regex_exp):\n            validate_config(cfg)\n\n        # Check for upper-case\n        cfg = (\n            DictDefault(\n                {\n                    \"base_model\": \"Falcon-7b\",\n                    \"fsdp\": [\"full_shard\", \"auto_wrap\"],\n                }\n            )\n            | minimal_cfg\n        )\n\n        with pytest.raises(ValueError, match=regex_exp):\n            validate_config(cfg)\n\n        cfg = (\n            DictDefault(\n                {\n                    \"base_model\": \"tiiuae/falcon-7b\",\n                }\n            )\n            | minimal_cfg\n        )\n\n        validate_config(cfg)\n\n    def test_mpt_gradient_checkpointing(self, minimal_cfg):\n        regex_exp = r\".*gradient_checkpointing is not supported for MPT models*\"\n\n        # Check for lower-case\n        cfg = (\n            DictDefault(\n                {\n                    \"base_model\": \"mosaicml/mpt-7b\",\n                    \"gradient_checkpointing\": True,\n                }\n            )\n            | minimal_cfg\n        )\n\n        with pytest.raises(ValueError, match=regex_exp):\n            validate_config(cfg)\n\n    def test_flash_optimum(self, minimal_cfg):\n        cfg = (\n            DictDefault(\n                {\n                    \"flash_optimum\": True,\n                    \"adapter\": \"lora\",\n                    \"bf16\": False,\n                }\n            )\n            | minimal_cfg\n        )\n\n        with self._caplog.at_level(\"WARNING\"):\n            validate_config(cfg)\n            assert any(\n                \"BetterTransformers probably doesn't work with PEFT adapters\"\n                in record.message\n                for record in self._caplog.records\n            )\n\n        cfg = (\n            DictDefault(\n                {\n                    \"flash_optimum\": True,\n                    \"bf16\": False,\n                }\n            )\n            | minimal_cfg\n        )\n\n        with self._caplog.at_level(\"WARNING\"):\n            validate_config(cfg)\n            assert any(\n                \"probably set bfloat16 or float16\" in record.message\n                for record in self._caplog.records\n            )\n\n        cfg = (\n            DictDefault(\n                {\n                    \"flash_optimum\": True,\n                    \"fp16\": True,\n                }\n            )\n            | minimal_cfg\n        )\n        regex_exp = r\".*AMP is not supported.*\"\n\n        with pytest.raises(ValueError, match=regex_exp):\n            validate_config(cfg)\n\n        cfg = (\n            DictDefault(\n                {\n                    \"flash_optimum\": True,\n                    \"bf16\": True,\n                }\n            )\n            | minimal_cfg\n        )\n        regex_exp = r\".*AMP is not supported.*\"\n\n        with pytest.raises(ValueError, match=regex_exp):\n            validate_config(cfg)\n\n    def test_adamw_hyperparams(self, minimal_cfg):\n        cfg = (\n            DictDefault(\n                {\n                    \"optimizer\": None,\n                    \"adam_epsilon\": 0.0001,\n                }\n            )\n            | minimal_cfg\n        )\n\n        with self._caplog.at_level(\"WARNING\"):\n            validate_config(cfg)\n            assert any(\n                \"adamw hyperparameters found, but no adamw optimizer set\"\n                in record.message\n                for record in self._caplog.records\n            )\n\n        cfg = (\n            DictDefault(\n                {\n                    \"optimizer\": \"adafactor\",\n                    \"adam_beta1\": 0.0001,\n                }\n            )\n            | minimal_cfg\n        )\n\n        with self._caplog.at_level(\"WARNING\"):\n            validate_config(cfg)\n            assert any(\n                \"adamw hyperparameters found, but no adamw optimizer set\"\n                in record.message\n                for record in self._caplog.records\n            )\n\n        cfg = (\n            DictDefault(\n                {\n                    \"optimizer\": \"adamw_bnb_8bit\",\n                    \"adam_beta1\": 0.9,\n                    \"adam_beta2\": 0.99,\n                    \"adam_epsilon\": 0.0001,\n                }\n            )\n            | minimal_cfg\n        )\n\n        validate_config(cfg)\n\n        cfg = (\n            DictDefault(\n                {\n                    \"optimizer\": \"adafactor\",\n                }\n            )\n            | minimal_cfg\n        )\n\n        validate_config(cfg)\n\n    def test_deprecated_packing(self, minimal_cfg):\n        cfg = (\n            DictDefault(\n                {\n                    \"max_packed_sequence_len\": 1024,\n                }\n            )\n            | minimal_cfg\n        )\n        with pytest.raises(\n            DeprecationWarning,\n            match=r\"`max_packed_sequence_len` is no longer supported\",\n        ):\n            validate_config(cfg)\n\n    def test_packing(self, minimal_cfg):\n        cfg = (\n            DictDefault(\n                {\n                    \"sample_packing\": True,\n                    \"pad_to_sequence_len\": False,\n                    \"flash_attention\": True,\n                }\n            )\n            | minimal_cfg\n        )\n        with self._caplog.at_level(\"WARNING\"):\n            validate_config(cfg)\n            assert any(\n                \"`pad_to_sequence_len: true` is recommended when using sample_packing\"\n                in record.message\n                for record in self._caplog.records\n            )\n\n    def test_packing_autoset(self, minimal_cfg):\n        cfg = (\n            DictDefault(\n                {\n                    \"sample_packing\": True,\n                    \"pad_to_sequence_len\": None,\n                    \"flash_attention\": True,\n                }\n            )\n            | minimal_cfg\n        )\n        with self._caplog.at_level(\"INFO\"):\n            cfg = validate_config(cfg)\n            assert any(\n                \"Setting `pad_to_sequence_len: true` to prevent memory leaks when sample_packing\"\n                in record.message\n                for record in self._caplog.records\n            )\n            assert cfg.pad_to_sequence_len is True\n\n    def test_merge_lora_no_bf16_fail(self, minimal_cfg):\n        \"\"\"\n        This is assumed to be run on a CPU machine, so bf16 is not supported.\n        \"\"\"\n\n        cfg = (\n            DictDefault(\n                {\n                    \"bf16\": True,\n                    \"capabilities\": {\"bf16\": False},\n                    \"env_capabilities\": {\n                        \"torch_version\": \"2.6.0\",\n                    },\n                }\n            )\n            | minimal_cfg\n        )\n\n        with pytest.raises(ValueError, match=r\".*AMP is not supported on this GPU*\"):\n            AxolotlConfigWCapabilities(**cfg.to_dict())\n\n        cfg = (\n            DictDefault(\n                {\n                    \"bf16\": True,\n                    \"merge_lora\": True,\n                    \"capabilities\": {\"bf16\": False},\n                }\n            )\n            | minimal_cfg\n        )\n\n        validate_config(cfg)\n\n    def test_no_conflict_save_strategy(self, minimal_cfg):\n        cfg = (\n            DictDefault(\n                {\n                    \"save_strategy\": \"epoch\",\n                    \"save_steps\": 10,\n                }\n            )\n            | minimal_cfg\n        )\n\n        with pytest.raises(\n            ValueError, match=r\".*save_strategy and save_steps mismatch.*\"\n        ):\n            validate_config(cfg)\n\n        cfg = (\n            DictDefault(\n                {\n                    \"save_strategy\": \"no\",\n                    \"save_steps\": 10,\n                }\n            )\n            | minimal_cfg\n        )\n\n        with pytest.raises(\n            ValueError, match=r\".*save_strategy and save_steps mismatch.*\"\n        ):\n            validate_config(cfg)\n\n        cfg = (\n            DictDefault(\n                {\n                    \"save_strategy\": \"steps\",\n                }\n            )\n            | minimal_cfg\n        )\n\n        validate_config(cfg)\n\n        cfg = (\n            DictDefault(\n                {\n                    \"save_strategy\": \"steps\",\n                    \"save_steps\": 10,\n                }\n            )\n            | minimal_cfg\n        )\n\n        validate_config(cfg)\n\n        cfg = (\n            DictDefault(\n                {\n                    \"save_steps\": 10,\n                }\n            )\n            | minimal_cfg\n        )\n\n        validate_config(cfg)\n\n        cfg = (\n            DictDefault(\n                {\n                    \"save_strategy\": \"no\",\n                }\n            )\n            | minimal_cfg\n        )\n\n        validate_config(cfg)\n\n    def test_no_conflict_eval_strategy(self, minimal_cfg):\n        cfg = (\n            DictDefault(\n                {\n                    \"eval_strategy\": \"epoch\",\n                    \"eval_steps\": 10,\n                }\n            )\n            | minimal_cfg\n        )\n\n        with pytest.raises(\n            ValueError, match=r\".*eval_strategy and eval_steps mismatch.*\"\n        ):\n            validate_config(cfg)\n\n        cfg = (\n            DictDefault(\n                {\n                    \"eval_strategy\": \"no\",\n                    \"eval_steps\": 10,\n                }\n            )\n            | minimal_cfg\n        )\n\n        with pytest.raises(\n            ValueError, match=r\".*eval_strategy and eval_steps mismatch.*\"\n        ):\n            validate_config(cfg)\n\n        cfg = (\n            DictDefault(\n                {\n                    \"eval_strategy\": \"steps\",\n                }\n            )\n            | minimal_cfg\n        )\n\n        validate_config(cfg)\n\n        cfg = (\n            DictDefault(\n                {\n                    \"eval_strategy\": \"steps\",\n                    \"eval_steps\": 10,\n                }\n            )\n            | minimal_cfg\n        )\n\n        validate_config(cfg)\n\n        cfg = (\n            DictDefault(\n                {\n                    \"eval_steps\": 10,\n                }\n            )\n            | minimal_cfg\n        )\n\n        validate_config(cfg)\n\n        cfg = (\n            DictDefault(\n                {\n                    \"eval_strategy\": \"no\",\n                }\n            )\n            | minimal_cfg\n        )\n\n        validate_config(cfg)\n\n        cfg = (\n            DictDefault(\n                {\n                    \"eval_strategy\": \"epoch\",\n                    \"val_set_size\": 0,\n                }\n            )\n            | minimal_cfg\n        )\n\n        with pytest.raises(\n            ValueError,\n            match=r\".*eval_steps and eval_strategy are not supported with val_set_size == 0.*\",\n        ):\n            validate_config(cfg)\n\n        cfg = (\n            DictDefault(\n                {\n                    \"eval_steps\": 10,\n                    \"val_set_size\": 0,\n                }\n            )\n            | minimal_cfg\n        )\n\n        with pytest.raises(\n            ValueError,\n            match=r\".*eval_steps and eval_strategy are not supported with val_set_size == 0.*\",\n        ):\n            validate_config(cfg)\n\n        cfg = (\n            DictDefault(\n                {\n                    \"val_set_size\": 0,\n                }\n            )\n            | minimal_cfg\n        )\n\n        validate_config(cfg)\n\n        cfg = (\n            DictDefault(\n                {\n                    \"eval_steps\": 10,\n                    \"val_set_size\": 0.01,\n                }\n            )\n            | minimal_cfg\n        )\n\n        validate_config(cfg)\n\n        cfg = (\n            DictDefault(\n                {\n                    \"eval_strategy\": \"epoch\",\n                    \"val_set_size\": 0.01,\n                }\n            )\n            | minimal_cfg\n        )\n\n        validate_config(cfg)\n\n    def test_eval_table_size_conflict_eval_packing(self, minimal_cfg):\n        cfg = (\n            DictDefault(\n                {\n                    \"sample_packing\": True,\n                    \"eval_table_size\": 100,\n                    \"flash_attention\": True,\n                }\n            )\n            | minimal_cfg\n        )\n\n        with pytest.raises(\n            ValueError, match=r\".*Please set 'eval_sample_packing' to false.*\"\n        ):\n            validate_config(cfg)\n\n        cfg = (\n            DictDefault(\n                {\n                    \"sample_packing\": True,\n                    \"eval_sample_packing\": False,\n                    \"flash_attention\": True,\n                }\n            )\n            | minimal_cfg\n        )\n\n        validate_config(cfg)\n\n        cfg = (\n            DictDefault(\n                {\n                    \"sample_packing\": False,\n                    \"eval_table_size\": 100,\n                    \"flash_attention\": True,\n                }\n            )\n            | minimal_cfg\n        )\n\n        validate_config(cfg)\n\n        cfg = (\n            DictDefault(\n                {\n                    \"sample_packing\": True,\n                    \"eval_table_size\": 100,\n                    \"eval_sample_packing\": False,\n                    \"flash_attention\": True,\n                }\n            )\n            | minimal_cfg\n        )\n\n        validate_config(cfg)\n\n    def test_load_in_x_bit_without_adapter(self, minimal_cfg):\n        cfg = (\n            DictDefault(\n                {\n                    \"load_in_4bit\": True,\n                }\n            )\n            | minimal_cfg\n        )\n\n        with pytest.raises(\n            ValueError,\n            match=r\".*load_in_8bit and load_in_4bit are not supported without setting an adapter.*\",\n        ):\n            validate_config(cfg)\n\n        cfg = (\n            DictDefault(\n                {\n                    \"load_in_8bit\": True,\n                }\n            )\n            | minimal_cfg\n        )\n\n        with pytest.raises(\n            ValueError,\n            match=r\".*load_in_8bit and load_in_4bit are not supported without setting an adapter.*\",\n        ):\n            validate_config(cfg)\n\n        cfg = (\n            DictDefault(\n                {\n                    \"load_in_4bit\": True,\n                    \"adapter\": \"qlora\",\n                }\n            )\n            | minimal_cfg\n        )\n\n        validate_config(cfg)\n\n        cfg = (\n            DictDefault(\n                {\n                    \"load_in_8bit\": True,\n                    \"adapter\": \"lora\",\n                }\n            )\n            | minimal_cfg\n        )\n\n        validate_config(cfg)\n\n    def test_warmup_step_no_conflict(self, minimal_cfg):\n        cfg = (\n            DictDefault(\n                {\n                    \"warmup_steps\": 10,\n                    \"warmup_ratio\": 0.1,\n                }\n            )\n            | minimal_cfg\n        )\n\n        with pytest.raises(\n            ValueError,\n            match=r\".*warmup_steps and warmup_ratio are mutually exclusive*\",\n        ):\n            validate_config(cfg)\n\n        cfg = (\n            DictDefault(\n                {\n                    \"warmup_steps\": 10,\n                }\n            )\n            | minimal_cfg\n        )\n\n        validate_config(cfg)\n\n        cfg = (\n            DictDefault(\n                {\n                    \"warmup_ratio\": 0.1,\n                }\n            )\n            | minimal_cfg\n        )\n\n        validate_config(cfg)\n\n    def test_unfrozen_parameters_w_peft_layers_to_transform(self, minimal_cfg):\n        cfg = (\n            DictDefault(\n                {\n                    \"adapter\": \"lora\",\n                    \"unfrozen_parameters\": [\n                        \"model.layers.2[0-9]+.block_sparse_moe.gate.*\"\n                    ],\n                    \"peft_layers_to_transform\": [0, 1],\n                }\n            )\n            | minimal_cfg\n        )\n\n        with pytest.raises(\n            ValueError,\n            match=r\".*can have unexpected behavior*\",\n        ):\n            validate_config(cfg)\n\n    def test_hub_model_id_save_value_warns_save_stragey_no(self, minimal_cfg):\n        cfg = DictDefault({\"hub_model_id\": \"test\", \"save_strategy\": \"no\"}) | minimal_cfg\n\n        with self._caplog.at_level(\"WARNING\"):\n            validate_config(cfg)\n            assert len(self._caplog.records) == 1\n\n    def test_hub_model_id_save_value_warns_random_value(self, minimal_cfg):\n        cfg = (\n            DictDefault({\"hub_model_id\": \"test\", \"save_strategy\": \"test\"}) | minimal_cfg\n        )\n\n        with self._caplog.at_level(\"WARNING\"):\n            validate_config(cfg)\n            assert len(self._caplog.records) == 1\n\n    def test_hub_model_id_save_value_steps(self, minimal_cfg):\n        cfg = (\n            DictDefault({\"hub_model_id\": \"test\", \"save_strategy\": \"steps\"})\n            | minimal_cfg\n        )\n\n        with self._caplog.at_level(\"WARNING\"):\n            validate_config(cfg)\n            assert len(self._caplog.records) == 0\n\n    def test_hub_model_id_save_value_epochs(self, minimal_cfg):\n        cfg = (\n            DictDefault({\"hub_model_id\": \"test\", \"save_strategy\": \"epoch\"})\n            | minimal_cfg\n        )\n\n        with self._caplog.at_level(\"WARNING\"):\n            validate_config(cfg)\n            assert len(self._caplog.records) == 0\n\n    def test_hub_model_id_save_value_none(self, minimal_cfg):\n        cfg = DictDefault({\"hub_model_id\": \"test\", \"save_strategy\": None}) | minimal_cfg\n\n        with self._caplog.at_level(\"WARNING\"):\n            validate_config(cfg)\n            assert len(self._caplog.records) == 0\n\n    def test_hub_model_id_save_value_no_set_save_strategy(self, minimal_cfg):\n        cfg = DictDefault({\"hub_model_id\": \"test\"}) | minimal_cfg\n\n        with self._caplog.at_level(\"WARNING\"):\n            validate_config(cfg)\n            assert len(self._caplog.records) == 0\n\n    def test_dpo_beta_deprecation(self, minimal_cfg):\n        cfg = DictDefault({\"dpo_beta\": 0.2}) | minimal_cfg\n\n        with self._caplog.at_level(\"WARNING\"):\n            new_cfg = validate_config(cfg)\n            assert new_cfg[\"rl_beta\"] == 0.2\n            assert new_cfg[\"dpo_beta\"] is None\n            assert len(self._caplog.records) == 1\n\n    def test_eval_strategy_remap(self, minimal_cfg):\n        cfg = (\n            DictDefault(\n                {\n                    \"evaluation_strategy\": \"steps\",\n                }\n            )\n            | minimal_cfg\n        )\n\n        with self._caplog.at_level(\"WARNING\"):\n            new_cfg = validate_config(cfg)\n            assert new_cfg.eval_strategy == \"steps\"\n            assert (\n                \"evaluation_strategy is deprecated, use eval_strategy instead\"\n                in self._caplog.records[0].message\n            )\n\n    def test_torch_version_adopt_req(self, minimal_cfg):\n        cfg = (\n            DictDefault(\n                {\n                    \"optimizer\": \"adopt_adamw\",\n                }\n            )\n            | minimal_cfg\n        )\n\n        with pytest.raises(\n            ValueError,\n            match=r\".*ADOPT optimizer is incompatible with torch version*\",\n        ):\n            env_capabilities = {\"torch_version\": \"2.3.0\"}\n            capabilities = {\"bf16\": False}\n            _ = validate_config(\n                cfg, capabilities=capabilities, env_capabilities=env_capabilities\n            )\n\n        env_capabilities = {\"torch_version\": \"2.6.0\"}\n        capabilities = {\"bf16\": False}\n        _ = validate_config(\n            cfg, capabilities=capabilities, env_capabilities=env_capabilities\n        )\n\n        env_capabilities = {\"torch_version\": \"2.5.2\"}\n        capabilities = {\"bf16\": False}\n        _ = validate_config(\n            cfg, capabilities=capabilities, env_capabilities=env_capabilities\n        )\n\n    def test_cfg_throws_error_with_s2_attention_and_sample_packing(self, minimal_cfg):\n        test_cfg = DictDefault(\n            {\n                \"s2_attention\": True,\n                \"sample_packing\": True,\n            }\n            | minimal_cfg\n        )\n        with pytest.raises(\n            ValidationError,\n            match=r\".*shifted-sparse attention does not currently support sample packing*\",\n        ):\n            validate_config(test_cfg)\n\n\nclass TestTorchCompileValidation(BaseValidation):\n    \"\"\"\n    test suite for when torch_compile is set to 'auto'\n    \"\"\"\n\n    def test_torch_compile_auto(self, minimal_cfg):\n        cfg = (\n            DictDefault(\n                {\n                    \"torch_compile\": \"auto\",\n                }\n            )\n            | minimal_cfg\n        )\n\n        env_capabilities = {\"torch_version\": \"2.6.0\"}\n        capabilities = {\"bf16\": True}\n        updated_cfg = validate_config(\n            cfg, capabilities=capabilities, env_capabilities=env_capabilities\n        )\n\n        assert updated_cfg.torch_compile is True\n\n        env_capabilities = {\"torch_version\": \"2.4.1\"}\n        capabilities = {\"bf16\": True}\n        updated_cfg = validate_config(\n            cfg, capabilities=capabilities, env_capabilities=env_capabilities\n        )\n\n        assert updated_cfg.torch_compile is False\n\n        env_capabilities = {}\n        capabilities = {\"bf16\": True}\n        updated_cfg = validate_config(\n            cfg, capabilities=capabilities, env_capabilities=env_capabilities\n        )\n\n        assert updated_cfg.torch_compile is False\n\n\nclass TestSampleOptimConfigValidation(BaseValidation):\n    \"\"\"\n    test configurations for sample optimizations like batch flattening\n    \"\"\"\n\n    def test_batch_flattening_auto_enables(self, minimal_cfg):\n        cfg = (\n            DictDefault(\n                {\n                    \"flash_attention\": True,\n                    \"sample_packing\": None,\n                    \"micro_batch_size\": 2,\n                    \"batch_flattening\": \"auto\",\n                }\n            )\n            | minimal_cfg\n        )\n\n        new_cfg = validate_config(cfg)\n        assert new_cfg[\"batch_flattening\"] is True\n\n    def test_batch_flattening_auto_no_fa(self, minimal_cfg):\n        cfg = (\n            DictDefault(\n                {\n                    \"flash_attention\": False,\n                    \"sample_packing\": None,\n                    \"micro_batch_size\": 2,\n                    \"batch_flattening\": \"auto\",\n                }\n            )\n            | minimal_cfg\n        )\n\n        new_cfg = validate_config(cfg)\n        assert new_cfg[\"batch_flattening\"] is False\n\n    def test_batch_flattening_auto_mbsz_1(self, minimal_cfg):\n        cfg = (\n            DictDefault(\n                {\n                    \"flash_attention\": True,\n                    \"sample_packing\": None,\n                    \"micro_batch_size\": 1,\n                    \"batch_flattening\": \"auto\",\n                }\n            )\n            | minimal_cfg\n        )\n\n        new_cfg = validate_config(cfg)\n        assert new_cfg[\"batch_flattening\"] is False\n\n    def test_batch_flattening_auto_packing(self, minimal_cfg):\n        cfg = (\n            DictDefault(\n                {\n                    \"flash_attention\": True,\n                    \"sample_packing\": True,\n                    \"micro_batch_size\": 2,\n                    \"batch_flattening\": \"auto\",\n                }\n            )\n            | minimal_cfg\n        )\n\n        new_cfg = validate_config(cfg)\n        assert new_cfg[\"batch_flattening\"] is False\n\n\nclass TestValidationCheckModelConfig(BaseValidation):\n    \"\"\"\n    Test the validation for the config when the model config is available\n    \"\"\"\n\n    def test_llama_add_tokens_adapter(self, minimal_cfg):\n        cfg = (\n            DictDefault(\n                {\"adapter\": \"qlora\", \"load_in_4bit\": True, \"tokens\": [\"<|imstart|>\"]}\n            )\n            | minimal_cfg\n        )\n        model_config = DictDefault({\"model_type\": \"llama\"})\n\n        with pytest.raises(\n            ValueError,\n            match=r\".*`lora_modules_to_save` not properly set when adding new tokens*\",\n        ):\n            check_model_config(cfg, model_config)\n\n        cfg = (\n            DictDefault(\n                {\n                    \"adapter\": \"qlora\",\n                    \"load_in_4bit\": True,\n                    \"tokens\": [\"<|imstart|>\"],\n                    \"lora_modules_to_save\": [\"embed_tokens\"],\n                }\n            )\n            | minimal_cfg\n        )\n\n        with pytest.raises(\n            ValueError,\n            match=r\".*`lora_modules_to_save` not properly set when adding new tokens*\",\n        ):\n            check_model_config(cfg, model_config)\n\n        cfg = (\n            DictDefault(\n                {\n                    \"adapter\": \"qlora\",\n                    \"load_in_4bit\": True,\n                    \"tokens\": [\"<|imstart|>\"],\n                    \"lora_modules_to_save\": [\"embed_tokens\", \"lm_head\"],\n                }\n            )\n            | minimal_cfg\n        )\n\n        check_model_config(cfg, model_config)\n\n    def test_phi_add_tokens_adapter(self, minimal_cfg):\n        cfg = (\n            DictDefault(\n                {\"adapter\": \"qlora\", \"load_in_4bit\": True, \"tokens\": [\"<|imstart|>\"]}\n            )\n            | minimal_cfg\n        )\n        model_config = DictDefault({\"model_type\": \"phi\"})\n\n        with pytest.raises(\n            ValueError,\n            match=r\".*`lora_modules_to_save` not properly set when adding new tokens*\",\n        ):\n            check_model_config(cfg, model_config)\n\n        cfg = (\n            DictDefault(\n                {\n                    \"adapter\": \"qlora\",\n                    \"load_in_4bit\": True,\n                    \"tokens\": [\"<|imstart|>\"],\n                    \"lora_modules_to_save\": [\"embd.wte\", \"lm_head.linear\"],\n                }\n            )\n            | minimal_cfg\n        )\n\n        with pytest.raises(\n            ValueError,\n            match=r\".*`lora_modules_to_save` not properly set when adding new tokens*\",\n        ):\n            check_model_config(cfg, model_config)\n\n        cfg = (\n            DictDefault(\n                {\n                    \"adapter\": \"qlora\",\n                    \"load_in_4bit\": True,\n                    \"tokens\": [\"<|imstart|>\"],\n                    \"lora_modules_to_save\": [\"embed_tokens\", \"lm_head\"],\n                }\n            )\n            | minimal_cfg\n        )\n\n        check_model_config(cfg, model_config)\n\n\nclass TestValidationWandb(BaseValidation):\n    \"\"\"\n    Validation test for wandb\n    \"\"\"\n\n    def test_wandb_set_run_id_to_name(self, minimal_cfg):\n        cfg = (\n            DictDefault(\n                {\n                    \"wandb_run_id\": \"foo\",\n                }\n            )\n            | minimal_cfg\n        )\n\n        with self._caplog.at_level(\"WARNING\"):\n            new_cfg = validate_config(cfg)\n            assert any(\n                \"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead.\"\n                in record.message\n                for record in self._caplog.records\n            )\n\n            assert new_cfg.wandb_name == \"foo\" and new_cfg.wandb_run_id == \"foo\"\n\n        cfg = (\n            DictDefault(\n                {\n                    \"wandb_name\": \"foo\",\n                }\n            )\n            | minimal_cfg\n        )\n\n        new_cfg = validate_config(cfg)\n\n        assert new_cfg.wandb_name == \"foo\" and new_cfg.wandb_run_id is None\n\n    def test_wandb_sets_env(self, minimal_cfg):\n        cfg = (\n            DictDefault(\n                {\n                    \"wandb_project\": \"foo\",\n                    \"wandb_name\": \"bar\",\n                    \"wandb_run_id\": \"bat\",\n                    \"wandb_entity\": \"baz\",\n                    \"wandb_mode\": \"online\",\n                    \"wandb_watch\": \"false\",\n                    \"wandb_log_model\": \"checkpoint\",\n                }\n            )\n            | minimal_cfg\n        )\n\n        new_cfg = validate_config(cfg)\n\n        setup_wandb_env_vars(new_cfg)\n\n        assert os.environ.get(\"WANDB_PROJECT\", \"\") == \"foo\"\n        assert os.environ.get(\"WANDB_NAME\", \"\") == \"bar\"\n        assert os.environ.get(\"WANDB_RUN_ID\", \"\") == \"bat\"\n        assert os.environ.get(\"WANDB_ENTITY\", \"\") == \"baz\"\n        assert os.environ.get(\"WANDB_MODE\", \"\") == \"online\"\n        assert os.environ.get(\"WANDB_WATCH\", \"\") == \"false\"\n        assert os.environ.get(\"WANDB_LOG_MODEL\", \"\") == \"checkpoint\"\n\n        os.environ.pop(\"WANDB_PROJECT\", None)\n        os.environ.pop(\"WANDB_NAME\", None)\n        os.environ.pop(\"WANDB_RUN_ID\", None)\n        os.environ.pop(\"WANDB_ENTITY\", None)\n        os.environ.pop(\"WANDB_MODE\", None)\n        os.environ.pop(\"WANDB_WATCH\", None)\n        os.environ.pop(\"WANDB_LOG_MODEL\", None)\n\n    def test_wandb_set_disabled(self, minimal_cfg):\n        cfg = DictDefault({}) | minimal_cfg\n        new_cfg = validate_config(cfg)\n        setup_wandb_env_vars(new_cfg)\n        assert new_cfg.use_wandb is None\n\n        cfg = (\n            DictDefault(\n                {\n                    \"wandb_project\": \"foo\",\n                }\n            )\n            | minimal_cfg\n        )\n\n        new_cfg = validate_config(cfg)\n        setup_wandb_env_vars(new_cfg)\n        assert new_cfg.use_wandb is True\n\n        os.environ.pop(\"WANDB_PROJECT\", None)\n\n\n@pytest.mark.skipif(is_comet_available() is False, reason=\"comet_ml is not installed\")\nclass TestValidationComet(BaseValidation):\n    \"\"\"\n    Validation test for comet\n    \"\"\"\n\n    def test_comet_sets_env(self, minimal_cfg):\n        from axolotl.utils.comet_ import setup_comet_env_vars\n\n        comet_config = {\n            \"comet_api_key\": \"foo\",\n            \"comet_workspace\": \"some_workspace\",\n            \"comet_project_name\": \"some_project\",\n            \"comet_experiment_key\": \"some_experiment_key\",\n            \"comet_mode\": \"get_or_create\",\n            \"comet_online\": False,\n            \"comet_experiment_config\": {\n                \"auto_histogram_activation_logging\": False,\n                \"auto_histogram_epoch_rate\": 2,\n                \"auto_histogram_gradient_logging\": True,\n                \"auto_histogram_tensorboard_logging\": False,\n                \"auto_histogram_weight_logging\": True,\n                \"auto_log_co2\": False,\n                \"auto_metric_logging\": True,\n                \"auto_metric_step_rate\": 15,\n                \"auto_output_logging\": False,\n                \"auto_param_logging\": True,\n                \"comet_disabled\": False,\n                \"display_summary_level\": 2,\n                \"distributed_node_identifier\": \"some_distributed_node_identifier\",\n                \"log_code\": True,\n                \"log_env_cpu\": False,\n                \"log_env_details\": True,\n                \"log_env_disk\": False,\n                \"log_env_gpu\": True,\n                \"log_env_host\": False,\n                \"log_env_network\": True,\n                \"log_git_metadata\": False,\n                \"log_git_patch\": True,\n                \"log_graph\": False,\n                \"name\": \"some_name\",\n                \"offline_directory\": \"some_offline_directory\",\n                \"parse_args\": True,\n                \"tags\": [\"tag1\", \"tag2\"],\n            },\n        }\n\n        cfg = DictDefault(comet_config) | minimal_cfg\n\n        new_cfg = validate_config(cfg)\n\n        setup_comet_env_vars(new_cfg)\n\n        comet_env = {\n            key: value for key, value in os.environ.items() if key.startswith(\"COMET_\")\n        }\n\n        assert (\n            len(comet_env)\n            == len(comet_config) + len(comet_config[\"comet_experiment_config\"]) - 1\n        )\n\n        assert comet_env == {\n            \"COMET_API_KEY\": \"foo\",\n            \"COMET_AUTO_LOG_CLI_ARGUMENTS\": \"true\",\n            \"COMET_AUTO_LOG_CO2\": \"false\",\n            \"COMET_AUTO_LOG_CODE\": \"true\",\n            \"COMET_AUTO_LOG_DISABLE\": \"false\",\n            \"COMET_AUTO_LOG_ENV_CPU\": \"false\",\n            \"COMET_AUTO_LOG_ENV_DETAILS\": \"true\",\n            \"COMET_AUTO_LOG_ENV_DISK\": \"false\",\n            \"COMET_AUTO_LOG_ENV_GPU\": \"true\",\n            \"COMET_AUTO_LOG_ENV_HOST\": \"false\",\n            \"COMET_AUTO_LOG_ENV_NETWORK\": \"true\",\n            \"COMET_AUTO_LOG_GIT_METADATA\": \"false\",\n            \"COMET_AUTO_LOG_GIT_PATCH\": \"true\",\n            \"COMET_AUTO_LOG_GRAPH\": \"false\",\n            \"COMET_AUTO_LOG_HISTOGRAM_ACTIVATIONS\": \"false\",\n            \"COMET_AUTO_LOG_HISTOGRAM_EPOCH_RATE\": \"2\",\n            \"COMET_AUTO_LOG_HISTOGRAM_GRADIENTS\": \"true\",\n            \"COMET_AUTO_LOG_HISTOGRAM_TENSORBOARD\": \"false\",\n            \"COMET_AUTO_LOG_HISTOGRAM_WEIGHTS\": \"true\",\n            \"COMET_AUTO_LOG_METRIC_STEP_RATE\": \"15\",\n            \"COMET_AUTO_LOG_METRICS\": \"true\",\n            \"COMET_AUTO_LOG_OUTPUT_LOGGER\": \"false\",\n            \"COMET_AUTO_LOG_PARAMETERS\": \"true\",\n            \"COMET_DISPLAY_SUMMARY_LEVEL\": \"2\",\n            \"COMET_DISTRIBUTED_NODE_IDENTIFIER\": \"some_distributed_node_identifier\",\n            \"COMET_EXPERIMENT_KEY\": \"some_experiment_key\",\n            \"COMET_OFFLINE_DIRECTORY\": \"some_offline_directory\",\n            \"COMET_PROJECT_NAME\": \"some_project\",\n            \"COMET_START_EXPERIMENT_NAME\": \"some_name\",\n            \"COMET_START_EXPERIMENT_TAGS\": \"tag1,tag2\",\n            \"COMET_START_MODE\": \"get_or_create\",\n            \"COMET_START_ONLINE\": \"false\",\n            \"COMET_WORKSPACE\": \"some_workspace\",\n        }\n\n        for key in comet_env.keys():\n            os.environ.pop(key, None)\n\n\nclass TestValidationMLflow(BaseValidation):\n    \"\"\"\n    Validation test for MLflow\n    \"\"\"\n\n    def test_hf_mlflow_artifacts_config_sets_env(self, minimal_cfg):\n        cfg = (\n            DictDefault(\n                {\n                    \"hf_mlflow_log_artifacts\": True,\n                }\n            )\n            | minimal_cfg\n        )\n\n        new_cfg = validate_config(cfg)\n\n        assert new_cfg.hf_mlflow_log_artifacts is True\n\n        # Check it's not already present in env\n        assert \"HF_MLFLOW_LOG_ARTIFACTS\" not in os.environ\n\n        setup_mlflow_env_vars(new_cfg)\n\n        assert os.environ.get(\"HF_MLFLOW_LOG_ARTIFACTS\") == \"true\"\n\n        os.environ.pop(\"HF_MLFLOW_LOG_ARTIFACTS\", None)\n\n    def test_mlflow_not_used_by_default(self, minimal_cfg):\n        cfg = DictDefault({}) | minimal_cfg\n\n        new_cfg = validate_config(cfg)\n\n        setup_mlflow_env_vars(new_cfg)\n\n        assert cfg.use_mlflow is not True\n\n        cfg = (\n            DictDefault(\n                {\n                    \"mlflow_experiment_name\": \"foo\",\n                }\n            )\n            | minimal_cfg\n        )\n\n        new_cfg = validate_config(cfg)\n\n        setup_mlflow_env_vars(new_cfg)\n\n        assert new_cfg.use_mlflow is True\n\n        os.environ.pop(\"MLFLOW_EXPERIMENT_NAME\", None)\n\n\nclass TestDataloaderValidation(BaseValidation):\n    \"\"\"\n    tests for dataloader_* sane defaults\n    \"\"\"\n\n    def test_dataloader_auto_defaults(self, minimal_cfg):\n        cfg = minimal_cfg\n\n        new_cfg = validate_config(cfg, {\"n_gpu\": 8}, {\"torch_version\": \"2.6.0\"})\n\n        assert new_cfg.dataloader_num_workers == 8\n        assert new_cfg.dataloader_pin_memory is True\n        assert new_cfg.dataloader_prefetch_factor == 256\n"
  },
  {
    "path": "tests/prompt_strategies/__init__.py",
    "content": ""
  },
  {
    "path": "tests/prompt_strategies/conftest.py",
    "content": "\"\"\"\nshared fixtures for prompt strategies tests\n\"\"\"\n\nimport pytest\nfrom datasets import Dataset\nfrom transformers import AutoTokenizer\n\nfrom axolotl.prompt_strategies.jinja_template_analyzer import JinjaTemplateAnalyzer\nfrom axolotl.utils.chat_templates import _CHAT_TEMPLATES\n\nfrom tests.hf_offline_utils import enable_hf_offline\n\n\n@pytest.fixture(name=\"assistant_dataset\")\ndef fixture_assistant_dataset():\n    return Dataset.from_list(\n        [\n            {\n                \"messages\": [\n                    {\"role\": \"user\", \"content\": \"hello\"},\n                    {\"role\": \"assistant\", \"content\": \"hello\"},\n                    {\"role\": \"user\", \"content\": \"goodbye\"},\n                    {\"role\": \"assistant\", \"content\": \"goodbye\"},\n                ]\n            }\n        ]\n    )\n\n\n@pytest.fixture(name=\"sharegpt_dataset\")\ndef fixture_sharegpt_dataset():\n    return Dataset.from_list(\n        [\n            {\n                \"conversations\": [\n                    {\"from\": \"human\", \"value\": \"hello\"},\n                    {\"from\": \"gpt\", \"value\": \"hello\"},\n                    {\"from\": \"human\", \"value\": \"goodbye\"},\n                    {\"from\": \"gpt\", \"value\": \"goodbye\"},\n                ]\n            }\n        ]\n    )\n\n\n@pytest.fixture(name=\"basic_dataset\")\ndef fixture_basic_dataset():\n    return Dataset.from_list(\n        [\n            {\n                \"conversations\": [\n                    {\"from\": \"system\", \"value\": \"You are an AI assistant.\"},\n                    {\"from\": \"human\", \"value\": \"Hello\"},\n                    {\"from\": \"assistant\", \"value\": \"Hi there!\"},\n                    {\"from\": \"human\", \"value\": \"How are you?\"},\n                    {\"from\": \"assistant\", \"value\": \"I'm doing well, thank you!\"},\n                ]\n            }\n        ]\n    )\n\n\n@pytest.fixture(name=\"toolcalling_dataset\")\ndef fixture_toolcalling_dataset():\n    return Dataset.from_list(\n        [\n            {\n                \"messages\": [\n                    {\n                        \"role\": \"system\",\n                        \"content\": \"You are a bot that responds to weather queries. You should reply with the unit used in the queried location.\",\n                    },\n                    {\n                        \"role\": \"user\",\n                        \"content\": \"Hey, what's the temperature in Paris right now?\",\n                    },\n                    {\n                        \"role\": \"assistant\",\n                        \"tool_calls\": [\n                            {\n                                \"type\": \"function\",\n                                \"function\": {\n                                    \"name\": \"get_current_temperature\",\n                                    \"arguments\": {\n                                        \"location\": \"Paris, France\",\n                                        \"unit\": \"celsius\",\n                                    },\n                                },\n                            }\n                        ],\n                    },\n                    {\n                        \"role\": \"tool\",\n                        \"name\": \"get_current_temperature\",\n                        \"content\": \"22.0\",\n                    },\n                    {\n                        \"role\": \"assistant\",\n                        \"content\": \"The temperature in Paris is 22.0 degrees Celsius.\",\n                    },\n                ]\n            }\n        ]\n    )\n\n\n@pytest.fixture(name=\"llama3_tokenizer\", scope=\"session\", autouse=True)\n@enable_hf_offline\ndef fixture_llama3_tokenizer(\n    download_llama3_8b_instruct_model_fixture,\n):\n    tokenizer = AutoTokenizer.from_pretrained(\"NousResearch/Meta-Llama-3-8B-Instruct\")\n\n    return tokenizer\n\n\n@pytest.fixture(name=\"smollm2_tokenizer\", scope=\"session\", autouse=True)\n@enable_hf_offline\ndef fixture_smollm2_tokenizer():\n    tokenizer = AutoTokenizer.from_pretrained(\"HuggingFaceTB/SmolLM2-135M\")\n    return tokenizer\n\n\n@pytest.fixture(name=\"mistralv03_tokenizer\", scope=\"session\", autouse=True)\n@enable_hf_offline\ndef fixture_mistralv03_tokenizer(\n    download_mlx_mistral_7b_model_fixture,\n):\n    tokenizer = AutoTokenizer.from_pretrained(\n        \"mlx-community/Mistral-7B-Instruct-v0.3-4bit\"\n    )\n    return tokenizer\n\n\n@pytest.fixture(name=\"phi35_tokenizer\", scope=\"session\", autouse=True)\n@enable_hf_offline\ndef fixture_phi35_tokenizer():\n    tokenizer = AutoTokenizer.from_pretrained(\"microsoft/Phi-3.5-mini-instruct\")\n    return tokenizer\n\n\n@pytest.fixture(name=\"phi4_tokenizer\", scope=\"session\", autouse=True)\n@enable_hf_offline\ndef fixture_phi4_tokenizer():\n    tokenizer = AutoTokenizer.from_pretrained(\"microsoft/Phi-4-reasoning\")\n    return tokenizer\n\n\n@pytest.fixture(name=\"gemma2_tokenizer\", scope=\"session\", autouse=True)\ndef fixture_gemma2_tokenizer():\n    tokenizer = AutoTokenizer.from_pretrained(\"mlx-community/gemma-2-9b-it-4bit\")\n\n    return tokenizer\n\n\n@pytest.fixture(name=\"magistral_tokenizer\")\ndef fixture_magistral_tokenizer():\n    from axolotl.utils.mistral import HFMistralTokenizer\n\n    tokenizer = HFMistralTokenizer.from_pretrained(\"mistralai/Magistral-Small-2506\")\n    return tokenizer\n\n\n@pytest.fixture(name=\"devstral_tokenizer\")\ndef fixture_devstral_tokenizer():\n    from axolotl.utils.mistral import HFMistralTokenizer\n\n    tokenizer = HFMistralTokenizer.from_pretrained(\"mistralai/Devstral-Small-2505\")\n    return tokenizer\n\n\n@pytest.fixture(name=\"devstral_1_1_tokenizer\")\ndef fixture_devstral_1_1_tokenizer():\n    from axolotl.utils.mistral import HFMistralTokenizer\n\n    tokenizer = HFMistralTokenizer.from_pretrained(\"mistralai/Devstral-Small-2507\")\n    return tokenizer\n\n\n@pytest.fixture(name=\"qwen3_tokenizer\")\n@enable_hf_offline\ndef qwen3_tokenizer_fixture(\n    download_qwen3_half_billion_model,\n):  # pylint: disable=unused-argument,redefined-outer-name\n    tokenizer = AutoTokenizer.from_pretrained(\"Qwen/Qwen3-0.6B\")\n\n    return tokenizer\n\n\n@pytest.fixture(name=\"mistralv03_tokenizer_chat_template_jinja\")\ndef fixture_mistralv03_chat_template_jinja_w_system() -> str:\n    return '{%- if messages[0][\"role\"] == \"system\" %}\\n    {%- set system_message = messages[0][\"content\"] %}\\n    {%- set loop_messages = messages[1:] %}\\n{%- else %}\\n    {%- set loop_messages = messages %}\\n{%- endif %}\\n{%- if not tools is defined %}\\n    {%- set tools = none %}\\n{%- endif %}\\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\\n\\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\\n{%- set ns = namespace() %}\\n{%- set ns.index = 0 %}\\n{%- for message in loop_messages %}\\n    {%- if not (message.role == \"tool\" or message.role == \"tool_results\" or (message.tool_calls is defined and message.tool_calls is not none)) %}\\n        {%- if (message[\"role\"] == \"user\") != (ns.index % 2 == 0) %}\\n            {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\\n        {%- endif %}\\n        {%- set ns.index = ns.index + 1 %}\\n    {%- endif %}\\n{%- endfor %}\\n\\n{{- bos_token }}\\n{%- for message in loop_messages %}\\n    {%- if message[\"role\"] == \"user\" %}\\n        {%- if tools is not none and (message == user_messages[-1]) %}\\n            {{- \"[AVAILABLE_TOOLS] [\" }}\\n            {%- for tool in tools %}\\n                {%- set tool = tool.function %}\\n                {{- \\'{\"type\": \"function\", \"function\": {\\' }}\\n                {%- for key, val in tool.items() if key != \"return\" %}\\n                    {%- if val is string %}\\n                        {{- \\'\"\\' + key + \\'\": \"\\' + val + \\'\"\\' }}\\n                    {%- else %}\\n                        {{- \\'\"\\' + key + \\'\": \\' + val|tojson }}\\n                    {%- endif %}\\n                    {%- if not loop.last %}\\n                        {{- \", \" }}\\n                    {%- endif %}\\n                {%- endfor %}\\n                {{- \"}}\" }}\\n                {%- if not loop.last %}\\n                    {{- \", \" }}\\n                {%- else %}\\n                    {{- \"]\" }}\\n                {%- endif %}\\n            {%- endfor %}\\n            {{- \"[/AVAILABLE_TOOLS]\" }}\\n            {%- endif %}\\n        {%- if loop.first and system_message is defined %}\\n            {{- \"[INST] \" + system_message + \"\\\\n\\\\n\" + message[\"content\"] + \"[/INST]\" }}\\n        {%- else %}\\n            {{- \"[INST] \" + message[\"content\"] + \"[/INST]\" }}\\n        {%- endif %}\\n    {%- elif message.tool_calls is defined and message.tool_calls is not none %}\\n        {{- \"[TOOL_CALLS] [\" }}\\n        {%- for tool_call in message.tool_calls %}\\n            {%- set out = tool_call.function|tojson %}\\n            {{- out[:-1] }}\\n            {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\\n                {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\\n            {%- endif %}\\n            {{- \\', \"id\": \"\\' + tool_call.id + \\'\"}\\' }}\\n            {%- if not loop.last %}\\n                {{- \", \" }}\\n            {%- else %}\\n                {{- \"]\" + eos_token }}\\n            {%- endif %}\\n        {%- endfor %}\\n    {%- elif message[\"role\"] == \"assistant\" %}\\n        {{- \" \" + message[\"content\"]|trim + eos_token}}\\n    {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\\n        {%- if message.content is defined and message.content.content is defined %}\\n            {%- set content = message.content.content %}\\n        {%- else %}\\n            {%- set content = message.content %}\\n        {%- endif %}\\n        {{- \\'[TOOL_RESULTS] {\"content\": \\' + content|string + \", \" }}\\n        {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\\n            {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\\n        {%- endif %}\\n        {{- \\'\"call_id\": \"\\' + message.tool_call_id + \\'\"}[/TOOL_RESULTS]\\' }}\\n    {%- else %}\\n        {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\\n    {%- endif %}\\n{%- endfor %}\\n'\n\n\n@pytest.fixture(name=\"gemma2_tokenizer_chat_template_jinja\")\ndef fixture_gemma2_chat_template_jinja_w_system() -> str:\n    return \"{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\\n' + message['content'] | trim + '<end_of_turn>\\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\\n'}}{% endif %}\"\n\n\n@pytest.fixture(name=\"llama3_2_vision_chat_template_jinja\")\ndef fixture_llama3_2_vision_with_hardcoded_date() -> str:\n    \"\"\"Hardcodes the date in the template to avoid the need for date logic in the prompt\"\"\"\n\n    template = _CHAT_TEMPLATES[\"llama3_2_vision\"]\n\n    old_date_logic = \"\"\"{%- if not date_string is defined %}\n    {%- if strftime_now is defined %}\n        {%- set date_string = strftime_now(\"%d %b %Y\") %}\n    {%- else %}\n        {%- set date_string = \"26 Jul 2024\" %}\n    {%- endif %}\n{%- endif %}\"\"\"\n\n    new_date_logic = \"\"\"{%- set date_string = \"17 Dec 2024\" %}\"\"\"\n\n    modified_template = template.replace(old_date_logic, new_date_logic)\n\n    return modified_template\n\n\n@pytest.fixture(name=\"chat_template_jinja_with_optional_fields\")\ndef fixture_chat_template_jinja_with_optional_fields() -> str:\n    return \"\"\"{% for message in messages %}\n{{'<|im_start|>'}}{{ message['role'] }}\n{% if message['thoughts'] is defined %}[Thoughts: {{ message['thoughts'] }}]{% endif %}\n{% if message['tool_calls'] is defined %}[Tool: {{ message['tool_calls'][0]['type'] }}]{% endif %}\n{{ message['content'] }}{{'<|im_end|>'}}\n{% endfor %}\"\"\"\n\n\n@pytest.fixture(name=\"basic_jinja_template_analyzer\")\ndef basic_jinja_template_analyzer():\n    return JinjaTemplateAnalyzer(\n        \"\"\"{% for message in messages %}{% if message['role'] == 'system' and message['content'] %}{{'<|system|>\n' + message['content'] + '<|end|>\n'}}{% elif message['role'] == 'user' %}{{'<|user|>\n' + message['content'] + '<|end|>\n'}}{% elif message['role'] == 'assistant' %}{{'<|assistant|>\n' + message['content'] + '<|end|>\n'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% else %}{{ eos_token }}{% endif %}\"\"\"\n    )\n\n\n@pytest.fixture(name=\"mistral_jinja_template_analyzer\")\ndef mistral_jinja_template_analyzer(mistralv03_tokenizer_chat_template_jinja):\n    return JinjaTemplateAnalyzer(mistralv03_tokenizer_chat_template_jinja)\n"
  },
  {
    "path": "tests/prompt_strategies/messages/__init__.py",
    "content": ""
  },
  {
    "path": "tests/prompt_strategies/messages/test_chat.py",
    "content": "\"\"\"\ntests for chat_template prompt strategy\n\"\"\"\n\nimport unittest\n\nfrom axolotl.prompt_strategies.messages.chat import load\nfrom axolotl.utils.dict import DictDefault\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__, log_level=\"DEBUG\")\n\n\nclass TestMessagesChatLlama3:\n    \"\"\"\n    Test class for assistant style datasets with llama-3 prompts using the messages chat llama3 strategy.\n    \"\"\"\n\n    def test_llama3_load(self, llama3_tokenizer, assistant_dataset):\n        LOG.info(\"Loading llama-3 tokenizer with assistant dataset\")\n        strategy = load(\n            llama3_tokenizer,\n            DictDefault(\n                {\n                    \"train_on_inputs\": False,\n                    \"sequence_len\": 512,\n                }\n            ),\n            DictDefault(\n                {\n                    \"chat_template\": \"llama3\",\n                    \"message_field_role\": \"role\",\n                    \"message_field_content\": \"content\",\n                    \"field_messages\": \"messages\",\n                }\n            ),\n        )\n        res = strategy.wrap_dataset(assistant_dataset)\n        input_ids = res[0][\"input_ids\"]\n        # fmt: off\n        expected_input_ids = [\n            128000,  # bos\n            128006, 882, 128007,  # user header\n            271, 15339, 128009,  # user prompt eot\n            128006, 78191, 128007,  # assistant header\n            271, 15339, 128009,  # assistant response eot\n            128006, 882, 128007,\n            271, 19045, 29474, 128009,\n            128006, 78191, 128007,\n            271, 19045, 29474, 128009,\n        ]\n        # fmt: on\n        LOG.debug(f\"Expected input_ids: {expected_input_ids}\")\n        LOG.debug(f\"Actual input_ids: {input_ids}\")\n        assert input_ids == expected_input_ids, (\n            f\"Input IDs mismatch: {input_ids} != {expected_input_ids}\"\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/prompt_strategies/test_alpaca.py",
    "content": "\"\"\"\nTest module for alpaca integration w chatml\n\"\"\"\n\nimport pytest\nfrom datasets import Dataset\nfrom tokenizers import AddedToken\nfrom transformers import AutoTokenizer\n\nfrom axolotl.datasets import TokenizedPromptDataset\nfrom axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy\nfrom axolotl.prompters import AlpacaPrompter, PromptStyle\n\nfrom tests.hf_offline_utils import enable_hf_offline\n\n\n@pytest.fixture(name=\"alpaca_dataset\")\ndef fixture_alpaca_dataset():\n    return Dataset.from_list(\n        [\n            {\n                \"instruction\": \"Evaluate this sentence for spelling and grammar mistakes\",\n                \"input\": \"He finnished his meal and left the resturant\",\n                \"output\": \"He finished his meal and left the restaurant.\",\n            }\n        ]\n    )\n\n\n@pytest.fixture(name=\"tokenizer\")\n@enable_hf_offline\ndef fixture_tokenizer():\n    tokenizer = AutoTokenizer.from_pretrained(\n        \"casperhansen/mistral-7b-instruct-v0.1-awq\"\n    )\n    tokenizer.add_special_tokens(\n        {\n            \"eos_token\": AddedToken(\n                \"<|im_end|>\", rstrip=False, lstrip=False, normalized=False\n            )\n        }\n    )\n    tokenizer.add_tokens(\n        [\n            AddedToken(\"<|im_start|>\", rstrip=False, lstrip=False, normalized=False),\n        ]\n    )\n\n    return tokenizer\n\n\nclass TestAlpacaChatml:\n    \"\"\"\n    Test class for alpaca prompter\n    \"\"\"\n\n    def test_no_double_im_end(self, alpaca_dataset, tokenizer):\n        strategy = AlpacaPromptTokenizingStrategy(\n            AlpacaPrompter(prompt_style=PromptStyle.CHATML.value),\n            tokenizer,\n            False,  # train_on_inputs\n            2048,  # sequence_len\n        )\n\n        dataset_wrapper = TokenizedPromptDataset(\n            strategy, alpaca_dataset, process_count=1\n        )\n\n        input_ids = dataset_wrapper[0][\"input_ids\"]\n        # fmt: off\n        assert input_ids == [\n            1,  # Bos\n            32001, 1587, 13, 20548, 336, 349, 396, 13126, 369, 13966, 264, 3638, 28725, 5881, 1360, 395, 396, 2787, 369, 5312, 3629, 2758, 28723, 12018, 264, 2899, 369, 6582, 1999, 2691, 274, 272, 2159, 28723, 32000, 28705, 13,  # instruction\n            32001, 2188, 13, 16627, 11931, 456, 12271, 354, 668, 3572, 304, 18756, 3479, 17179, 13, 2428, 854, 28711, 1497, 516, 11314, 304, 1749, 272, 1846, 324, 440, 32000, 28705, 13,  # input\n            32001, 13892, 13, 650, 5967, 516, 11314, 304, 1749, 272, 9926, 28723, 32000,  # output\n        ]\n        # fmt: on\n\n    def test_no_train_on_input(self, alpaca_dataset, tokenizer):\n        strategy = AlpacaPromptTokenizingStrategy(\n            AlpacaPrompter(prompt_style=PromptStyle.CHATML.value),\n            tokenizer,\n            False,  # train_on_inputs\n            2048,  # sequence_len\n        )\n\n        dataset_wrapper = TokenizedPromptDataset(\n            strategy, alpaca_dataset, process_count=1\n        )\n\n        labels = dataset_wrapper[0][\"labels\"]\n        # fmt: off\n        assert labels == [\n            -100,  # bos\n            -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,  # instruction\n            -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,  # input\n            -100, -100, -100, 650, 5967, 516, 11314, 304, 1749, 272, 9926, 28723, 32000,  # Output\n        ]\n        # fmt: on\n\n    def test_w_train_on_input(self, alpaca_dataset, tokenizer):\n        strategy = AlpacaPromptTokenizingStrategy(\n            AlpacaPrompter(prompt_style=PromptStyle.CHATML.value),\n            tokenizer,\n            True,  # train_on_inputs\n            2048,  # sequence_len\n        )\n\n        dataset_wrapper = TokenizedPromptDataset(\n            strategy, alpaca_dataset, process_count=1\n        )\n\n        labels = dataset_wrapper[0][\"labels\"]\n        # fmt: off\n        assert labels == [\n            1,  # Bos\n            32001, 1587, 13, 20548, 336, 349, 396, 13126, 369, 13966, 264, 3638, 28725, 5881, 1360, 395, 396, 2787, 369, 5312, 3629, 2758, 28723, 12018, 264, 2899, 369, 6582, 1999, 2691, 274, 272, 2159, 28723, 32000, 28705, 13,  # instruction\n            32001, 2188, 13, 16627, 11931, 456, 12271, 354, 668, 3572, 304, 18756, 3479, 17179, 13, 2428, 854, 28711, 1497, 516, 11314, 304, 1749, 272, 1846, 324, 440, 32000, 28705, 13,  # input\n            32001, 13892, 13, 650, 5967, 516, 11314, 304, 1749, 272, 9926, 28723, 32000,  # output\n        ]\n        # fmt: on\n"
  },
  {
    "path": "tests/prompt_strategies/test_chat_template_ds_schema_unification.py",
    "content": "\"\"\"\nTests for chat template prompt strategy with schema unification for none fields\n\"\"\"\n\nimport json\n\nimport pytest\nfrom datasets import Dataset\n\nfrom axolotl.prompt_strategies.chat_template import StrategyLoader\nfrom axolotl.utils.dict import DictDefault\n\n\n@pytest.fixture(name=\"messages_w_tools\")\ndef fixture_messages_w_tools():\n    jsons = \"\"\"\n{\"messages\":[{\"role\":\"user\",\"content\":\"move to (0, 1)\"},{\"role\":\"assistant\",\"content\":\"\",\"tool_calls\":[{\"function\":{\"name\":\"move\",\"arguments\":{\"x\":0,\"y\":1}}}]}],\"tools\":[{\"type\":\"function\",\"function\":{\"name\":\"move\",\"description\":\"Move to a given location measured in meters\",\"parameters\":{\"type\":\"object\",\"properties\":{\"x\":{\"type\":\"number\",\"description\":\"The x coordinate of the location, negative values are to the left, positive values are to the right\"},\"y\":{\"type\":\"number\",\"description\":\"The y coordinate of the location, negative values are backward, positive values are forward\"}},\"required\":[\"x\",\"y\"]}}},{\"type\":\"function\",\"function\":{\"name\":\"turn\",\"description\":\"Turn the robot to a given direction\",\"parameters\":{\"type\":\"object\",\"properties\":{\"theta\":{\"type\":\"integer\",\"description\":\"The angle to turn to, in degrees, positive values are counter-clockwise, negative values are clockwise\"}},\"required\":[\"theta\"]}}},{\"type\":\"function\",\"function\":{\"name\":\"invalid_prompt\",\"description\":\"call when the user's prompt is invalid\",\"parameters\":{\"type\":\"object\",\"properties\":{\"message\":{\"type\":\"string\",\"description\":\"why the prompt is invalid\"}},\"required\":[\"message\"]}}}],\"add_generation_prompt\":false}\n{\"messages\":[{\"role\":\"user\",\"content\":\"turn 270 degree\"},{\"role\":\"assistant\",\"content\":\"\",\"tool_calls\":[{\"function\":{\"name\":\"turn\",\"arguments\":{\"theta\": 270}}}]}],\"tools\":[{\"type\":\"function\",\"function\":{\"name\":\"move\",\"description\":\"Move to a given location measured in meters\",\"parameters\":{\"type\":\"object\",\"properties\":{\"x\":{\"type\":\"number\",\"description\":\"The x coordinate of the location, negative values are to the left, positive values are to the right\"},\"y\":{\"type\":\"number\",\"description\":\"The y coordinate of the location, negative values are backward, positive values are forward\"}},\"required\":[\"x\",\"y\"]}}},{\"type\":\"function\",\"function\":{\"name\":\"turn\",\"description\":\"Turn the robot to a given direction\",\"parameters\":{\"type\":\"object\",\"properties\":{\"theta\":{\"type\":\"integer\",\"description\":\"The angle to turn to, in degrees, positive values are counter-clockwise, negative values are clockwise\"}},\"required\":[\"theta\"]}}},{\"type\":\"function\",\"function\":{\"name\":\"invalid_prompt\",\"description\":\"call when the user's prompt is invalid\",\"parameters\":{\"type\":\"object\",\"properties\":{\"message\":{\"type\":\"string\",\"description\":\"why the prompt is invalid\"}},\"required\":[\"message\"]}}}],\"add_generation_prompt\":false}\n{\"messages\":[{\"role\":\"user\",\"content\":\"jump high\"},{\"role\":\"assistant\",\"content\":\"\",\"tool_calls\":[{\"function\":{\"name\":\"invalid_prompt\",\"arguments\":{\"message\": \"jump is not a valid action\"}}}]}],\"tools\":[{\"type\":\"function\",\"function\":{\"name\":\"move\",\"description\":\"Move to a given location measured in meters\",\"parameters\":{\"type\":\"object\",\"properties\":{\"x\":{\"type\":\"number\",\"description\":\"The x coordinate of the location, negative values are to the left, positive values are to the right\"},\"y\":{\"type\":\"number\",\"description\":\"The y coordinate of the location, negative values are backward, positive values are forward\"}},\"required\":[\"x\",\"y\"]}}},{\"type\":\"function\",\"function\":{\"name\":\"turn\",\"description\":\"Turn the robot to a given direction\",\"parameters\":{\"type\":\"object\",\"properties\":{\"theta\":{\"type\":\"integer\",\"description\":\"The angle to turn to, in degrees, positive values are counter-clockwise, negative values are clockwise\"}},\"required\":[\"theta\"]}}},{\"type\":\"function\",\"function\":{\"name\":\"invalid_prompt\",\"description\":\"call when the user's prompt is invalid\",\"parameters\":{\"type\":\"object\",\"properties\":{\"message\":{\"type\":\"string\",\"description\":\"why the prompt is invalid\"}},\"required\":[\"message\"]}}}],\"add_generation_prompt\":false}\n    \"\"\".strip().split(\"\\n\")\n    rows = [json.loads(row) for row in jsons]\n    return Dataset.from_list(rows)\n\n\n@pytest.fixture(name=\"qwen3_prompt_strategy\")\ndef qwen3_chat_template_strategy(qwen3_tokenizer):\n    cfg = DictDefault(\n        sequence_len=2048,\n        chat_template=\"qwen3\",\n        eot_tokens=[\"<|im_end|>\"],\n    )\n    ds_cfg = DictDefault(\n        type=\"chat_template\",\n    )\n    load = StrategyLoader()\n    strat = load(qwen3_tokenizer, cfg, ds_cfg)\n    return strat\n\n\nclass TestSchemaUnification:\n    \"\"\"\n    Test class on handling null fields for tool calling\n    \"\"\"\n\n    def test_schema_unification_single_prompt(\n        self, messages_w_tools, qwen3_prompt_strategy, qwen3_tokenizer\n    ):\n        for row in messages_w_tools:\n            inputs = qwen3_prompt_strategy.tokenize_prompt(row)\n            decoded = qwen3_tokenizer.decode(inputs[\"input_ids\"])\n            tool_call = decoded.split(\"<tool_call>\")[-1].split(\"</tool_call>\")[0]\n            assert '\"message\": null' not in tool_call\n            assert '\"theta\": null' not in tool_call\n\n    def test_schema_unification_batched(\n        self, messages_w_tools, qwen3_prompt_strategy, qwen3_tokenizer\n    ):\n        rows = messages_w_tools.map(qwen3_prompt_strategy.tokenize_prompt, batched=True)\n        for row in rows:\n            decoded = qwen3_tokenizer.decode(row[\"input_ids\"])\n            tool_call = decoded.split(\"<tool_call>\")[-1].split(\"</tool_call>\")[0]\n            assert '\"message\": null' not in tool_call\n            assert '\"theta\": null' not in tool_call\n"
  },
  {
    "path": "tests/prompt_strategies/test_chat_template_utils.py",
    "content": "\"\"\"\nTests for utils in axolotl.utils.chat_templates\n\"\"\"\n\nimport unittest\n\nimport pytest\nfrom transformers import AutoTokenizer\n\nfrom axolotl.utils.chat_templates import (\n    _CHAT_TEMPLATES,\n    extract_chat_template_args,\n    get_chat_template,\n)\n\nfrom tests.hf_offline_utils import enable_hf_offline\n\n\n@pytest.fixture(name=\"llama3_tokenizer\")\n@enable_hf_offline\ndef fixture_llama3_tokenizer():\n    tokenizer = AutoTokenizer.from_pretrained(\"NousResearch/Meta-Llama-3-8B\")\n\n    return tokenizer\n\n\nclass TestGetChatTemplateUtils:\n    \"\"\"\n    Tests the get_chat_template function.\n    \"\"\"\n\n    def test_known_chat_template(self):\n        chat_template_str = get_chat_template(\"llama3\")\n        assert chat_template_str == _CHAT_TEMPLATES[\"llama3\"]\n\n    def test_invalid_chat_template(self):\n        with pytest.raises(ValueError) as exc:\n            get_chat_template(\"invalid_template\")\n            assert str(exc) == \"Template 'invalid_template' not found.\"\n\n    def test_tokenizer_default_no_tokenizer(self):\n        with pytest.raises(ValueError):\n            get_chat_template(\"tokenizer_default\", tokenizer=None)\n\n    def test_tokenizer_default_no_chat_template_on_tokenizer(self, llama3_tokenizer):\n        with pytest.raises(ValueError):\n            get_chat_template(\"tokenizer_default\", tokenizer=llama3_tokenizer)\n\n    def test_tokenizer_default_with_chat_template_on_tokenizer(self, llama3_tokenizer):\n        llama3_tokenizer.chat_template = \"test_template\"\n        chat_template_str = get_chat_template(\n            \"tokenizer_default\", tokenizer=llama3_tokenizer\n        )\n        assert chat_template_str == \"test_template\"\n\n    def test_tokenizer_default_fallback_no_tokenizer(self):\n        with pytest.raises(ValueError):\n            get_chat_template(\"tokenizer_default_fallback_test\", tokenizer=None)\n\n    def test_tokenizer_default_fallback_no_chat_template_on_tokenizer(\n        self, llama3_tokenizer\n    ):\n        chat_template_str = get_chat_template(\n            \"tokenizer_default_fallback_chatml\", tokenizer=llama3_tokenizer\n        )\n        assert chat_template_str == get_chat_template(\"chatml\")\n\n    def test_tokenizer_default_fallback_with_chat_template_on_tokenizer(\n        self, llama3_tokenizer\n    ):\n        llama3_tokenizer.chat_template = \"test_template\"\n        chat_template_str = get_chat_template(\n            \"tokenizer_default_fallback_chatml\", tokenizer=llama3_tokenizer\n        )\n        assert chat_template_str == \"test_template\"\n\n    def test_jinja_template_mode(self):\n        jinja_template = \"example_jinja_template\"\n        chat_template_str = get_chat_template(\"jinja\", jinja_template=jinja_template)\n        assert chat_template_str == jinja_template\n\n    def test_jinja_template_mode_no_jinja_template(self):\n        with pytest.raises(ValueError):\n            get_chat_template(\"jinja\", jinja_template=None)\n\n    def test_extract_chat_template_args(self):\n        # No ds_cfg\n        chat_template_choice, chat_template_jinja = extract_chat_template_args(\n            cfg={\"chat_template\": \"chatml\"},\n        )\n        assert chat_template_choice == \"chatml\"\n        assert chat_template_jinja is None\n\n        # ds_cfg provided\n        chat_template_choice, chat_template_jinja = extract_chat_template_args(\n            cfg={\n                \"chat_template\": \"jinja\",\n                \"chat_template_jinja\": \"global_jinja_template\",\n            },\n            ds_cfg={\"chat_template\": \"llama3\", \"chat_template_jinja\": None},\n        )\n        assert chat_template_choice == \"llama3\"\n        assert chat_template_jinja is None\n\n        # ds_cfg provided with jinja template\n        chat_template_choice, chat_template_jinja = extract_chat_template_args(\n            cfg={\"chat_template\": \"chatml\", \"chat_template_jinja\": None},\n            ds_cfg={\n                \"chat_template\": \"jinja\",\n                \"chat_template_jinja\": \"ds_jinja_template\",\n            },\n        )\n        assert chat_template_choice == \"jinja\"\n        assert chat_template_jinja == \"ds_jinja_template\"\n\n        # ds_cfg provided with no chat_template\n        chat_template_choice, chat_template_jinja = extract_chat_template_args(\n            cfg={\n                \"chat_template\": \"jinja\",\n                \"chat_template_jinja\": \"global_jinja_template\",\n            },\n            ds_cfg={\"chat_template\": None, \"chat_template_jinja\": \"ds_jinja_template\"},\n        )\n        assert chat_template_choice == \"jinja\"\n        assert chat_template_jinja == \"global_jinja_template\"\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/prompt_strategies/test_chat_templates.py",
    "content": "\"\"\"\ntests for chat_template prompt strategy\n\"\"\"\n\nimport unittest\n\nfrom axolotl.prompt_strategies.chat_template import (\n    ChatTemplatePrompter,\n    ChatTemplateStrategy,\n    load,\n)\nfrom axolotl.prompters import IGNORE_TOKEN_ID\nfrom axolotl.utils.chat_templates import get_chat_template\nfrom axolotl.utils.dict import DictDefault\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__)\n\n\nclass TestAssistantChatTemplateLlama3:\n    \"\"\"\n    Test class for assistant style datasets with llama-3 prompts using the chat_template strategy.\n    \"\"\"\n\n    def test_llama3_load(self, llama3_tokenizer, assistant_dataset):\n        LOG.info(\"Loading llama-3 tokenizer with assistant dataset\")\n        strategy = load(\n            llama3_tokenizer,\n            DictDefault(\n                {\n                    \"train_on_inputs\": False,\n                    \"sequence_len\": 512,\n                }\n            ),\n            DictDefault(\n                {\n                    \"chat_template\": \"llama3\",\n                    \"message_field_role\": \"role\",\n                    \"message_field_content\": \"content\",\n                    \"message_property_mappings\": {\n                        \"role\": \"role\",\n                        \"content\": \"content\",\n                    },\n                    \"roles\": {\n                        \"user\": [\"user\"],\n                        \"assistant\": [\"assistant\"],\n                        \"system\": [\"system\"],\n                    },\n                    \"field_messages\": \"messages\",\n                }\n            ),\n        )\n        res = strategy.tokenize_prompt(assistant_dataset[0])\n        input_ids = res[\"input_ids\"]\n        # fmt: off\n        expected_input_ids = [\n            128000,  # bos\n            128006, 882, 128007,  # user header\n            271, 15339, 128009,  # user prompt eot\n            128006, 78191, 128007,  # assistant header\n            271, 15339, 128009,  # assistant response eot\n            128006, 882, 128007,\n            271, 19045, 29474, 128009,\n            128006, 78191, 128007,\n            271, 19045, 29474, 128009,\n        ]\n        # fmt: on\n        LOG.debug(f\"Expected input_ids: {expected_input_ids}\")\n        LOG.debug(f\"Actual input_ids: {input_ids}\")\n        assert input_ids == expected_input_ids, (\n            f\"Input IDs mismatch: {input_ids} != {expected_input_ids}\"\n        )\n\n    def test_llama3(self, llama3_tokenizer, assistant_dataset):\n        LOG.info(\"Testing llama-3 with assistant dataset\")\n        strategy = ChatTemplateStrategy(\n            ChatTemplatePrompter(\n                llama3_tokenizer,\n                chat_template=get_chat_template(\"llama3\"),\n                message_property_mappings={\n                    \"role\": \"role\",\n                    \"content\": \"content\",\n                },\n                roles={\n                    \"user\": [\"user\"],\n                    \"assistant\": [\"assistant\"],\n                    \"system\": [\"system\"],\n                },\n            ),\n            tokenizer=llama3_tokenizer,\n            train_on_inputs=False,\n            sequence_len=512,\n        )\n\n        res = strategy.tokenize_prompt(assistant_dataset[0])\n        input_ids = res[\"input_ids\"]\n        # fmt: off\n        expected_input_ids = [\n            128000,  # bos\n            128006, 882, 128007,  # user header\n            271, 15339, 128009,  # user prompt eot\n            128006, 78191, 128007,  # assistant header\n            271, 15339, 128009,   # assistant response eot\n            128006, 882, 128007,\n            271, 19045, 29474, 128009,\n            128006, 78191, 128007,\n            271, 19045, 29474, 128009,\n        ]\n        # fmt: on\n        LOG.debug(f\"Expected input_ids: {expected_input_ids}\")\n        LOG.debug(f\"Actual input_ids: {input_ids}\")\n        assert input_ids == expected_input_ids, (\n            f\"Input IDs mismatch: {input_ids} != {expected_input_ids}\"\n        )\n\n    def test_phi35(self, phi35_tokenizer, assistant_dataset):\n        LOG.info(\"Testing phi-3.5 with assistant dataset\")\n        strategy = ChatTemplateStrategy(\n            ChatTemplatePrompter(\n                phi35_tokenizer,\n                chat_template=get_chat_template(\"phi_35\"),\n                message_property_mappings={\n                    \"role\": \"role\",\n                    \"content\": \"content\",\n                },\n                roles={\n                    \"user\": [\"user\"],\n                    \"assistant\": [\"assistant\"],\n                    \"system\": [\"system\"],\n                },\n            ),\n            tokenizer=phi35_tokenizer,\n            train_on_inputs=False,\n            sequence_len=512,\n        )\n\n        res = strategy.tokenize_prompt(assistant_dataset[0])\n        input_ids = res[\"input_ids\"]\n        labels = res[\"labels\"]\n        # fmt: off\n        expected_input_ids = [\n            32010,  # user\n            22172, 32007,  # user eot\n            32001,  # assistant\n            22172, 32007,  # assistant eot\n            32010,  # user\n            1781, 26966, 32007,  # user eot\n            32001,  # assistant\n            1781, 26966, 32007,  # assistant eot\n        ]\n        expected_labels = [\n            -100,  # user\n            -100, -100,  # user eot\n            -100,  # assistant\n            -100, -100,  # assistant eot,\n            -100,  # user\n            -100, -100, -100,  # user eot\n            -100,  # assistant\n            1781, 26966, 32007,  # assistant eot\n        ]\n        # fmt: on\n        LOG.debug(f\"Expected input_ids: {expected_input_ids}\")\n        LOG.debug(f\"Actual input_ids: {input_ids}\")\n        assert input_ids == expected_input_ids, (\n            f\"Input IDs mismatch: {input_ids} != {expected_input_ids}\"\n        )\n\n        LOG.debug(f\"Expected labels : {expected_labels}\")\n        LOG.debug(f\"Actual labels : {labels}\")\n        assert labels == expected_labels, (\n            f\"Input IDs mismatch: {labels} != {expected_labels}\"\n        )\n\n    def test_llama3_with_training_data(self, llama3_tokenizer, assistant_dataset):\n        LOG.info(\"Testing llama-3 with assistant dataset including training data\")\n        strategy = ChatTemplateStrategy(\n            ChatTemplatePrompter(\n                llama3_tokenizer,\n                chat_template=get_chat_template(\"llama3\"),\n                message_field_training=\"training\",\n                message_property_mappings={\n                    \"role\": \"role\",\n                    \"content\": \"content\",\n                },\n                roles={\n                    \"user\": [\"user\"],\n                    \"assistant\": [\"assistant\"],\n                    \"system\": [\"system\"],\n                },\n            ),\n            tokenizer=llama3_tokenizer,\n            train_on_inputs=False,\n            train_on_eos=\"none\",\n            sequence_len=512,\n            roles_to_train=[\"assistant\"],\n        )\n\n        prompt_tokens = strategy.prompter.build_prompt(\n            assistant_dataset[0][\"messages\"], False\n        )\n        prompt = llama3_tokenizer.decode(prompt_tokens, skip_special_tokens=False)\n        LOG.debug(f\"Generated prompt: {prompt}\")\n        res = strategy.tokenize_prompt(assistant_dataset[0])\n        labels = res[\"labels\"]\n        input_ids = res[\"input_ids\"]\n        # fmt: off\n        expected_labels = [\n            IGNORE_TOKEN_ID,  # bos\n            IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,  # user header\n            IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,  # user prompt eot\n            IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,  # assistant header\n            IGNORE_TOKEN_ID, 15339, IGNORE_TOKEN_ID,  # assistant response eot\n            IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,\n            IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,\n            IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,\n            IGNORE_TOKEN_ID, 19045, 29474, IGNORE_TOKEN_ID,\n        ]\n        # fmt: on\n\n        LOG.debug(f\"Expected labels: {expected_labels}\")\n        LOG.debug(f\"Actual labels: {labels}\")\n        assert labels == expected_labels, (\n            f\"Labels mismatch:\\n\"\n            f\"Expected: {expected_labels}\\n\"\n            f\"Actual: {labels}\\n\"\n            f\"Input IDs: {input_ids}\\n\"\n        )\n\n\nclass TestSharegptChatTemplateLlama3:\n    \"\"\"\n    Test class for ShareGPT style datasets with llama-3 prompts using the chat_template strategy.\n    \"\"\"\n\n    def test_llama3_assistant(self, llama3_tokenizer, sharegpt_dataset):\n        LOG.info(\"Testing ShareGPT style datasets with llama-3 assistant prompts\")\n\n        strategy = ChatTemplateStrategy(\n            ChatTemplatePrompter(\n                llama3_tokenizer,\n                chat_template=get_chat_template(\"llama3\"),\n                message_property_mappings={\n                    \"role\": \"from\",\n                    \"content\": \"value\",\n                },\n                field_messages=\"conversations\",\n            ),\n            tokenizer=llama3_tokenizer,\n            train_on_inputs=False,\n            train_on_eos=\"none\",\n            sequence_len=512,\n            roles_to_train=[\"gpt\"],\n        )\n\n        res = strategy.tokenize_prompt(sharegpt_dataset[0])\n        input_ids = res[\"input_ids\"]\n        labels = res[\"labels\"]\n        # fmt: off\n        expected_input_ids = [\n            128000,  # bos\n            128006, 882, 128007,  # user header\n            271, 15339, 128009,  # user prompt eot\n            128006, 78191, 128007,  # assistant header\n            271, 15339, 128009,  # assistant response eot\n            128006, 882, 128007,\n            271, 19045, 29474, 128009,\n            128006, 78191, 128007,\n            271, 19045, 29474, 128009,\n        ]\n        expected_labels = [\n            IGNORE_TOKEN_ID,  # bos\n            IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,  # user header\n            IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,  # user prompt eot\n            IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,  # assistant header\n            IGNORE_TOKEN_ID, 15339, IGNORE_TOKEN_ID,  # assistant response eot\n            IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,\n            IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,\n            IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,\n            IGNORE_TOKEN_ID, 19045, 29474, IGNORE_TOKEN_ID,\n        ]\n        # fmt: on\n\n        LOG.debug(f\"Expected input_ids: {expected_input_ids}\")\n        LOG.debug(f\"Actual input_ids: {input_ids}\")\n        LOG.debug(f\"Expected labels: {expected_labels}\")\n        LOG.debug(f\"Actual labels: {labels}\")\n\n        assert input_ids == expected_input_ids, (\n            f\"Input IDs mismatch: {input_ids} != {expected_input_ids}\"\n        )\n        assert labels == expected_labels, (\n            f\"Labels mismatch: {labels} != {expected_labels}\"\n        )\n\n    def test_llama3_human(self, llama3_tokenizer, sharegpt_dataset):\n        LOG.info(\"Testing ShareGPT style datasets with llama-3 human prompts\")\n\n        strategy = ChatTemplateStrategy(\n            ChatTemplatePrompter(\n                llama3_tokenizer,\n                chat_template=get_chat_template(\"llama3\"),\n                message_property_mappings={\n                    \"role\": \"from\",\n                    \"content\": \"value\",\n                },\n                field_messages=\"conversations\",\n            ),\n            tokenizer=llama3_tokenizer,\n            train_on_inputs=False,\n            train_on_eos=\"none\",\n            sequence_len=512,\n            roles_to_train=[\"human\"],\n        )\n\n        res = strategy.tokenize_prompt(sharegpt_dataset[0])\n        input_ids = res[\"input_ids\"]\n        labels = res[\"labels\"]\n        # fmt: off\n        expected_input_ids = [\n            128000,  # bos\n            128006, 882, 128007,  # user header\n            271, 15339, 128009,  # user prompt eot\n            128006, 78191, 128007,  # assistant header\n            271, 15339, 128009,  # assistant response eot\n            128006, 882, 128007,\n            271, 19045, 29474, 128009,\n            128006, 78191, 128007,\n            271, 19045, 29474, 128009,\n        ]\n        expected_labels = [\n            IGNORE_TOKEN_ID,  # bos\n            IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,  # user header\n            IGNORE_TOKEN_ID, 15339, IGNORE_TOKEN_ID,  # user prompt eot\n            IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,  # assistant header\n            IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,  # assistant response eot\n            IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,\n            IGNORE_TOKEN_ID, 19045, 29474, IGNORE_TOKEN_ID,\n            IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,\n            IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,\n        ]\n        # fmt: on\n\n        LOG.debug(f\"Expected input_ids: {expected_input_ids}\")\n        LOG.debug(f\"Actual input_ids: {input_ids}\")\n        LOG.debug(f\"Expected labels: {expected_labels}\")\n        LOG.debug(f\"Actual labels: {labels}\")\n\n        assert input_ids == expected_input_ids, (\n            f\"Input IDs mismatch: {input_ids} != {expected_input_ids}\"\n        )\n        assert labels == expected_labels, (\n            f\"Labels mismatch: {labels} != {expected_labels}\"\n        )\n\n    def test_llama3_system_human(self, llama3_tokenizer, basic_dataset):\n        LOG.info(\"Testing ShareGPT style datasets with llama-3 system/human prompts\")\n\n        strategy = ChatTemplateStrategy(\n            ChatTemplatePrompter(\n                llama3_tokenizer,\n                chat_template=get_chat_template(\"llama3\"),\n                message_property_mappings={\n                    \"role\": \"from\",\n                    \"content\": \"value\",\n                },\n                field_messages=\"conversations\",\n            ),\n            tokenizer=llama3_tokenizer,\n            train_on_inputs=False,\n            train_on_eos=\"none\",\n            sequence_len=512,\n            roles_to_train=[\"system\", \"human\"],\n        )\n\n        res = strategy.tokenize_prompt(basic_dataset[0])\n        input_ids = res[\"input_ids\"]\n        labels = res[\"labels\"]\n        # fmt: off\n        expected_input_ids = [\n            128000,  # bos\n            128006, 9125, 128007,\n            271, 2675, 527, 459, 15592, 18328, 13, 128009,\n            128006, 882, 128007,  # user header\n            271, 9906, 128009,  # user prompt eot\n            128006, 78191, 128007,  # assistant header\n            271, 13347, 1070, 0, 128009,  # assistant response eot\n            128006, 882, 128007,\n            271, 4438, 527, 499, 30, 128009,\n            128006, 78191, 128007,\n            271, 40, 2846, 3815, 1664, 11, 9901, 499, 0, 128009,\n        ]\n        expected_labels = [\n            IGNORE_TOKEN_ID,  # bos\n            IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,  # system header\n            IGNORE_TOKEN_ID, 2675, 527, 459, 15592, 18328, 13, IGNORE_TOKEN_ID,  # system prompt eot\n            IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,  # user header\n            IGNORE_TOKEN_ID, 9906, IGNORE_TOKEN_ID,  # user prompt eot\n            IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,  # assistant header\n            IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,  # assistant response eot\n            IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,\n            IGNORE_TOKEN_ID, 4438, 527, 499, 30, IGNORE_TOKEN_ID,\n            IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,\n            IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,\n        ]\n        # fmt: on\n\n        LOG.debug(f\"Expected input_ids: {expected_input_ids}\")\n        LOG.debug(f\"Actual input_ids: {input_ids}\")\n        LOG.debug(f\"Expected labels: {expected_labels}\")\n        LOG.debug(f\"Actual labels: {labels}\")\n\n        assert input_ids == expected_input_ids, (\n            f\"Input IDs mismatch: {input_ids} != {expected_input_ids}\"\n        )\n        assert labels == expected_labels, (\n            f\"Labels mismatch: {labels} != {expected_labels}\"\n        )\n\n\nclass TestAssistantToolCallingChatTemplateLlama32Vision:\n    \"\"\"\n    Test class for assistant style datasets with tool_calling prompts using the llama-32_vision chat template.\n    \"\"\"\n\n    def test_llama32vision_train_on_assistant(\n        self, llama3_tokenizer, toolcalling_dataset, llama3_2_vision_chat_template_jinja\n    ):\n        LOG.info(\n            \"Testing assistant style datasets with tool_calling with llama-32 chat template, training on assistant\"\n        )\n\n        strategy = ChatTemplateStrategy(\n            ChatTemplatePrompter(\n                llama3_tokenizer,\n                chat_template=get_chat_template(\n                    \"jinja\", jinja_template=llama3_2_vision_chat_template_jinja\n                ),\n                message_property_mappings={\"role\": \"role\", \"content\": \"content\"},\n            ),\n            tokenizer=llama3_tokenizer,\n            train_on_inputs=False,\n            train_on_eos=\"turn\",\n            sequence_len=512,\n            roles_to_train=[\"assistant\"],\n        )\n\n        res = strategy.tokenize_prompt(toolcalling_dataset[0])\n\n        input_ids = res[\"input_ids\"]\n        labels = res[\"labels\"]\n\n        # fmt: off\n        expected_input_ids = [\n            128000,  # bos\n            128006, 9125, 128007, 271,  # system header\n            38766, 1303, 33025, 2696, 25, 6790, 220, 2366, 18, 198, 15724, 2696, 25, 220, 1114, 3799, 220, 2366, 19, 271,  # system date prompt\n            2675, 527, 264, 11164, 430, 31680, 311, 9282, 20126, 13, 1472, 1288, 10052, 449, 279, 5089, 1511, 304, 279, 79002, 3813, 13, 128009,  # system message\n            128006, 882, 128007, 271,  # user header\n            19182, 11, 1148, 596, 279, 9499, 304, 12366, 1314, 1457, 30, 128009,  # user message\n            128006, 78191, 128007, 271,  # assistant header\n            5018, 609, 794, 330, 456, 11327, 54625, 498, 330, 14105, 794, 5324, 2588, 794, 330, 60704, 11, 9822, 498, 330, 3928, 794, 330, 66, 41347, 32075, 128009,  # assistant message\n            128006, 23799, 4690, 128007, 271,  # tool header\n            1, 1313, 13, 15, 1, 128009,  # tool message\n            128006, 78191, 128007, 271,  # assistant header\n            791, 9499, 304, 12366, 374, 220, 1313, 13, 15, 12628, 62447, 13, 128009  # assistant message\n        ]\n\n        expected_labels = [\n            IGNORE_TOKEN_ID,  # bos\n            IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,  # system header\n            IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,  # system date prompt\n            IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,  # system message\n            IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,  # user header\n            IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,  # user message\n            IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,  # assistant header\n            5018, 609, 794, 330, 456, 11327, 54625, 498, 330, 14105, 794, 5324, 2588, 794, 330, 60704, 11, 9822, 498, 330, 3928, 794, 330, 66, 41347, 32075, 128009,  # assistant message\n            IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,  # tool header\n            IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,  # tool message\n            IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,  # assistant header\n            791, 9499, 304, 12366, 374, 220, 1313, 13, 15, 12628, 62447, 13, 128009  # assistant message\n        ]\n        # fmt: on\n\n        assert input_ids == expected_input_ids, (\n            f\"Input IDs mismatch: {input_ids} != {expected_input_ids}\"\n        )\n\n        assert labels == expected_labels, (\n            f\"Labels mismatch: {labels} != {expected_labels}\"\n        )\n\n    def test_llama32vision_train_on_tools(\n        self, llama3_tokenizer, toolcalling_dataset, llama3_2_vision_chat_template_jinja\n    ):\n        LOG.info(\n            \"Testing assistant style datasets with tool_calling with llama-32 chat template, training on tools\"\n        )\n\n        strategy = ChatTemplateStrategy(\n            ChatTemplatePrompter(\n                llama3_tokenizer,\n                chat_template=get_chat_template(\n                    \"jinja\", jinja_template=llama3_2_vision_chat_template_jinja\n                ),\n                message_property_mappings={\"role\": \"role\", \"content\": \"content\"},\n            ),\n            tokenizer=llama3_tokenizer,\n            train_on_inputs=False,\n            train_on_eos=\"turn\",\n            sequence_len=512,\n            roles_to_train=[\"assistant\", \"tool\"],\n        )\n\n        res = strategy.tokenize_prompt(toolcalling_dataset[0])\n\n        input_ids = res[\"input_ids\"]\n        labels = res[\"labels\"]\n\n        # fmt: off\n        expected_input_ids = [\n            128000,  # bos\n            128006, 9125, 128007, 271,  # system header\n            38766, 1303, 33025, 2696, 25, 6790, 220, 2366, 18, 198, 15724, 2696, 25, 220, 1114, 3799, 220, 2366, 19, 271,  # system date prompt\n            2675, 527, 264, 11164, 430, 31680, 311, 9282, 20126, 13, 1472, 1288, 10052, 449, 279, 5089, 1511, 304, 279, 79002, 3813, 13, 128009,  # system message\n            128006, 882, 128007, 271,  # user header\n            19182, 11, 1148, 596, 279, 9499, 304, 12366, 1314, 1457, 30, 128009,  # user message\n            128006, 78191, 128007, 271,  # assistant header\n            5018, 609, 794, 330, 456, 11327, 54625, 498, 330, 14105, 794, 5324, 2588, 794, 330, 60704, 11, 9822, 498, 330, 3928, 794, 330, 66, 41347, 32075, 128009,  # assistant message\n            128006, 23799, 4690, 128007, 271,  # tool header\n            1, 1313, 13, 15, 1, 128009,  # tool message\n            128006, 78191, 128007, 271,  # assistant header\n            791, 9499, 304, 12366, 374, 220, 1313, 13, 15, 12628, 62447, 13, 128009  # assistant message\n        ]\n\n        expected_labels = [\n            IGNORE_TOKEN_ID,  # bos\n            IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,  # system header\n            IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,  # system date prompt\n            IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,  # system message\n            IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,  # user header\n            IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,  # user message\n            IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,  # assistant header\n            5018, 609, 794, 330, 456, 11327, 54625, 498, 330, 14105, 794, 5324, 2588, 794, 330, 60704, 11, 9822, 498, 330, 3928, 794, 330, 66, 41347, 32075, 128009,  # assistant message\n            IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,  # tool header\n            IGNORE_TOKEN_ID, 1313, 13, 15, IGNORE_TOKEN_ID, 128009,  # tool message\n            IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,  # assistant header\n            791, 9499, 304, 12366, 374, 220, 1313, 13, 15, 12628, 62447, 13, 128009  # assistant message\n        ]\n        # fmt: on\n\n        assert input_ids == expected_input_ids, (\n            f\"Input IDs mismatch: {input_ids} != {expected_input_ids}\"\n        )\n\n        assert labels == expected_labels, (\n            f\"Labels mismatch: {labels} != {expected_labels}\"\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/prompt_strategies/test_chat_templates_advanced.py",
    "content": "\"\"\"\ntests for chat_template prompt strategy\n\"\"\"\n\nfrom copy import deepcopy\n\nimport pytest\nfrom datasets import Dataset\nfrom tokenizers import AddedToken\nfrom transformers import PreTrainedTokenizer\n\nfrom axolotl.prompt_strategies.chat_template import (\n    ChatTemplatePrompter,\n    ChatTemplateStrategy,\n)\nfrom axolotl.prompters import IGNORE_TOKEN_ID\nfrom axolotl.utils.chat_templates import get_chat_template\nfrom axolotl.utils.logging import get_logger\n\nfrom tests.hf_offline_utils import enable_hf_offline\n\nLOG = get_logger(__name__)\n\nPARAMETRIZE_KEYS = \"tokenizer, chat_template, chat_template_jinja, eos_token\"\nPARAMETRIZE_PARAMS = [\n    (\"llama3_tokenizer\", \"llama3\", None, None),\n    (\"llama3_tokenizer\", \"chatml\", None, \"<|im_end|>\"),\n    (\n        \"mistralv03_tokenizer\",\n        \"jinja\",\n        \"mistralv03_tokenizer_chat_template_jinja\",\n        \"[/INST]\",\n    ),\n    (\n        \"gemma2_tokenizer\",\n        \"jinja\",\n        \"gemma2_tokenizer_chat_template_jinja\",\n        \"<end_of_turn>\",\n    ),\n    # (\"phi35_tokenizer\", \"phi_35\", None, \"<|end|>\"),  # seems to be broken w transformers v5\n    (\"phi4_tokenizer\", \"phi_4\", None, \"<|im_end|>\"),\n]\n\n\n@pytest.mark.parametrize(\n    PARAMETRIZE_KEYS,\n    PARAMETRIZE_PARAMS,\n)\nclass TestChatTemplateConfigurations:\n    \"\"\"\n    Test class for various configurations of ChatTemplateStrategy.\n    \"\"\"\n\n    @staticmethod\n    def setup_tokenizer(\n        tokenizer_name,\n        chat_template,\n        chat_template_jinja=None,\n        eos_token=None,\n        request=None,\n        eot_token=None,\n    ) -> tuple[PreTrainedTokenizer, str]:\n        \"\"\"\n        Helper function to set up the tokenizer and chat template for the test.\n        \"\"\"\n        tokenizer = deepcopy(request.getfixturevalue(tokenizer_name))\n        if chat_template == \"jinja\":\n            chat_template_jinja = request.getfixturevalue(chat_template_jinja)\n        if eos_token:\n            tokenizer.add_special_tokens(\n                {\n                    \"eos_token\": AddedToken(\n                        eos_token, rstrip=False, lstrip=False, normalized=False\n                    )\n                }\n            )\n            if tokenizer.__class__.__name__ in (\n                \"LlamaTokenizerFast\",\n                \"CodeLlamaTokenizerFast\",\n            ):\n                tokenizer.update_post_processor()\n\n        if eot_token:\n            tokenizer.add_special_tokens({\"additional_special_tokens\": [eot_token]})\n\n        return tokenizer, chat_template_jinja\n\n    def _should_skip_turn(self, tokenizer, turn, turn_idx, start_idx, end_idx):\n        \"\"\"Helper method to determine if a turn should be skipped in testing.\n        This is used to skip system messages for Mistral as the template does not output them without more turns.\n        \"\"\"\n        if (\n            turn_idx == 0\n            and turn.get(\"from\") in [\"system\", \"context\"]\n            and (\"mistral\" in tokenizer.name_or_path.lower())\n        ):\n            assert start_idx == -1 and end_idx == -1, (\n                \"Expected system message to be skipped\"\n            )\n            return True\n        return False\n\n    @enable_hf_offline\n    def test_train_on_inputs_true(\n        self,\n        tokenizer,\n        chat_template,\n        chat_template_jinja,\n        eos_token,\n        basic_dataset,\n        request,\n    ):\n        LOG.info(\"Testing with train_on_inputs=True\")\n\n        tokenizer, chat_template_jinja = self.setup_tokenizer(\n            tokenizer, chat_template, chat_template_jinja, eos_token, request\n        )\n\n        strategy = ChatTemplateStrategy(\n            ChatTemplatePrompter(\n                tokenizer,\n                chat_template=get_chat_template(\n                    chat_template, jinja_template=chat_template_jinja\n                ),\n                message_property_mappings={\"role\": \"from\", \"content\": \"value\"},\n                field_messages=\"conversations\",\n            ),\n            tokenizer=tokenizer,\n            train_on_inputs=True,\n            sequence_len=512,\n            roles_to_train=[\"assistant\"],\n        )\n\n        res = strategy.tokenize_prompt(basic_dataset[0])\n        turns = strategy.get_conversation_thread(basic_dataset[0])\n        labels = res[\"labels\"]\n        input_ids = res[\"input_ids\"]\n\n        # Verify assistant responses are labeled\n        for i, turn in enumerate(basic_dataset[0][\"conversations\"]):\n            start_idx, end_idx = strategy.find_turn(turns=turns, turn_idx=i)\n\n            if self._should_skip_turn(tokenizer, turn, i, start_idx, end_idx):\n                continue\n\n            decoded_response = tokenizer.decode(input_ids[start_idx:end_idx])\n            response = turn[\"value\"]\n\n            assert response in decoded_response, (\n                f\"Response {response} not found in index {start_idx}:{end_idx} \"\n                f\"decoded:{decoded_response}\"\n            )\n\n            assert all(\n                label != IGNORE_TOKEN_ID for label in labels[start_idx:end_idx]\n            ), (\n                f\"Expected labels for input '{response}' to be ignored, but got {labels[start_idx:end_idx]}\"\n            )\n\n        LOG.debug(\"Full labels: %s\", labels)\n        LOG.debug(\"Full input_ids: %s\", input_ids)\n\n    def test_train_on_inputs_false(\n        self,\n        tokenizer,\n        chat_template,\n        chat_template_jinja,\n        eos_token,\n        basic_dataset,\n        request,\n    ):\n        LOG.info(\"Testing with train_on_inputs=False, on assistant only\")\n\n        tokenizer, chat_template_jinja = self.setup_tokenizer(\n            tokenizer, chat_template, chat_template_jinja, eos_token, request\n        )\n\n        strategy = ChatTemplateStrategy(\n            ChatTemplatePrompter(\n                tokenizer,\n                chat_template=get_chat_template(\n                    chat_template, jinja_template=chat_template_jinja\n                ),\n                message_property_mappings={\"role\": \"from\", \"content\": \"value\"},\n                field_messages=\"conversations\",\n            ),\n            tokenizer=tokenizer,\n            train_on_inputs=False,\n            sequence_len=512,\n            roles_to_train=[\"assistant\"],\n        )\n\n        res = strategy.tokenize_prompt(basic_dataset[0])\n        turns = strategy.get_conversation_thread(basic_dataset[0])\n        labels = res[\"labels\"]\n        input_ids = res[\"input_ids\"]\n\n        # Process all turns and verify correct labeling based on role\n        for i, turn in enumerate(basic_dataset[0][\"conversations\"]):\n            start_idx, end_idx = strategy.find_turn(turns=turns, turn_idx=i)\n\n            if self._should_skip_turn(tokenizer, turn, i, start_idx, end_idx):\n                continue\n\n            decoded_response = tokenizer.decode(input_ids[start_idx:end_idx])\n            response = turn[\"value\"]\n\n            assert response in decoded_response, (\n                f\"Response {response} not found in index {start_idx}:{end_idx} \"\n                f\"decoded:{decoded_response}\"\n            )\n\n            # Verify that assistant responses are labeled and other inputs are not\n            is_assistant = turn[\"from\"] == \"assistant\"\n            if is_assistant:\n                assert all(\n                    label != IGNORE_TOKEN_ID for label in labels[start_idx:end_idx]\n                ), (\n                    f\"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:end_idx]}\"\n                )\n            else:\n                assert all(\n                    label == IGNORE_TOKEN_ID for label in labels[start_idx:end_idx]\n                ), (\n                    f\"Expected labels for human input '{response}' to be IGNORE_TOKEN_ID, but got {labels[start_idx:end_idx]}\"\n                )\n\n    def test_roles_to_train_human_assistant_only(\n        self,\n        tokenizer,\n        chat_template,\n        chat_template_jinja,\n        eos_token,\n        basic_dataset,\n        request,\n    ):\n        LOG.info(\"Testing roles_to_train with human assistant only\")\n\n        tokenizer, chat_template_jinja = self.setup_tokenizer(\n            tokenizer, chat_template, chat_template_jinja, eos_token, request\n        )\n\n        strategy = ChatTemplateStrategy(\n            ChatTemplatePrompter(\n                tokenizer,\n                chat_template=get_chat_template(\n                    chat_template, jinja_template=chat_template_jinja\n                ),\n                message_property_mappings={\"role\": \"from\", \"content\": \"value\"},\n                field_messages=\"conversations\",\n            ),\n            tokenizer=tokenizer,\n            train_on_inputs=False,\n            sequence_len=512,\n            roles_to_train=[\"assistant\", \"human\"],\n        )\n\n        res = strategy.tokenize_prompt(basic_dataset[0])\n        turns = strategy.get_conversation_thread(basic_dataset[0])\n        labels = res[\"labels\"]\n        input_ids = res[\"input_ids\"]\n\n        # Process all turns and verify correct labeling based on role\n        for i, turn in enumerate(basic_dataset[0][\"conversations\"]):\n            start_idx, end_idx = strategy.find_turn(turns=turns, turn_idx=i)\n\n            if self._should_skip_turn(tokenizer, turn, i, start_idx, end_idx):\n                continue\n\n            decoded_response = tokenizer.decode(input_ids[start_idx:end_idx])\n            response = turn[\"value\"]\n\n            assert response in decoded_response, (\n                f\"Response {response} not found in index {start_idx}:{end_idx} \"\n                f\"decoded:{decoded_response}\"\n            )\n\n            # Verify that non-system responses are labeled and system are not\n            should_be_labelled = turn[\"from\"] != \"system\"\n            if should_be_labelled:\n                assert all(\n                    label != IGNORE_TOKEN_ID for label in labels[start_idx:end_idx]\n                ), (\n                    f\"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:end_idx]}\"\n                )\n            else:\n                assert all(\n                    label == IGNORE_TOKEN_ID for label in labels[start_idx:end_idx]\n                ), (\n                    f\"Expected labels for human input '{response}' to be IGNORE_TOKEN_ID, but got {labels[start_idx:end_idx]}\"\n                )\n\n    def test_roles_to_train_all(\n        self,\n        tokenizer,\n        chat_template,\n        chat_template_jinja,\n        eos_token,\n        basic_dataset,\n        request,\n    ):\n        LOG.info(\"Testing roles_to_train with all roles\")\n\n        tokenizer, chat_template_jinja = self.setup_tokenizer(\n            tokenizer, chat_template, chat_template_jinja, eos_token, request\n        )\n\n        strategy = ChatTemplateStrategy(\n            ChatTemplatePrompter(\n                tokenizer,\n                chat_template=get_chat_template(\n                    chat_template, jinja_template=chat_template_jinja\n                ),\n                message_property_mappings={\"role\": \"from\", \"content\": \"value\"},\n                field_messages=\"conversations\",\n            ),\n            tokenizer=tokenizer,\n            train_on_inputs=True,\n            sequence_len=512,\n            roles_to_train=[\"human\", \"assistant\"],\n        )\n\n        res = strategy.tokenize_prompt(basic_dataset[0])\n        turns = strategy.get_conversation_thread(basic_dataset[0])\n        labels = res[\"labels\"]\n        input_ids = res[\"input_ids\"]\n\n        # Verify that all responses are labeled (except for special tokens)\n        for i, turn in enumerate(basic_dataset[0][\"conversations\"]):\n            response = turn[\"value\"]\n\n            start_idx, end_idx = strategy.find_turn(turns=turns, turn_idx=i)\n\n            if self._should_skip_turn(tokenizer, turn, i, start_idx, end_idx):\n                continue\n\n            decoded_response = tokenizer.decode(input_ids[start_idx:end_idx])\n            assert response in decoded_response, (\n                f\"Response {response} not found in index {start_idx}:{end_idx} decoded:{decoded_response}\"\n            )\n\n            assert all(\n                label != IGNORE_TOKEN_ID for label in labels[start_idx:end_idx]\n            ), (\n                f\"Expected labels for response '{response}' to be set, but got {labels[start_idx:end_idx]}\"\n            )\n\n    def test_empty_roles_to_train(\n        self,\n        tokenizer,\n        chat_template,\n        chat_template_jinja,\n        eos_token,\n        basic_dataset,\n        request,\n    ):\n        LOG.info(\"Testing with empty roles_to_train\")\n\n        tokenizer, chat_template_jinja = self.setup_tokenizer(\n            tokenizer, chat_template, chat_template_jinja, eos_token, request\n        )\n\n        strategy = ChatTemplateStrategy(\n            ChatTemplatePrompter(\n                tokenizer,\n                chat_template=get_chat_template(\n                    chat_template, jinja_template=chat_template_jinja\n                ),\n                message_property_mappings={\"role\": \"from\", \"content\": \"value\"},\n                field_messages=\"conversations\",\n            ),\n            tokenizer=tokenizer,\n            train_on_inputs=False,\n            sequence_len=512,\n            roles_to_train=[],\n            train_on_eos=\"none\",  # Add this line\n        )\n\n        res = strategy.tokenize_prompt(basic_dataset[0])\n        labels = res[\"labels\"]\n\n        # Verify that no labels are set when roles_to_train is empty\n        LOG.debug(\"Full labels: %s\", labels)\n        assert all(label == IGNORE_TOKEN_ID for label in labels), (\n            \"Expected all labels to be IGNORE_TOKEN_ID when roles_to_train is empty\"\n        )\n\n    def test_train_on_eos_all(\n        self,\n        tokenizer,\n        chat_template,\n        chat_template_jinja,\n        eos_token,\n        basic_dataset,\n        request,\n    ):\n        LOG.info(\"Testing with train_on_eos='all'\")\n\n        tokenizer, chat_template_jinja = self.setup_tokenizer(\n            tokenizer, chat_template, chat_template_jinja, eos_token, request\n        )\n\n        strategy = ChatTemplateStrategy(\n            ChatTemplatePrompter(\n                tokenizer,\n                chat_template=get_chat_template(\n                    chat_template, jinja_template=chat_template_jinja\n                ),\n                message_property_mappings={\"role\": \"from\", \"content\": \"value\"},\n                field_messages=\"conversations\",\n            ),\n            tokenizer=tokenizer,\n            train_on_inputs=False,\n            sequence_len=512,\n            roles_to_train=[\"assistant\"],\n            train_on_eos=\"all\",\n        )\n\n        res = strategy.tokenize_prompt(basic_dataset[0])\n        labels = res[\"labels\"]\n        input_ids = res[\"input_ids\"]\n\n        eos_token_id = tokenizer.eos_token_id\n        eos_indices = [\n            i for i, token_id in enumerate(input_ids) if token_id == eos_token_id\n        ]\n\n        assert len(eos_indices) > 0, \"Expected at least one EOS token in the input\"\n        for eos_idx in eos_indices:\n            assert labels[eos_idx] != IGNORE_TOKEN_ID, (\n                f\"Expected EOS token at index {eos_idx} to be labeled\"\n            )\n\n    def test_train_on_eos_turn(\n        self,\n        tokenizer,\n        chat_template,\n        chat_template_jinja,\n        eos_token,\n        basic_dataset,\n        request,\n    ):\n        LOG.info(\"Testing with train_on_eos='turn'\")\n\n        tokenizer, chat_template_jinja = self.setup_tokenizer(\n            tokenizer, chat_template, chat_template_jinja, eos_token, request\n        )\n\n        strategy = ChatTemplateStrategy(\n            ChatTemplatePrompter(\n                tokenizer,\n                chat_template=get_chat_template(\n                    chat_template, jinja_template=chat_template_jinja\n                ),\n                message_property_mappings={\"role\": \"from\", \"content\": \"value\"},\n                field_messages=\"conversations\",\n            ),\n            tokenizer=tokenizer,\n            train_on_inputs=False,\n            sequence_len=512,\n            roles_to_train=[\"assistant\"],\n            train_on_eos=\"turn\",\n        )\n        res = strategy.tokenize_prompt(basic_dataset[0])\n        turns = strategy.get_conversation_thread(basic_dataset[0])\n        labels = res[\"labels\"]\n        input_ids = res[\"input_ids\"]\n\n        eos_token_id = tokenizer.eos_token_id\n        # Process all turns and verify EOS token labeling\n        for i, turn in enumerate(basic_dataset[0][\"conversations\"]):\n            start_idx, end_idx = strategy.find_turn(turns=turns, turn_idx=i)\n\n            if self._should_skip_turn(tokenizer, turn, i, start_idx, end_idx):\n                continue\n\n            decoded_response = tokenizer.decode(input_ids[start_idx:end_idx])\n            response = turn[\"value\"]\n\n            assert response in decoded_response, (\n                f\"Response {response} not found in index {start_idx}:{end_idx} \"\n                f\"decoded:{decoded_response}\"\n            )\n\n            # Find the EOS token after this turn\n            eos_idx = end_idx\n            while eos_idx < len(input_ids) and input_ids[eos_idx] != eos_token_id:\n                eos_idx += 1\n\n            assert eos_idx < len(input_ids), (\n                f\"Could not find EOS token after '{response}'\"\n            )\n\n            LOG.debug(\n                f\"Turn {i}: role={turn['from']}, content='{turn['value']}', start_idx={start_idx}, end_idx={end_idx}, eos_idx={eos_idx}\"\n            )\n\n            LOG.debug(\n                f\"Labels for turn {i}: {labels[start_idx:end_idx]}, EOS label: {labels[eos_idx]}\"\n            )\n\n            # Verify EOS token labeling based on role\n            is_assistant = turn[\"from\"] == \"assistant\"\n            if is_assistant:\n                assert labels[eos_idx] != IGNORE_TOKEN_ID, (\n                    f\"Expected EOS token after assistant response '{response}' to be labeled\"\n                )\n            else:\n                assert labels[eos_idx] == IGNORE_TOKEN_ID, (\n                    f\"Expected EOS token after non-assistant input '{response}' to not be labeled\"\n                )\n\n    def test_train_on_eos_last(\n        self,\n        tokenizer,\n        chat_template,\n        chat_template_jinja,\n        eos_token,\n        basic_dataset,\n        request,\n    ):\n        LOG.info(\"Testing with train_on_eos='last'\")\n\n        tokenizer, chat_template_jinja = self.setup_tokenizer(\n            tokenizer, chat_template, chat_template_jinja, eos_token, request\n        )\n\n        strategy = ChatTemplateStrategy(\n            ChatTemplatePrompter(\n                tokenizer,\n                chat_template=get_chat_template(\n                    chat_template, jinja_template=chat_template_jinja\n                ),\n                message_property_mappings={\"role\": \"from\", \"content\": \"value\"},\n                field_messages=\"conversations\",\n            ),\n            tokenizer=tokenizer,\n            train_on_inputs=False,\n            sequence_len=512,\n            roles_to_train=[\"assistant\"],\n            train_on_eos=\"last\",\n        )\n\n        res = strategy.tokenize_prompt(basic_dataset[0])\n        labels = res[\"labels\"]\n        input_ids = res[\"input_ids\"]\n\n        eos_token_id = tokenizer.eos_token_id\n        eos_indices = [\n            i for i, token_id in enumerate(input_ids) if token_id == eos_token_id\n        ]\n\n        assert len(eos_indices) > 0, \"Expected at least one EOS token in the input\"\n        last_eos_idx = eos_indices[-1]\n\n        # Check that only the last EOS token is labeled\n        for idx in eos_indices[:-1]:\n            assert labels[idx] == IGNORE_TOKEN_ID, (\n                f\"Expected EOS token at index {idx} to not be labeled\"\n            )\n        assert labels[last_eos_idx] != IGNORE_TOKEN_ID, (\n            f\"Expected last EOS token at index {last_eos_idx} to be labeled\"\n        )\n\n    def test_train_on_eos_none(\n        self,\n        tokenizer,\n        chat_template,\n        chat_template_jinja,\n        eos_token,\n        basic_dataset,\n        request,\n    ):\n        LOG.info(\"Testing with train_on_eos='none'\")\n\n        tokenizer, chat_template_jinja = self.setup_tokenizer(\n            tokenizer, chat_template, chat_template_jinja, eos_token, request\n        )\n\n        strategy = ChatTemplateStrategy(\n            ChatTemplatePrompter(\n                tokenizer,\n                chat_template=get_chat_template(\n                    chat_template, jinja_template=chat_template_jinja\n                ),\n                message_property_mappings={\"role\": \"from\", \"content\": \"value\"},\n                field_messages=\"conversations\",\n            ),\n            tokenizer=tokenizer,\n            train_on_inputs=False,\n            sequence_len=512,\n            roles_to_train=[\"assistant\"],\n            train_on_eos=\"none\",\n        )\n\n        res = strategy.tokenize_prompt(basic_dataset[0])\n        labels = res[\"labels\"]\n        input_ids = res[\"input_ids\"]\n\n        eos_token_id = tokenizer.eos_token_id\n        eos_indices = [\n            i for i, token_id in enumerate(input_ids) if token_id == eos_token_id\n        ]\n\n        assert len(eos_indices) > 0, \"Expected at least one EOS token in the input\"\n        for eos_idx in eos_indices:\n            assert labels[eos_idx] == IGNORE_TOKEN_ID, (\n                f\"Expected EOS token at index {eos_idx} to not be labeled\"\n            )\n\n    def test_drop_system_message(\n        self,\n        tokenizer,\n        chat_template,\n        chat_template_jinja,\n        eos_token,\n        basic_dataset,\n        request,\n    ):\n        LOG.info(\"Testing with drop_system_message=True\")\n        tokenizer, chat_template_jinja = self.setup_tokenizer(\n            tokenizer, chat_template, chat_template_jinja, eos_token, request\n        )\n\n        strategy = ChatTemplateStrategy(\n            ChatTemplatePrompter(\n                tokenizer,\n                chat_template=get_chat_template(\n                    chat_template, jinja_template=chat_template_jinja\n                ),\n                drop_system_message=True,\n                message_property_mappings={\"role\": \"from\", \"content\": \"value\"},\n                field_messages=\"conversations\",\n            ),\n            tokenizer=tokenizer,\n            train_on_inputs=False,\n            sequence_len=512,\n            roles_to_train=[\"assistant\"],\n        )\n\n        res = strategy.tokenize_prompt(basic_dataset[0])\n        input_ids = res[\"input_ids\"]\n\n        # Check if system message is not present in input_ids\n        system_message = \"You are an AI assistant.\"\n        decoded_message = tokenizer.decode(input_ids)\n        assert system_message not in decoded_message, (\n            \"Expected system message to be dropped\"\n        )\n\n    def test_custom_roles(\n        self,\n        tokenizer,\n        chat_template,\n        chat_template_jinja,\n        eos_token,\n        request,\n    ):\n        LOG.info(\"Testing with custom roles mapping\")\n        custom_roles = {\n            \"user\": [\"human\", \"user\"],\n            \"assistant\": [\"ai\", \"assistant\"],\n            \"system\": [\"context\"],\n        }\n        tokenizer, chat_template_jinja = self.setup_tokenizer(\n            tokenizer, chat_template, chat_template_jinja, eos_token, request\n        )\n\n        strategy = ChatTemplateStrategy(\n            ChatTemplatePrompter(\n                tokenizer,\n                chat_template=get_chat_template(\n                    chat_template, jinja_template=chat_template_jinja\n                ),\n                roles=custom_roles,\n                message_property_mappings={\"role\": \"from\", \"content\": \"value\"},\n            ),\n            tokenizer=tokenizer,\n            train_on_inputs=False,\n            sequence_len=512,\n            roles_to_train=[\"ai\"],\n        )\n\n        # Create a new dataset with modified role names\n        modified_conversations = [\n            {\"from\": \"context\", \"value\": \"You are an AI assistant.\"},\n            {\"from\": \"human\", \"value\": \"Hello\"},\n            {\"from\": \"ai\", \"value\": \"Hi there!\"},\n            {\"from\": \"human\", \"value\": \"How are you?\"},\n            {\"from\": \"ai\", \"value\": \"I'm doing well, thank you!\"},\n        ]\n\n        modified_dataset = Dataset.from_dict({\"messages\": [modified_conversations]})\n\n        res = strategy.tokenize_prompt(modified_dataset[0])\n        turns = strategy.get_conversation_thread(modified_dataset[0])\n        labels = res[\"labels\"]\n        input_ids = res[\"input_ids\"]\n\n        # Process all turns and verify labeling\n        for i, turn in enumerate(modified_dataset[0][\"messages\"]):\n            start_idx, end_idx = strategy.find_turn(turns=turns, turn_idx=i)\n\n            if self._should_skip_turn(tokenizer, turn, i, start_idx, end_idx):\n                continue\n\n            decoded_response = tokenizer.decode(input_ids[start_idx:end_idx])\n            response = turn[\"value\"]\n\n            assert response in decoded_response, (\n                f\"Response {response} not found in index {start_idx}:{end_idx} \"\n                f\"decoded:{decoded_response}\"\n            )\n\n            # Check if responses are labeled correctly based on role\n            is_ai = turn[\"from\"] == \"ai\"\n            if is_ai:\n                assert all(\n                    label != IGNORE_TOKEN_ID for label in labels[start_idx:end_idx]\n                ), f\"Expected labels for AI response '{response}' to be set\"\n            else:\n                assert all(\n                    label == IGNORE_TOKEN_ID for label in labels[start_idx:end_idx]\n                ), (\n                    f\"Expected labels for non-AI message '{response}' to be IGNORE_TOKEN_ID\"\n                )\n\n    def test_message_field_training(\n        self,\n        tokenizer,\n        chat_template,\n        chat_template_jinja,\n        eos_token,\n        request,\n    ):\n        LOG.info(\"Testing with message_field_training\")\n\n        tokenizer, chat_template_jinja = self.setup_tokenizer(\n            tokenizer, chat_template, chat_template_jinja, eos_token, request\n        )\n\n        strategy = ChatTemplateStrategy(\n            ChatTemplatePrompter(\n                tokenizer,\n                chat_template=get_chat_template(\n                    chat_template, jinja_template=chat_template_jinja\n                ),\n                message_field_training=\"train\",\n                message_field_training_detail=\"train_detail\",\n                message_property_mappings={\"role\": \"from\", \"content\": \"value\"},\n            ),\n            tokenizer=tokenizer,\n            train_on_inputs=False,\n            sequence_len=512,\n            roles_to_train=[],\n        )\n\n        # Create a new dataset with the train and train_detail fields\n        modified_conversation = [\n            {\"from\": \"system\", \"value\": \"You are an AI assistant.\", \"train\": False},\n            {\"from\": \"human\", \"value\": \"Hello\", \"train\": False},\n            {\"from\": \"assistant\", \"value\": \"Hello\", \"train\": True},\n            {\"from\": \"human\", \"value\": \"How are you?\", \"train\": True},\n            {\n                \"from\": \"assistant\",\n                \"value\": \"I'm doing very well, thank you!\",\n                \"train_detail\": [\n                    {\"begin_offset\": 0, \"end_offset\": 8, \"train\": False},\n                    {\"begin_offset\": 9, \"end_offset\": 18, \"train\": True},\n                    {\"begin_offset\": 19, \"end_offset\": 30, \"train\": False},\n                ],\n            },\n            {\n                \"from\": \"human\",\n                \"value\": \"I'm doing very well, thank you!\",\n                \"train\": False,\n            },\n            {\"from\": \"assistant\", \"value\": \"Hi there!\", \"train\": True},\n        ]\n\n        modified_dataset = Dataset.from_dict({\"messages\": [modified_conversation]})\n\n        res = strategy.tokenize_prompt(modified_dataset[0])\n        turns = strategy.get_conversation_thread(modified_dataset[0])\n        labels = res[\"labels\"]\n        input_ids = res[\"input_ids\"]\n\n        def verify_labels(labels_span, should_train, context_message):\n            \"\"\"Helper to verify if a span of labels matches expected training state\"\"\"\n            if should_train:\n                assert all(label != IGNORE_TOKEN_ID for label in labels_span), (\n                    f\"Expected all labels for {context_message} to be set, but got {labels_span}\"\n                )\n            else:\n                assert all(label == IGNORE_TOKEN_ID for label in labels_span), (\n                    f\"Expected all labels for {context_message} to be {IGNORE_TOKEN_ID}, but got {labels_span}\"\n                )\n\n        # Process all turns and verify labeling\n        for i, turn in enumerate(modified_dataset[0][\"messages\"]):\n            start_idx, end_idx = strategy.find_turn(turns=turns, turn_idx=i)\n\n            if self._should_skip_turn(tokenizer, turn, i, start_idx, end_idx):\n                continue\n\n            decoded_response = tokenizer.decode(input_ids[start_idx:end_idx])\n            response = turn[\"value\"]\n\n            assert response in decoded_response, (\n                f\"Response {response} not found in index {start_idx}:{end_idx} \"\n                f\"decoded:{decoded_response}\"\n            )\n\n            LOG.debug(\n                f\"Processing turn {i}: role={turn['from']}, content='{turn['value']}', \"\n                f\"start_idx={start_idx}, end_idx={end_idx}\"\n            )\n\n            if turn.get(\"train_detail\", None) is not None:\n                # Handle detailed token-level training control\n                tokenized_output = tokenizer(\n                    turn[\"value\"], return_offsets_mapping=True, add_special_tokens=False\n                )\n                assert tokenized_output[\"input_ids\"] == input_ids[start_idx:end_idx], (\n                    f\"Tokenized input mismatch for turn: {turn['value']}\\n\"\n                    f\"Expected: {input_ids[start_idx:end_idx]}\\nActual: {tokenized_output['input_ids']}\\n\"\n                    f\"This will likely be a mismatch between template content and encoded content\"\n                )\n\n                token_offsets = tokenized_output[\"offset_mapping\"]\n\n                # Adjust token offsets\n                for j in range(len(token_offsets) - 1):\n                    token_offsets[j] = (\n                        token_offsets[j][0],\n                        token_offsets[j + 1][0] - 1,\n                    )\n                token_offsets[-1] = (token_offsets[-1][0], len(turn[\"value\"]) - 1)\n\n                adjusted_train_details = strategy.prompter.adjust_train_details(\n                    turn[\"train_detail\"], token_offsets\n                )\n\n                LOG.debug(f\"Original train_details: {turn['train_detail']}\")\n                LOG.debug(f\"Adjusted train_details: {adjusted_train_details}\")\n\n                # Get and verify token offsets\n                turn_tokens = input_ids[start_idx:end_idx]\n                token_offsets_unmasked = strategy.prompter.get_offsets_for_train_detail(\n                    text=turn[\"value\"],\n                    train_details=adjusted_train_details,\n                    mask_untrainable=False,\n                )\n\n                for i, offset in enumerate(token_offsets_unmasked):\n                    assert token_offsets[i][0] == offset, (\n                        f\"Token start offsets mismatch for turn: {turn['value']}\\n\"\n                        f\"Expected: {token_offsets[i][0]}\\nActual: {offset}\"\n                    )\n\n                token_offsets_masked = strategy.prompter.get_offsets_for_train_detail(\n                    text=turn[\"value\"],\n                    train_details=adjusted_train_details,\n                    mask_untrainable=True,\n                )\n                LOG.debug(f\"Token offsets: {token_offsets_masked}\")\n\n                # Verify expected labels against actual labels\n                expected_labels = [IGNORE_TOKEN_ID] * len(turn_tokens)\n                for i, offset in enumerate(token_offsets_masked):\n                    if offset != IGNORE_TOKEN_ID:\n                        expected_labels[i] = turn_tokens[i]\n                actual_labels = labels[\n                    start_idx : start_idx + len(token_offsets_masked)\n                ]\n                assert actual_labels == expected_labels, (\n                    f\"Labels mismatch for turn: {turn['value']}\\nExpected: {expected_labels}\\nActual: {actual_labels}\"\n                )\n\n                # Verify each detail section\n                for detail in adjusted_train_details:\n                    detail_start = start_idx + next(\n                        j\n                        for j, offset in enumerate(token_offsets_unmasked)\n                        if offset >= detail[\"begin_offset\"]\n                    )\n                    detail_end = start_idx + next(\n                        (\n                            j\n                            for j, offset in enumerate(token_offsets_unmasked)\n                            if offset > detail[\"end_offset\"]\n                        ),\n                        len(token_offsets),\n                    )\n\n                    detail_text = turn[\"value\"][\n                        detail[\"begin_offset\"] : detail[\"end_offset\"] + 1\n                    ]\n                    detail_labels = labels[detail_start:detail_end]\n\n                    context = (\n                        f\"detail (ind {detail_start}:{detail_end}): '{detail_text}'\\n\"\n                        f\"decoded: '{tokenizer.decode(input_ids[detail_start:detail_end])}')\"\n                    )\n                    verify_labels(detail_labels, detail[\"train\"], context)\n            else:\n                # Handle regular turn-level training control\n                should_train = turn.get(\"train\", False)\n                turn_labels = labels[start_idx:end_idx]\n                context = (\n                    f\"turn (ind {start_idx}:{end_idx}): '{turn['value']}'\\n\"\n                    f\"decoded: '{decoded_response}')\"\n                )\n                verify_labels(turn_labels, should_train, context)\n\n        LOG.debug(f\"Final labels: {labels}\")\n        LOG.debug(f\"Final input_ids: {input_ids}\")\n\n    def test_get_chat_template_variables(\n        self, tokenizer, chat_template, chat_template_jinja, eos_token, request\n    ):\n        LOG.info(\"Testing get_chat_template_variables\")\n\n        actual_tokenizer, actual_jinja_template = self.setup_tokenizer(\n            tokenizer, chat_template, chat_template_jinja, eos_token, request\n        )\n\n        prompter = ChatTemplatePrompter(\n            actual_tokenizer,\n            chat_template=get_chat_template(\n                chat_template, jinja_template=actual_jinja_template\n            ),\n            message_property_mappings={\"from\": \"role\", \"value\": \"content\"},\n        )\n\n        variables = prompter.get_chat_template_msg_variables(\n            (\n                actual_jinja_template\n                if actual_jinja_template\n                else actual_tokenizer.get_chat_template()\n            ),\n            \"messages\",\n        )\n\n        # Special case for Mistral with additional tool variables\n        if chat_template == \"jinja\" and tokenizer == \"mistralv03_tokenizer\":\n            expected_variables = {\"role\", \"content\", \"tool_call_id\", \"tool_calls\"}\n        # Most chat templates use the standard role and content variables\n        elif chat_template in [\"llama3\", \"chatml\", \"phi_35\", \"phi_4\"] or (\n            chat_template == \"jinja\" and tokenizer == \"gemma2_tokenizer\"\n        ):\n            expected_variables = {\"role\", \"content\"}\n        else:\n            LOG.warning(\n                f\"Unsupported chat template: {chat_template} with {chat_template_jinja}\"\n            )\n            raise ValueError(\n                f\"Unsupported chat template: {chat_template} with {chat_template_jinja}\"\n            )\n\n        assert variables == expected_variables, (\n            f\"Expected variables: {expected_variables} from {tokenizer}/{chat_template}\\n\"\n            f\"Got: {variables}\\n\"\n            f\"Chat template: {actual_jinja_template}\"\n        )\n\n    def test_eot_tokens_conflict_with_eos_token(\n        self,\n        tokenizer,\n        chat_template,\n        chat_template_jinja,\n        eos_token,\n        basic_dataset,\n        request,\n    ):\n        \"\"\"Test that an error is raised when eot_tokens contains eos_token and train_on_eot/train_on_eos conflict\"\"\"\n        LOG.info(\n            \"Testing conflict between eot_tokens containing eos_token and train_on_eot/train_on_eos mismatch\"\n        )\n\n        tokenizer, chat_template_jinja = self.setup_tokenizer(\n            tokenizer, chat_template, chat_template_jinja, eos_token, request\n        )\n\n        # Create a situation where eot_tokens contains eos_token\n        eot_tokens = [\n            tokenizer.eos_token,\n            \"[/INST]\",\n        ]  # Deliberately including eos_token\n\n        # Create conflicting train_on_eos and train_on_eot settings\n        with pytest.raises(\n            ValueError,\n            match=\".*eos_token is in eot_tokens and train_on_eos != train_on_eot.*\",\n        ):\n            ChatTemplateStrategy(\n                ChatTemplatePrompter(\n                    tokenizer,\n                    chat_template=get_chat_template(\n                        chat_template, jinja_template=chat_template_jinja\n                    ),\n                    message_property_mappings={\"role\": \"from\", \"content\": \"value\"},\n                    field_messages=\"conversations\",\n                ),\n                tokenizer=tokenizer,\n                train_on_inputs=False,\n                sequence_len=512,\n                roles_to_train=[\"assistant\"],\n                train_on_eos=\"none\",  # Setting to none\n                train_on_eot=\"turn\",  # Different from train_on_eos\n                eot_tokens=eot_tokens,\n            )\n\n    def test_eot_token_backward_compatibility(\n        self,\n        tokenizer,\n        chat_template,\n        chat_template_jinja,\n        eos_token,\n        basic_dataset,\n        request,\n    ):\n        \"\"\"Test that eot_tokens inherits from eos_token when not specified\"\"\"\n        LOG.info(\"Testing backward compatibility that eot_token inherits eos_token\")\n\n        tokenizer, chat_template_jinja = self.setup_tokenizer(\n            tokenizer, chat_template, chat_template_jinja, eos_token, request\n        )\n\n        strategy = ChatTemplateStrategy(\n            ChatTemplatePrompter(\n                tokenizer,\n                chat_template=get_chat_template(\n                    chat_template, jinja_template=chat_template_jinja\n                ),\n                message_property_mappings={\"role\": \"from\", \"content\": \"value\"},\n                field_messages=\"conversations\",\n            ),\n            tokenizer=tokenizer,\n            train_on_inputs=False,\n            sequence_len=512,\n            roles_to_train=[\"assistant\"],\n            train_on_eos=\"turn\",  # Setting train_on_eos to \"turn\"\n        )\n\n        # In backward compatibility mode, eot_tokens should be derived from eos_token\n        assert strategy.eot_tokens == [tokenizer.eos_token], (\n            f\"Expected eot_tokens to inherit from eos_token, got {strategy.eot_tokens}\"\n        )\n        assert strategy.train_on_eot == \"turn\", (\n            f\"Expected train_on_eot to inherit from train_on_eos, got {strategy.train_on_eot}\"\n        )\n\n    def test_token_not_in_template(\n        self,\n        tokenizer,\n        chat_template,\n        chat_template_jinja,\n        eos_token,\n        basic_dataset,\n        request,\n    ):\n        \"\"\"Test runs even when tokens are not found in the template\"\"\"\n        LOG.info(\"Testing runs even when tokens are not found in template\")\n\n        tokenizer, chat_template_jinja = self.setup_tokenizer(\n            tokenizer, chat_template, chat_template_jinja, eos_token, request\n        )\n\n        # Create a non-existent token that definitely won't be in the template\n        non_existent_token = \"[DEFINITELY_NOT_IN_TEMPLATE]\"\n        tokenizer.add_special_tokens(\n            {\"additional_special_tokens\": [non_existent_token]}\n        )\n\n        strategy = ChatTemplateStrategy(\n            ChatTemplatePrompter(\n                tokenizer,\n                chat_template=get_chat_template(\n                    chat_template, jinja_template=chat_template_jinja\n                ),\n                message_property_mappings={\"role\": \"from\", \"content\": \"value\"},\n                field_messages=\"conversations\",\n            ),\n            tokenizer=tokenizer,\n            train_on_inputs=False,\n            sequence_len=512,\n            roles_to_train=[\"assistant\"],\n            eot_tokens=[non_existent_token],\n        )\n\n        # Force template check by calling tokenize_prompt\n        strategy.tokenize_prompt(basic_dataset[0])\n\n        # We can also check that a warning was logged, but there's\n        # caplog conflicts when running with other tests\n        # assert any(\n        #     \"not found in chat_template\" in record.message for record in self._caplog.records\n        # ), \"Expected warning about token not found in template was not logged\"\n\n    def test_custom_eot_tokens(\n        self,\n        tokenizer,\n        chat_template,\n        chat_template_jinja,\n        eos_token,\n        basic_dataset,\n        request,\n    ):\n        \"\"\"Test with custom EOT tokens to ensure proper masking and training\"\"\"\n        LOG.info(\"Testing with custom EOT tokens\")\n\n        tokenizer, chat_template_jinja = self.setup_tokenizer(\n            tokenizer, chat_template, chat_template_jinja, None, request\n        )\n\n        # Add custom EOT tokens to the tokenizer\n        custom_eot = \"[EOT]\"\n        tokenizer.add_special_tokens({\"additional_special_tokens\": [custom_eot]})\n\n        # Create a custom chat template that uses our EOT token\n        custom_template = \"\"\"{% for message in messages %}{% if message['role'] == 'system' %}{{ message['content'] }}{% elif message['role'] == 'user' %}User: {{ message['content'] }}{% elif message['role'] == 'assistant' %}Assistant: {{ message['content'] }}[EOT]{% endif %}{% endfor %}\"\"\"\n\n        strategy = ChatTemplateStrategy(\n            ChatTemplatePrompter(\n                tokenizer,\n                chat_template=custom_template,\n                message_property_mappings={\"role\": \"from\", \"content\": \"value\"},\n                field_messages=\"conversations\",\n            ),\n            tokenizer=tokenizer,\n            train_on_inputs=False,\n            sequence_len=512,\n            roles_to_train=[\"assistant\"],\n            train_on_eot=\"turn\",  # Train on EOT token after each turn\n            eot_tokens=[custom_eot],\n        )\n\n        res = strategy.tokenize_prompt(basic_dataset[0])\n        labels = res[\"labels\"]\n        input_ids = res[\"input_ids\"]\n\n        # Find indices of the EOT token\n        eot_token_id = tokenizer.convert_tokens_to_ids(custom_eot)\n        eot_indices = [\n            i for i, token_id in enumerate(input_ids) if token_id == eot_token_id\n        ]\n\n        assert len(eot_indices) > 0, \"Expected at least one EOT token in the input\"\n\n        # Verify labeling for EOT tokens based on role\n        turns = strategy.get_conversation_thread(basic_dataset[0])\n        assistant_turn_indices = []\n        non_assistant_turn_indices = []\n\n        for i, turn in enumerate(basic_dataset[0][\"conversations\"]):\n            start_idx, end_idx = strategy.find_turn(turns=turns, turn_idx=i)\n            if start_idx != -1 and end_idx != -1:  # If turn is found\n                if turn[\"from\"] == \"assistant\":\n                    assistant_turn_indices.append((start_idx, end_idx))\n                else:\n                    non_assistant_turn_indices.append((start_idx, end_idx))\n\n        # Check EOT tokens after assistant turns are labeled\n        for eot_idx in eot_indices:\n            is_after_assistant = any(\n                start_idx <= eot_idx <= end_idx + 1  # +1 to include the EOT token\n                for start_idx, end_idx in assistant_turn_indices\n            )\n\n            if is_after_assistant:\n                assert labels[eot_idx] != IGNORE_TOKEN_ID, (\n                    f\"Expected EOT token after assistant turn at index {eot_idx} to be labeled\"\n                )\n            else:\n                assert labels[eot_idx] == IGNORE_TOKEN_ID, (\n                    f\"Expected EOT token not after assistant turn at index {eot_idx} to not be labeled\"\n                )\n\n    def test_multiple_train_on_eot_settings(\n        self,\n        tokenizer,\n        chat_template,\n        chat_template_jinja,\n        eos_token,\n        basic_dataset,\n        request,\n    ):\n        \"\"\"Test different train_on_eot settings\"\"\"\n        LOG.info(\"Testing different train_on_eot settings\")\n\n        tokenizer, chat_template_jinja = self.setup_tokenizer(\n            tokenizer, chat_template, chat_template_jinja, eos_token, request\n        )\n\n        # Create a list to test different train_on_eot settings\n        test_settings = [\n            (\"none\", lambda idx, is_assistant: False),  # Never train on EOT\n            (\"all\", lambda idx, is_assistant: True),  # Always train on EOT\n            (\n                \"turn\",\n                lambda idx, is_assistant: is_assistant,\n            ),  # Train on EOT after assistant turns\n            (\"last\", lambda idx, is_last: is_last),  # Only train on last EOT\n        ]\n\n        for setting, expected_train_func in test_settings:\n            LOG.info(f\"Testing train_on_eot='{setting}'\")\n\n            strategy = ChatTemplateStrategy(\n                ChatTemplatePrompter(\n                    tokenizer,\n                    chat_template=get_chat_template(\n                        chat_template, jinja_template=chat_template_jinja\n                    ),\n                    message_property_mappings={\"role\": \"from\", \"content\": \"value\"},\n                    field_messages=\"conversations\",\n                ),\n                tokenizer=tokenizer,\n                train_on_inputs=False,\n                sequence_len=512,\n                roles_to_train=[\"assistant\"],\n                train_on_eot=setting,\n                eot_tokens=[\n                    tokenizer.eos_token\n                ],  # Use eos_token as the EOT token for simplicity\n            )\n\n            res = strategy.tokenize_prompt(basic_dataset[0])\n            turns = strategy.get_conversation_thread(basic_dataset[0])\n            labels = res[\"labels\"]\n            input_ids = res[\"input_ids\"]\n\n            eos_token_id = tokenizer.eos_token_id\n            eos_indices = [\n                i for i, token_id in enumerate(input_ids) if token_id == eos_token_id\n            ]\n\n            assert len(eos_indices) > 0, (\n                \"Expected at least one EOS/EOT token in the input\"\n            )\n\n            # Check labeling for each EOS/EOT token\n            for idx, eos_idx in enumerate(eos_indices):\n                # Find which turn this EOS token belongs to\n                preceding_turn = None\n                for i, turn in enumerate(basic_dataset[0][\"conversations\"]):\n                    start_idx, end_idx = strategy.find_turn(turns=turns, turn_idx=i)\n                    if (\n                        start_idx != -1\n                        and end_idx != -1\n                        and start_idx <= eos_idx <= end_idx + 1\n                    ):\n                        preceding_turn = turn\n                        break\n\n                is_assistant = (\n                    preceding_turn is not None and preceding_turn[\"from\"] == \"assistant\"\n                )\n                is_last = idx == len(eos_indices) - 1\n\n                expected_label = not expected_train_func(\n                    idx, is_assistant if setting != \"last\" else is_last\n                )\n\n                if expected_label:\n                    assert labels[eos_idx] == IGNORE_TOKEN_ID, (\n                        f\"Expected EOT token at index {eos_idx} to not be labeled with train_on_eot='{setting}'\"\n                    )\n                else:\n                    assert labels[eos_idx] != IGNORE_TOKEN_ID, (\n                        f\"Expected EOT token at index {eos_idx} to be labeled with train_on_eot='{setting}'\"\n                    )\n\n\nclass TestChatTemplateToolCalling:\n    \"\"\"\n    Test class for tool calling functionality with chat templates.\n    \"\"\"\n\n    def test_tool_calling_with_llama4_template(\n        self,\n        llama3_tokenizer,\n    ):\n        LOG.info(\"Testing tool calling with llama3 tokenizer and llama4 chat template\")\n\n        # Create tool calling dataset\n        tool_calling_dataset = [\n            {\n                \"tools\": [\n                    {\n                        \"type\": \"function\",\n                        \"function\": {\n                            \"name\": \"xml_escape\",\n                            \"description\": 'Replaces any \"<\", \">\", or \"&\" characters in the input string with their corresponding XML entities.',\n                            \"parameters\": {\n                                \"type\": \"object\",\n                                \"properties\": {\n                                    \"s\": {\n                                        \"type\": \"string\",\n                                        \"description\": \"The input string to be XML-escaped.\",\n                                    }\n                                },\n                                \"required\": [\"s\"],\n                            },\n                        },\n                    },\n                    {\n                        \"type\": \"function\",\n                        \"function\": {\n                            \"name\": \"multiples\",\n                            \"description\": \"Generates a list of all the multiples of a number that are less than a given limit.\",\n                            \"parameters\": {\n                                \"type\": \"object\",\n                                \"properties\": {\n                                    \"number\": {\n                                        \"type\": \"integer\",\n                                        \"description\": \"The number to find multiples of.\",\n                                    },\n                                    \"limit\": {\n                                        \"type\": \"integer\",\n                                        \"description\": \"The upper limit for the multiples.\",\n                                    },\n                                },\n                                \"required\": [\"number\", \"limit\"],\n                            },\n                        },\n                    },\n                ],\n                \"messages\": [\n                    {\n                        \"role\": \"user\",\n                        \"content\": \"Can you help me find multiples of 5 that are less than 20?\",\n                    },\n                    {\n                        \"role\": \"assistant\",\n                        \"tool_calls\": [\n                            {\n                                \"type\": \"function\",\n                                \"function\": {\n                                    \"name\": \"multiples\",\n                                    \"arguments\": {\n                                        \"number\": 5,\n                                        \"limit\": 20,\n                                    },\n                                },\n                            }\n                        ],\n                    },\n                    {\"role\": \"tool\", \"name\": \"multiples\", \"content\": \"5,10,15\"},\n                    {\n                        \"role\": \"assistant\",\n                        \"content\": \"The multiples of 5 less than 20 are: 5, 10, and 15.\",\n                    },\n                ],\n            }\n        ]\n\n        # Setup tokenizer with llama4 chat template\n        tokenizer = deepcopy(llama3_tokenizer)\n\n        # Add EOS token to the tokenizer\n        eot_token = \"<|eot_id|>\"\n        tokenizer.add_special_tokens({\"additional_special_tokens\": [eot_token]})\n\n        strategy = ChatTemplateStrategy(\n            ChatTemplatePrompter(\n                tokenizer,\n                chat_template=get_chat_template(\"llama4\"),\n                message_property_mappings={\"role\": \"role\", \"content\": \"content\"},\n                field_messages=\"messages\",\n                field_tools=\"tools\",\n            ),\n            tokenizer=tokenizer,\n            train_on_inputs=False,\n            sequence_len=512,\n            roles_to_train=[\"assistant\"],\n            eot_tokens=[eot_token],\n        )\n\n        res = strategy.tokenize_prompt(tool_calling_dataset[0])\n        input_ids = res[\"input_ids\"]\n        labels = res[\"labels\"]\n\n        # Verify that the input_ids contain expected tokens\n        assert len(input_ids) > 0, \"Input IDs should not be empty\"\n        assert len(labels) == len(input_ids), \"Labels should match input_ids length\"\n\n        # Decode the full conversation to verify structure\n        decoded_conversation = tokenizer.decode(input_ids)\n\n        # Verify tool calling structure is present in the decoded conversation\n        assert '\"type\": \"function\",' in decoded_conversation, (\n            \"Tool type function should be in conversation\"\n        )\n        assert '\"name\": \"multiples\",' in decoded_conversation, (\n            \"Tool function name should be in conversation\"\n        )\n\n        assert (\n            '<|python_start|><|python_end|>{\"name\": \"multiples\", \"parameters\": {\"number\": 5, \"limit\": 20}}<|eot|>'\n            in decoded_conversation\n        ), \"Assistant tool call should be in conversation\"\n        assert \"<|header_start|>ipython<|header_end|>\" in decoded_conversation, (\n            \"IPython header should be in conversation\"\n        )\n        assert '\"5,10,15\"' in decoded_conversation, (\n            \"Tool response should be in conversation\"\n        )\n\n        # Get conversation turns to verify labeling\n        turns = strategy.get_conversation_thread(tool_calling_dataset[0])\n        tools = strategy._get_tools(tool_calling_dataset[0])\n\n        # Check that assistant responses are properly labeled\n        for i, turn in enumerate(tool_calling_dataset[0][\"messages\"]):\n            if turn[\"role\"] == \"assistant\":\n                start_idx, end_idx = strategy.find_turn(\n                    turns=turns, turn_idx=i, tools=tools\n                )\n\n                assert start_idx != -1 and end_idx != -1, (\n                    f\"Assistant turn {i} should be found\"\n                )\n\n                # Verify that assistant responses have proper labels\n                turn_labels = labels[start_idx:end_idx]\n                assert all(label != IGNORE_TOKEN_ID for label in turn_labels), (\n                    f\"Assistant turn {i} should be unmasked\"\n                )\n"
  },
  {
    "path": "tests/prompt_strategies/test_chat_templates_mistral.py",
    "content": "\"\"\"Test chat templates for mistral-common wrapper tokenizer\"\"\"\n\nimport unittest\nfrom typing import TYPE_CHECKING\n\nimport pytest\n\nif TYPE_CHECKING:\n    from transformers import PreTrainedTokenizer\n\n    from axolotl.utils.mistral import HFMistralTokenizer\n\n\n# fmt: off\n@pytest.mark.parametrize(\n    (\"tokenizer_str\", \"assistant_toolcall_ids\", \"tool_result_ids\"),\n    (\n        (\"magistral_tokenizer\", (9, 44627, 3684, 33, 19881, 1049, 1050, 1051, 1052, 1053, 32, 19227, 12856, 2811, 1032, 1049, 1054, 1044, 1429, 33319, 2811, 1032, 1050, 1125, 2), (7, 19881, 1049, 1050, 1051, 1052, 1053, 19, 1049, 1044, 1050, 8)),\n        (\"devstral_tokenizer\", (9, 1091, 19227, 2391, 2811, 1429, 44627, 3684, 1897, 1429, 61906, 2811, 16753, 12856, 2811, 1032, 1049, 1054, 1044, 1429, 33319, 2811, 1032, 1050, 4179, 1429, 1327, 2811, 1429, 19881, 1049, 1050, 1051, 1052, 1053, 1034, 27028, 2), (7, 19881, 1049, 1050, 1051, 1052, 1053, 19, 1049, 1044, 1050, 8)),\n        (\"devstral_1_1_tokenizer\", (9, 44627, 3684, 32, 19227, 12856, 2811, 1032, 1049, 1054, 1044, 1429, 33319, 2811, 1032, 1050, 1125, 2,), (7, 1049, 1044, 1050, 8)),\n    )\n)\n# fmt: on\ndef test_mistral_chat_template(\n    tokenizer_str: str,\n    assistant_toolcall_ids: tuple[int, ...],\n    tool_result_ids: tuple[int, ...],\n    request: pytest.FixtureRequest,\n):\n    \"\"\"Test chat template with the Magistral/Devstral tokenizer\"\"\"\n\n    from axolotl.prompt_strategies.chat_template import MistralPrompter, MistralStrategy\n\n    tokenizer: HFMistralTokenizer = request.getfixturevalue(tokenizer_str)\n\n    # check bos, eos, pad, unk are accessible properties\n    assert tokenizer.bos_token_id == 1\n    assert tokenizer.eos_token_id == 2\n    assert tokenizer.pad_token_id == 11\n    assert tokenizer.unk_token_id == 0\n\n    assert tokenizer.pad_token == \"<pad>\"\n    assert tokenizer.eos_token == \"</s>\"\n    assert tokenizer.bos_token == \"<s>\"\n    assert tokenizer.unk_token == \"<unk>\"\n\n    strategy = MistralStrategy(\n        MistralPrompter(\n            tokenizer,\n            chat_template=None,\n            message_property_mappings={\"role\": \"role\", \"content\": \"content\"},\n        ),\n        tokenizer=tokenizer,\n        train_on_inputs=False,\n        train_on_eos=\"turn\",\n        sequence_len=512,\n        roles_to_train=[\"assistant\"],\n    )\n\n    # test chat template masking without system prompt\n    res = strategy.tokenize_prompt(\n        {\n            \"messages\": [\n                {\"role\": \"user\", \"content\": \"Hello, how are you?\"},\n                {\"role\": \"assistant\", \"content\": \"I'm doing great, thank you!\"},\n            ]\n        }\n    )\n\n    assert res[\"input_ids\"] == [\n        1,  # bos\n        3,  # [INST]\n        22177,  # Hello\n        1044,  # ,\n        2606,  # how\n        1584,  # are\n        1636,  # you\n        1063,  # ?\n        4,  # [/INST]\n        1073,  # I\n        4525,  # 'm\n        6965,  # doing\n        4824,  # great\n        1044,  # ,\n        15412,  # thank\n        1636,  # you\n        1033,  # !\n        2,  # </s>\n    ]\n\n    assert res[\"labels\"] == [\n        -100,  # bos\n        -100,  # [INST]\n        -100,  # Hello\n        -100,  # ,\n        -100,  # how\n        -100,  # are\n        -100,  # you\n        -100,  # ?\n        -100,  # [/INST]\n        1073,  # I\n        4525,  # 'm\n        6965,  # doing\n        4824,  # great\n        1044,  # ,\n        15412,  # thank\n        1636,  # you\n        1033,  # !\n        2,  # </s>\n    ]\n\n    # test chat template masking with system prompt\n    res = strategy.tokenize_prompt(\n        {\n            \"messages\": [\n                {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n                {\"role\": \"user\", \"content\": \"Hello, how are you?\"},\n                {\"role\": \"assistant\", \"content\": \"I'm doing great, thank you!\"},\n            ]\n        }\n    )\n\n    assert res[\"input_ids\"] == [\n        1,  # bos\n        17,  # [SYSTEM_PROMPT]\n        4568,  # You\n        1584,  # are\n        1261,  # a\n        20351,  # helpful\n        27089,  # assistant\n        1046,  # .\n        18,  # [/SYSTEM_PROMPT]\n        3,  # [INST]\n        22177,  # Hello\n        1044,  # ,\n        2606,  # how\n        1584,  # are\n        1636,  # you\n        1063,  # ?\n        4,  # [/INST]\n        1073,  # I\n        4525,  # 'm\n        6965,  # doing\n        4824,  # great\n        1044,  # ,\n        15412,  # thank\n        1636,  # you\n        1033,  # !\n        2,  # </s>\n    ]\n\n    assert res[\"labels\"] == [\n        -100,  # bos\n        -100,  # [SYSTEM_PROMPT]\n        -100,  # You\n        -100,  # are\n        -100,  # a\n        -100,  # helpful\n        -100,  # assistant\n        -100,  # .\n        -100,  # [/SYSTEM_PROMPT]\n        -100,  # [INST]\n        -100,  # Hello\n        -100,  # ,\n        -100,  # how\n        -100,  # are\n        -100,  # you\n        -100,  # ?\n        -100,  # [/INST]\n        1073,  # I\n        4525,  # 'm\n        6965,  # doing\n        4824,  # great\n        1044,  # ,\n        15412,  # thank\n        1636,  # you\n        1033,  # !\n        2,  # </s>\n    ]\n\n    # test chat template with tools\n    res = strategy.tokenize_prompt(\n        {\n            \"tools\": [\n                {\n                    \"type\": \"function\",\n                    \"function\": {\n                        \"name\": \"multiples\",\n                        \"description\": \"Generates a list of all the multiples of a number that are less than a given limit.\",\n                        \"parameters\": {\n                            \"type\": \"object\",\n                            \"properties\": {\n                                \"number\": {\n                                    \"type\": \"integer\",\n                                    \"description\": \"The number to find multiples of.\",\n                                },\n                                \"limit\": {\n                                    \"type\": \"integer\",\n                                    \"description\": \"The upper limit for the multiples.\",\n                                },\n                            },\n                            \"required\": [\"number\", \"limit\"],\n                        },\n                    },\n                },\n            ],\n            \"messages\": [\n                {\n                    \"role\": \"user\",\n                    \"content\": \"Hey, can you give me a breakdown of how to throw an awesome themed party? Like, what themes work best, and how can I set everything up to really wow my guests? I want some ideas on decorations, food, and activities that will make the party unforgettable!\",\n                },\n                {\n                    \"role\": \"assistant\",\n                    \"tool_calls\": [\n                        {\n                            \"id\": \"call12345\",\n                            \"type\": \"function\",\n                            \"function\": {\n                                \"name\": \"multiples\",\n                                \"arguments\": {\n                                    \"number\": 16,\n                                    \"limit\": 2,\n                                },\n                            },\n                        }\n                    ],\n                },\n                {\n                    \"role\": \"tool\",\n                    \"tool_call_id\": \"call12345\",\n                    \"name\": \"multiples\",\n                    \"content\": \"1,2\",\n                },\n                {\"role\": \"assistant\", \"content\": \"The multiples of 16 is 1 and 2.\"},\n            ],\n        }\n    )\n\n    # fmt: off\n    assert res[\"input_ids\"] == [\n        1,  # bos\n        5, 1091, 19227, 4994, 2811, 1429, 5165, 1897, 1429, 5165, 2811, 16753, 2391, 2811, 1429, 44627, 3684, 1897, 1429, 14653, 2811, 1429, 10639, 2130, 1261, 2951, 1307, 1747, 1278, 60092, 1307, 1261, 2782, 1455, 1584, 4289, 2224, 1261, 4265, 6139, 39249, 1429, 26204, 2811, 16753, 4994, 2811, 1429, 6371, 1897, 1429, 48649, 2811, 16753, 12856, 2811, 16753, 4994, 2811, 1429, 49039, 1897, 1429, 14653, 2811, 1429, 1784, 2782, 1317, 3081, 60092, 1307, 2613, 4179, 1429, 33319, 2811, 16753, 4994, 2811, 1429, 49039, 1897, 1429, 14653, 2811, 1429, 1784, 9229, 6139, 1394, 1278, 60092, 2613, 47579, 1429, 15760, 2811, 12161, 12856, 1897, 1429, 33319, 4964, 2821, 27028, 6,  # tool prompt\n        3, 46634, 1044, 1710, 1636, 5628, 1639, 1261, 44433, 1307, 2606, 1317, 5388, 1420, 54191, 2424, 1286, 8967, 1063, 15621, 1044, 2549, 30305, 2196, 3560, 1044, 1321, 2606, 1710, 1362, 2016, 8605, 2015, 1317, 5524, 118931, 2036, 32951, 1063, 1362, 2933, 2269, 12106, 1408, 101987, 1044, 6939, 1044, 1321, 9216, 1455, 2084, 3180, 1278, 8967, 119141, 1689, 5935, 1033, 4,  # user\n        *assistant_toolcall_ids,  # assistant tool calling\n        *tool_result_ids,  # tool result\n        1784, 60092, 1307, 1032, 1049, 1054, 1395, 1032, 1049, 1321, 1032, 1050, 1046,  # assistant\n        2  # eos\n    ]\n\n    assert res[\"labels\"] == [\n        -100,  # bos\n        -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,  # tool prompt\n        -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,  # user prompt\n        *assistant_toolcall_ids,  # assistant tool calling\n        *([-100] * len(tool_result_ids)),  # tool result\n        1784, 60092, 1307, 1032, 1049, 1054, 1395, 1032, 1049, 1321, 1032, 1050, 1046,  # assistant\n        2  # eos\n    ]\n    # fmt: on\n\n    # test chat template with tokenize=False\n    res = tokenizer.apply_chat_template(\n        [\n            {\"role\": \"user\", \"content\": \"Hello, how are you?\"},\n            {\"role\": \"assistant\", \"content\": \"I'm doing great, thank you!\"},\n        ],\n        tokenize=False,\n    )\n\n    assert res == \"<s>[INST]Hello, how are you?[/INST]I'm doing great, thank you!</s>\"\n\n    # test encode\n    res = tokenizer.encode(\"Hello, how are you?\", add_special_tokens=True)\n    assert res == [\n        1,  # bos\n        22177,  # Hello\n        1044,  # ,\n        2606,  # how\n        1584,  # are\n        1636,  # you\n        1063,  # ?\n        2,  # eos\n    ]\n\n    # test decode no skip special tokens\n    decoded_res = tokenizer.decode(res, skip_special_tokens=False)\n\n    assert decoded_res == \"<s>Hello, how are you?</s>\"\n\n    # test decode skip special tokens\n    decoded_res = tokenizer.decode(res, skip_special_tokens=True)\n    assert decoded_res == \"Hello, how are you?\"\n\n    # test encode no special tokens\n    res = tokenizer.encode(\"Hello, how are you?\", add_special_tokens=False)\n    assert res == [\n        22177,  # Hello\n        1044,  # ,\n        2606,  # how\n        1584,  # are\n        1636,  # you\n        1063,  # ?\n    ]\n\n    # test convert ids to tokens\n    res = tokenizer.convert_ids_to_tokens(res)\n    # spacing are needed as we are converting without decoding\n    assert res == [\"Hello\", \",\", \" how\", \" are\", \" you\", \"?\"]\n\n\n@pytest.mark.skip(reason=\"TODO, fix for new HF wrapper call\")\ndef test_magistral_tokenizer_pad_method(magistral_tokenizer: \"HFMistralTokenizer\"):\n    \"\"\"Test the MistralTokenizer pad method\"\"\"\n    from axolotl.utils.collators.core import IGNORE_INDEX\n\n    magistral_pad_token_id = 11  # taken from tokenizer.pad_token_id\n\n    # Test padding with input_ids and labels only\n    features = [\n        {\"input_ids\": [1, 2, 3], \"labels\": [4, 5, 6]},\n        {\"input_ids\": [7, 8], \"labels\": [9, 10]},\n    ]\n\n    result = magistral_tokenizer.pad(features, padding=True, return_tensors=\"pt\")\n\n    # Check that input_ids are padded correctly\n    assert result[\"input_ids\"].shape == (2, 3)\n    assert result[\"input_ids\"].tolist() == [[1, 2, 3], [7, 8, magistral_pad_token_id]]\n\n    # Check that labels are padded correctly\n    assert result[\"labels\"].shape == (2, 3)\n    assert result[\"labels\"].tolist() == [[4, 5, 6], [9, 10, IGNORE_INDEX]]\n\n    # Check that attention_mask and position_ids are NOT created\n    assert \"attention_mask\" not in result\n    assert \"position_ids\" not in result\n\n    # Test padding with attention_mask\n    features_with_attention = [\n        {\"input_ids\": [1, 2, 3], \"labels\": [4, 5, 6], \"attention_mask\": [1, 1, 1]},\n        {\"input_ids\": [7, 8], \"labels\": [9, 10], \"attention_mask\": [1, 1]},\n    ]\n\n    result = magistral_tokenizer.pad(\n        features_with_attention, padding=True, return_tensors=\"pt\"\n    )\n\n    # Check that attention_mask is padded correctly\n    assert result[\"attention_mask\"].shape == (2, 3)\n    assert result[\"attention_mask\"].tolist() == [[1, 1, 1], [1, 1, 0]]\n\n    # Test padding with position_ids\n    features_with_position = [\n        {\"input_ids\": [1, 2, 3], \"labels\": [4, 5, 6], \"position_ids\": [0, 1, 2]},\n        {\"input_ids\": [7, 8], \"labels\": [9, 10], \"position_ids\": [0, 1]},\n    ]\n\n    result = magistral_tokenizer.pad(\n        features_with_position, padding=True, return_tensors=\"pt\"\n    )\n\n    # Check that position_ids are padded correctly (continuing sequence)\n    assert result[\"position_ids\"].shape == (2, 3)\n    assert result[\"position_ids\"].tolist() == [[0, 1, 2], [0, 1, 2]]\n\n    # Test padding with all fields\n    features_all = [\n        {\n            \"input_ids\": [1, 2, 3],\n            \"labels\": [4, 5, 6],\n            \"attention_mask\": [1, 1, 1],\n            \"position_ids\": [0, 1, 2],\n        },\n        {\n            \"input_ids\": [7, 8],\n            \"labels\": [9, 10],\n            \"attention_mask\": [1, 1],\n            \"position_ids\": [0, 1],\n        },\n    ]\n\n    result = magistral_tokenizer.pad(features_all, padding=True, return_tensors=\"pt\")\n\n    # All fields should be present and correctly padded\n    assert \"input_ids\" in result\n    assert \"labels\" in result\n    assert \"attention_mask\" in result\n    assert \"position_ids\" in result\n\n    # Test padding with all sequences same length\n    features_same_length = [\n        {\"input_ids\": [1, 2, 3], \"labels\": [4, 5, 6]},\n        {\"input_ids\": [7, 8, 9], \"labels\": [10, 11, 12]},\n    ]\n\n    result = magistral_tokenizer.pad(\n        features_same_length, padding=True, return_tensors=\"pt\"\n    )\n\n    # Check match when no padding is needed\n    assert result[\"input_ids\"][0].tolist() == features_same_length[0][\"input_ids\"]\n    assert result[\"labels\"][0].tolist() == features_same_length[0][\"labels\"]\n\n    assert result[\"input_ids\"][1].tolist() == features_same_length[1][\"input_ids\"]\n    assert result[\"labels\"][1].tolist() == features_same_length[1][\"labels\"]\n\n    # Test padding with max_length parameter\n    result = magistral_tokenizer.pad(\n        features, padding=\"max_length\", max_length=5, return_tensors=\"pt\"\n    )\n\n    # Should pad to max_length\n    assert result[\"input_ids\"].shape == (2, 5)\n    assert result[\"labels\"].shape == (2, 5)\n\n    # Test numpy return type\n    result = magistral_tokenizer.pad(features, padding=True, return_tensors=\"np\")\n\n    # Should return numpy arrays\n    import numpy as np\n\n    assert isinstance(result[\"input_ids\"], np.ndarray)\n    assert isinstance(result[\"labels\"], np.ndarray)\n\n    # Test unsupported field rejection\n    features_unsupported = [\n        {\"input_ids\": [1, 2, 3], \"labels\": [4, 5, 6], \"unsupported_field\": [7, 8, 9]},\n    ]\n\n    with pytest.raises(NotImplementedError, match=\"unsupported_field\"):\n        magistral_tokenizer.pad(features_unsupported, padding=True, return_tensors=\"pt\")\n\n    # Test token_type_ids rejection\n    features_token_type = [\n        {\"input_ids\": [1, 2, 3], \"labels\": [4, 5, 6], \"token_type_ids\": [0, 0, 0]},\n    ]\n\n    with pytest.raises(ValueError, match=\"token_type_ids is not supported\"):\n        magistral_tokenizer.pad(features_token_type, padding=True, return_tensors=\"pt\")\n\n\ndef test_magistral_tool_calling(magistral_tokenizer: \"HFMistralTokenizer\"):\n    \"\"\"Test tool calling with the Magistral tokenizer\"\"\"\n    from axolotl.prompt_strategies.chat_template import MistralPrompter, MistralStrategy\n\n    strategy = MistralStrategy(\n        MistralPrompter(\n            magistral_tokenizer,\n            chat_template=None,\n            message_property_mappings={\"role\": \"role\", \"content\": \"content\"},\n        ),\n        tokenizer=magistral_tokenizer,\n        train_on_inputs=False,\n        train_on_eos=\"turn\",\n        sequence_len=512,\n        roles_to_train=[\"assistant\"],\n    )\n\n    # Test basic tool calling with single function\n    basic_tool_calling = {\n        \"tools\": [\n            {\n                \"type\": \"function\",\n                \"function\": {\n                    \"name\": \"get_weather\",\n                    \"description\": \"Get the current weather for a location\",\n                    \"parameters\": {\n                        \"type\": \"object\",\n                        \"properties\": {\n                            \"location\": {\n                                \"type\": \"string\",\n                                \"description\": \"The city and state, e.g. San Francisco, CA\",\n                            },\n                        },\n                        \"required\": [\"location\"],\n                    },\n                },\n            },\n        ],\n        \"messages\": [\n            {\n                \"role\": \"user\",\n                \"content\": \"What's the weather like in San Francisco?\",\n            },\n            {\n                \"role\": \"assistant\",\n                \"tool_calls\": [\n                    {\n                        \"id\": \"call12345\",\n                        \"type\": \"function\",\n                        \"function\": {\n                            \"name\": \"get_weather\",\n                            \"arguments\": {\n                                \"location\": \"San Francisco, CA\",\n                            },\n                        },\n                    }\n                ],\n            },\n            {\n                \"role\": \"tool\",\n                \"tool_call_id\": \"call12345\",\n                \"name\": \"get_weather\",\n                \"content\": \"Sunny, 72°F\",\n            },\n            {\n                \"role\": \"assistant\",\n                \"content\": \"The weather in San Francisco is sunny and 72°F.\",\n            },\n        ],\n    }\n\n    res = strategy.tokenize_prompt(basic_tool_calling)\n\n    # Basic validation\n    assert \"input_ids\" in res\n    assert \"labels\" in res\n    assert len(res[\"input_ids\"]) > 0\n    assert len(res[\"labels\"]) == len(res[\"input_ids\"])\n\n    # Decode and verify structure\n    decoded = magistral_tokenizer.decode(res[\"input_ids\"])\n    assert (\n        '<s>[AVAILABLE_TOOLS][{\"type\": \"function\", \"function\": {\"name\": \"get_weather\", \"description\": \"Get the current weather for a location\", \"parameters\": {\"type\": \"object\", \"properties\": {\"location\": {\"type\": \"string\", \"description\": \"The city and state, e.g. San Francisco, CA\"}}, \"required\": [\"location\"]}}}][/AVAILABLE_TOOLS]'\n        in decoded\n    )\n    assert (\n        '[TOOL_CALLS]get_weather[CALL_ID]call12345[ARGS]{\"location\": \"San Francisco, CA\"}</s>'\n        in decoded\n    )\n    assert \"[TOOL_RESULTS]call12345[TOOL_CONTENT]Sunny, 72°F[/TOOL_RESULTS]\" in decoded\n    assert \"The weather in San Francisco is sunny and 72°F.</s>\" in decoded\n\n    # Test multiple tool calls in sequence\n    multi_tool_calling = {\n        \"tools\": [\n            {\n                \"type\": \"function\",\n                \"function\": {\n                    \"name\": \"add_numbers\",\n                    \"description\": \"Add two numbers together\",\n                    \"parameters\": {\n                        \"type\": \"object\",\n                        \"properties\": {\n                            \"a\": {\"type\": \"number\", \"description\": \"First number\"},\n                            \"b\": {\"type\": \"number\", \"description\": \"Second number\"},\n                        },\n                        \"required\": [\"a\", \"b\"],\n                    },\n                },\n            },\n            {\n                \"type\": \"function\",\n                \"function\": {\n                    \"name\": \"multiply_numbers\",\n                    \"description\": \"Multiply two numbers\",\n                    \"parameters\": {\n                        \"type\": \"object\",\n                        \"properties\": {\n                            \"x\": {\"type\": \"number\", \"description\": \"First number\"},\n                            \"y\": {\"type\": \"number\", \"description\": \"Second number\"},\n                        },\n                        \"required\": [\"x\", \"y\"],\n                    },\n                },\n            },\n        ],\n        \"messages\": [\n            {\n                \"role\": \"user\",\n                \"content\": \"Add 5 and 3, then multiply the result by 2\",\n            },\n            {\n                \"role\": \"assistant\",\n                \"tool_calls\": [\n                    {\n                        \"id\": \"call12345\",\n                        \"type\": \"function\",\n                        \"function\": {\n                            \"name\": \"add_numbers\",\n                            \"arguments\": {\"a\": 5, \"b\": 3},\n                        },\n                    }\n                ],\n            },\n            {\n                \"role\": \"tool\",\n                \"tool_call_id\": \"call12345\",\n                \"name\": \"add_numbers\",\n                \"content\": \"8\",\n            },\n            {\n                \"role\": \"assistant\",\n                \"tool_calls\": [\n                    {\n                        \"id\": \"call23456\",\n                        \"type\": \"function\",\n                        \"function\": {\n                            \"name\": \"multiply_numbers\",\n                            \"arguments\": {\"x\": 8, \"y\": 2},\n                        },\n                    }\n                ],\n            },\n            {\n                \"role\": \"tool\",\n                \"tool_call_id\": \"call23456\",\n                \"name\": \"multiply_numbers\",\n                \"content\": \"16\",\n            },\n            {\n                \"role\": \"assistant\",\n                \"content\": \"The result is 16. I first added 5 and 3 to get 8, then multiplied 8 by 2 to get 16.\",\n            },\n        ],\n    }\n\n    res = strategy.tokenize_prompt(multi_tool_calling)\n\n    # Validation\n    assert len(res[\"input_ids\"]) > 0\n    assert len(res[\"labels\"]) == len(res[\"input_ids\"])\n\n    decoded = magistral_tokenizer.decode(res[\"input_ids\"])\n    assert (\n        '<s>[AVAILABLE_TOOLS][{\"type\": \"function\", \"function\": {\"name\": \"add_numbers\", \"description\": \"Add two numbers together\", \"parameters\": {\"type\": \"object\", \"properties\": {\"a\": {\"type\": \"number\", \"description\": \"First number\"}, \"b\": {\"type\": \"number\", \"description\": \"Second number\"}}, \"required\": [\"a\", \"b\"]}}}, {\"type\": \"function\", \"function\": {\"name\": \"multiply_numbers\", \"description\": \"Multiply two numbers\", \"parameters\": {\"type\": \"object\", \"properties\": {\"x\": {\"type\": \"number\", \"description\": \"First number\"}, \"y\": {\"type\": \"number\", \"description\": \"Second number\"}}, \"required\": [\"x\", \"y\"]}}}][/AVAILABLE_TOOLS]'\n        in decoded\n    )\n    assert (\n        '[TOOL_CALLS]add_numbers[CALL_ID]call12345[ARGS]{\"a\": 5, \"b\": 3}</s>' in decoded\n    )\n    assert \"[TOOL_RESULTS]call12345[TOOL_CONTENT]8[/TOOL_RESULTS]\" in decoded\n    assert (\n        '[TOOL_CALLS]multiply_numbers[CALL_ID]call23456[ARGS]{\"x\": 8, \"y\": 2}</s>'\n        in decoded\n    )\n    assert \"[TOOL_RESULTS]call23456[TOOL_CONTENT]16[/TOOL_RESULTS]\" in decoded\n    assert (\n        \"The result is 16. I first added 5 and 3 to get 8, then multiplied 8 by 2 to get 16.</s>\"\n        in decoded\n    )\n\n    # Test tool calling with system message\n    system_tool_calling = {\n        \"tools\": [\n            {\n                \"type\": \"function\",\n                \"function\": {\n                    \"name\": \"search_database\",\n                    \"description\": \"Search for information in database\",\n                    \"parameters\": {\n                        \"type\": \"object\",\n                        \"properties\": {\n                            \"query\": {\"type\": \"string\", \"description\": \"Search query\"},\n                        },\n                        \"required\": [\"query\"],\n                    },\n                },\n            },\n        ],\n        \"messages\": [\n            {\n                \"role\": \"system\",\n                \"content\": \"You are a helpful assistant with access to a database.\",\n            },\n            {\n                \"role\": \"user\",\n                \"content\": \"Find information about Python programming\",\n            },\n            {\n                \"role\": \"assistant\",\n                \"tool_calls\": [\n                    {\n                        \"id\": \"search123\",\n                        \"type\": \"function\",\n                        \"function\": {\n                            \"name\": \"search_database\",\n                            \"arguments\": {\"query\": \"Python programming\"},\n                        },\n                    }\n                ],\n            },\n            {\n                \"role\": \"tool\",\n                \"tool_call_id\": \"search123\",\n                \"name\": \"search_database\",\n                \"content\": \"Python is a high-level programming language known for its simplicity.\",\n            },\n            {\n                \"role\": \"assistant\",\n                \"content\": \"Based on the database search, Python is a high-level programming language known for its simplicity and readability.\",\n            },\n        ],\n    }\n\n    res = strategy.tokenize_prompt(system_tool_calling)\n\n    # Validation\n    assert len(res[\"input_ids\"]) > 0\n    assert len(res[\"labels\"]) == len(res[\"input_ids\"])\n\n    decoded = magistral_tokenizer.decode(res[\"input_ids\"])\n\n    assert (\n        '<s>[SYSTEM_PROMPT]You are a helpful assistant with access to a database.[/SYSTEM_PROMPT][AVAILABLE_TOOLS][{\"type\": \"function\", \"function\": {\"name\": \"search_database\", \"description\": \"Search for information in database\", \"parameters\": {\"type\": \"object\", \"properties\": {\"query\": {\"type\": \"string\", \"description\": \"Search query\"}}, \"required\": [\"query\"]}}}][/AVAILABLE_TOOLS]'\n        in decoded\n    )\n\n    # Test error handling - missing tool response\n    incomplete_tool_calling = {\n        \"tools\": [\n            {\n                \"type\": \"function\",\n                \"function\": {\n                    \"name\": \"get_time\",\n                    \"description\": \"Get current time\",\n                    \"parameters\": {\"type\": \"object\", \"properties\": {}},\n                },\n            },\n        ],\n        \"messages\": [\n            {\n                \"role\": \"user\",\n                \"content\": \"What time is it?\",\n            },\n            {\n                \"role\": \"assistant\",\n                \"tool_calls\": [\n                    {\n                        \"id\": \"time12345\",\n                        \"type\": \"function\",\n                        \"function\": {\n                            \"name\": \"get_time\",\n                            \"arguments\": {},\n                        },\n                    }\n                ],\n            },\n            {\n                \"role\": \"assistant\",\n                \"content\": \"The current time is 12:00 PM.\",\n            },\n        ],\n    }\n\n    from mistral_common.exceptions import InvalidMessageStructureException\n\n    try:\n        strategy.tokenize_prompt(incomplete_tool_calling)\n    except InvalidMessageStructureException as e:\n        assert \"Not the same number of function calls and responses\" in str(e)\n\n\n@pytest.mark.skip(reason=\"TODO, fix for new HF wrapper call\")\ndef test_magistral_tokenizer_call_method(\n    magistral_tokenizer: \"HFMistralTokenizer\", llama3_tokenizer: \"PreTrainedTokenizer\"\n):\n    \"\"\"Test the __call__ method behavior matches HuggingFace standards\"\"\"\n    from copy import deepcopy\n\n    import numpy as np\n    import torch\n\n    hf_tokenizer = deepcopy(llama3_tokenizer)\n    hf_tokenizer.pad_token = hf_tokenizer.eos_token\n\n    test_text = \"Hello, how are you?\"\n    batch_texts = [\"Hello world\", \"How are you?\"]\n\n    # Test single string with return_tensors=None\n    hf_result: dict[str, list[int]] = hf_tokenizer(test_text, return_tensors=None)\n    mistral_result: dict[str, list[int]] = magistral_tokenizer(\n        test_text, return_tensors=None\n    )\n\n    assert isinstance(mistral_result, dict)\n    assert set(mistral_result.keys()) == {\"input_ids\", \"attention_mask\"}\n    assert isinstance(mistral_result[\"input_ids\"], type(hf_result[\"input_ids\"]))  # list\n    assert isinstance(\n        mistral_result[\"attention_mask\"], type(hf_result[\"attention_mask\"])\n    )\n    assert len(mistral_result[\"input_ids\"]) == len(mistral_result[\"attention_mask\"])\n    assert np.all(mistral_result[\"attention_mask\"])\n    assert len(np.array(mistral_result[\"input_ids\"]).shape) == 1  # 1D array\n\n    # Test single string with return_tensors='pt'\n    hf_result_pt: dict[str, torch.Tensor] = hf_tokenizer(test_text, return_tensors=\"pt\")\n    mistral_result_pt: dict[str, torch.Tensor] = magistral_tokenizer(\n        test_text, return_tensors=\"pt\"\n    )\n\n    # Check structure and types\n    assert isinstance(mistral_result_pt[\"input_ids\"], torch.Tensor)\n    assert isinstance(mistral_result_pt[\"attention_mask\"], torch.Tensor)\n\n    # Check shapes match (don't compare token dimension)\n    assert len(hf_result_pt[\"input_ids\"].shape) == len(\n        mistral_result_pt[\"input_ids\"].shape\n    )\n    assert hf_result_pt[\"input_ids\"].shape[0] == mistral_result_pt[\"input_ids\"].shape[0]\n    assert (\n        mistral_result_pt[\"attention_mask\"].shape\n        == mistral_result_pt[\"input_ids\"].shape\n    )\n    assert torch.all(mistral_result_pt[\"attention_mask\"] == 1)\n\n    # Test batch input with padding\n    hf_batch: dict[str, torch.Tensor] = hf_tokenizer(\n        batch_texts, return_tensors=\"pt\", padding=True\n    )\n    mistral_batch: dict[str, torch.Tensor] = magistral_tokenizer(\n        batch_texts, return_tensors=\"pt\", padding=True\n    )\n\n    # Check batch behavior\n    assert len(hf_batch[\"input_ids\"].shape) == len(mistral_batch[\"input_ids\"].shape)\n    assert hf_batch[\"input_ids\"].shape[0] == mistral_batch[\"input_ids\"].shape[0]\n    assert mistral_batch[\"attention_mask\"].shape == mistral_batch[\"input_ids\"].shape\n    assert torch.any(\n        mistral_batch[\"attention_mask\"][0] == 0\n    )  # padding in shorter sequence\n    assert torch.all(\n        mistral_batch[\"attention_mask\"][1] == 1\n    )  # no padding in longer sequence\n\n    # Test numpy tensors\n    mistral_result_np: dict[str, np.ndarray] = magistral_tokenizer(\n        test_text, return_tensors=\"np\"\n    )\n    assert isinstance(mistral_result_np[\"input_ids\"], np.ndarray)\n    assert isinstance(mistral_result_np[\"attention_mask\"], np.ndarray)\n\n    # Test consistency with encode()\n    encoded: list[int] = magistral_tokenizer.encode(test_text, add_special_tokens=True)\n    called: dict[str, torch.Tensor] = magistral_tokenizer(\n        test_text, return_tensors=\"pt\"\n    )\n    assert encoded == called[\"input_ids\"][0].tolist()\n\n    # Test Error handling\n    with pytest.raises(ValueError, match=\"Unsupported kwargs\"):\n        magistral_tokenizer(test_text, unsupported_param=True)\n\n    with pytest.raises(\n        ValueError, match=\"return_tensors='pt' or 'np' requires padding or truncation\"\n    ):\n        magistral_tokenizer(batch_texts, return_tensors=\"pt\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/prompt_strategies/test_chat_templates_thinking.py",
    "content": "\"\"\"\nTests for splitting reasoning/thinking from content into separate field\n\"\"\"\n\nimport pytest\nfrom datasets import Dataset\n\nfrom axolotl.prompt_strategies.chat_template import (\n    load,\n)\nfrom axolotl.utils.dict import DictDefault\n\n\n@pytest.fixture(name=\"messages_w_reasoning\")\ndef messages_w_reasoning_fixture():\n    return Dataset.from_list(\n        [\n            {\n                \"messages\": [\n                    {\n                        \"role\": \"user\",\n                        \"content\": \"hello\",\n                    },\n                    {\n                        \"role\": \"assistant\",\n                        \"content\": \"<think>lorem</think>\\nwelcome\",\n                    },\n                ]\n            },\n            {\n                \"messages\": [\n                    {\n                        \"role\": \"user\",\n                        \"content\": \"hello\",\n                    },\n                    {\n                        \"role\": \"assistant\",\n                        \"content\": \"<|begin_of_thought|>lorem<|end_of_thought|>\\n<|begin_of_solution|>welcome\\n<|end_of_solution|>\",\n                    },\n                ]\n            },\n            {\n                \"messages\": [\n                    {\n                        \"role\": \"user\",\n                        \"content\": \"hello\",\n                    },\n                    {\n                        \"role\": \"assistant\",\n                        \"content\": \"<reasoning>lorem</reasoning>\\nwelcome\",\n                    },\n                ]\n            },\n        ]\n    )\n\n\nclass TestSplitThinking:\n    \"\"\"\n    test class to make sure datasets with reasoning content conforms to the chat_template strategy\n    \"\"\"\n\n    def test_splits_think(self, messages_w_reasoning, qwen3_tokenizer):\n        strategy = load(\n            qwen3_tokenizer,\n            DictDefault(\n                {\n                    \"train_on_inputs\": False,\n                    \"sequence_len\": 512,\n                }\n            ),\n            DictDefault(\n                {\n                    \"chat_template\": \"qwen3\",\n                    \"message_field_role\": \"role\",\n                    \"message_field_content\": \"content\",\n                    \"message_property_mappings\": {\n                        \"role\": \"role\",\n                        \"content\": \"content\",\n                    },\n                    \"roles\": {\n                        \"user\": [\"user\"],\n                        \"assistant\": [\"assistant\"],\n                        \"system\": [\"system\"],\n                    },\n                    \"field_messages\": \"messages\",\n                    \"split_thinking\": True,\n                }\n            ),\n        )\n        for conversation in messages_w_reasoning:\n            transformed_prompt = strategy.get_conversation_thread(conversation)\n            assert transformed_prompt[0][\"role\"] == \"user\"\n            assert transformed_prompt[1][\"role\"] == \"assistant\"\n            assert transformed_prompt[1][\"reasoning_content\"] == \"lorem\"\n            assert transformed_prompt[1][\"content\"] == \"welcome\"\n\n            res = strategy.tokenize_prompt(conversation)\n            input_ids = res[\"input_ids\"]\n            # fmt: off\n            expected_input_ids = [\n                151644,  # im_start\n                872,  # user\n                198,  # \\n\n                14990,  # hello\n                151645,  # im_end\n                198,  # \\n\n                151644,  # im_start\n                77091,  # assistant\n                198,  # \\n\n                151667,  # think\n                198,  # \\n\n                385, 1826,  # lorem\n                198,  # \\n\n                151668,  # /think\n                271,  # \\n\n                34084,  # welcome\n                151645,  # im_end\n                198,  # \\n\n            ]\n            # fmt: on\n            assert input_ids == expected_input_ids, (\n                f\"Input IDs mismatch: {input_ids} != {expected_input_ids}\"\n            )\n"
  },
  {
    "path": "tests/prompt_strategies/test_chat_templates_tool_call_string_arguments.py",
    "content": "\"\"\"\nTests for handling json tool content\n\"\"\"\n\nimport json\n\nimport pytest\nfrom datasets import Dataset\n\nfrom axolotl.prompt_strategies.chat_template import (\n    load,\n)\nfrom axolotl.utils.dict import DictDefault\n\n\n@pytest.fixture(name=\"qwen3_instruct_prompt_strategy\")\ndef qwen3_instruct_chat_template_strategy(qwen3_tokenizer):\n    strategy = load(\n        qwen3_tokenizer,\n        DictDefault(\n            {\n                \"train_on_inputs\": False,\n                \"sequence_len\": 512,\n            }\n        ),\n        DictDefault(\n            {\n                \"chat_template\": \"qwen3\",\n                \"message_field_role\": \"role\",\n                \"message_field_content\": \"content\",\n                \"message_property_mappings\": {\n                    \"role\": \"role\",\n                    \"content\": \"content\",\n                },\n                \"roles\": {\n                    \"user\": [\"user\"],\n                    \"assistant\": [\"assistant\"],\n                    \"system\": [\"system\"],\n                },\n                \"field_messages\": \"messages\",\n            }\n        ),\n    )\n    return strategy\n\n\nclass TestQwen3IdenticalConversationArgs:\n    \"\"\"\n    Test Qwen3 tools is identical between JSON and dict\n    \"\"\"\n\n    @pytest.fixture(name=\"conversation_dict_args_dataset\")\n    def fixture_conversation_dict_args_dataset(self):\n        \"\"\"\n        Provides a dataset with conversation where arguments is a dict.\n        \"\"\"\n        user_content = \"What is the weather in Boston?\"\n        function_name = \"get_current_weather\"\n        arguments_dict = {\"location\": \"Boston, MA\", \"unit\": \"celsius\"}\n\n        data = [\n            {\n                \"messages\": [\n                    {\"role\": \"user\", \"content\": user_content},\n                    {\n                        \"role\": \"assistant\",\n                        \"content\": \"\",\n                        \"tool_calls\": [\n                            {\n                                \"function\": {\n                                    \"name\": function_name,\n                                    \"arguments\": arguments_dict,  # dict\n                                }\n                            }\n                        ],\n                    },\n                ],\n            }\n        ]\n        return Dataset.from_list(data)\n\n    @pytest.fixture(name=\"conversation_str_args_dataset\")\n    def fixture_conversation_str_args_dataset(self):\n        \"\"\"\n        Provides a dataset with conversation where arguments is a JSON string.\n        \"\"\"\n        user_content = \"What is the weather in Boston?\"\n        function_name = \"get_current_weather\"\n        arguments_dict = {\"location\": \"Boston, MA\", \"unit\": \"celsius\"}\n        arguments_str = json.dumps(arguments_dict)\n\n        data = [\n            {\n                \"messages\": [\n                    {\"role\": \"user\", \"content\": user_content},\n                    {\n                        \"role\": \"assistant\",\n                        \"content\": \"\",\n                        \"tool_calls\": [\n                            {\n                                \"function\": {\n                                    \"name\": function_name,\n                                    \"arguments\": arguments_str,  # str\n                                }\n                            }\n                        ],\n                    },\n                ],\n            }\n        ]\n        return Dataset.from_list(data)\n\n    @pytest.fixture(name=\"conversation_mixed_time_types_dataset\")\n    def fixture_conversation_mixed_time_types_dataset(self):\n        \"\"\"\n        Provides a dataset where 'time' field has different types in different tool calls.\n        \"\"\"\n        data = [\n            {\n                \"messages\": [\n                    {\n                        \"role\": \"user\",\n                        \"content\": \"Get weather information at different times\",\n                    },\n                    {\n                        \"role\": \"assistant\",\n                        \"content\": \"\",\n                        \"tool_calls\": [\n                            {\n                                \"function\": {\n                                    \"name\": \"func1\",\n                                    \"arguments\": json.dumps(\n                                        {\"time\": \"2025-08-01\"}\n                                    ),  # string type\n                                }\n                            },\n                            {\n                                \"function\": {\n                                    \"name\": \"func2\",\n                                    \"arguments\": json.dumps(\n                                        {\"time\": 1690876800}\n                                    ),  # number type\n                                }\n                            },\n                        ],\n                    },\n                ],\n            }\n        ]\n        return Dataset.from_list(data)\n\n    def test_dict_and_str_args_produce_identical_output(\n        self,\n        conversation_dict_args_dataset,\n        conversation_str_args_dataset,\n        qwen3_instruct_prompt_strategy,\n        qwen3_tokenizer,\n    ):\n        \"\"\"\n        Tests that after tokenization and decoding, the outputs for both\n        dict and string `arguments` are exactly the same.\n        \"\"\"\n        processed_dict_args = conversation_dict_args_dataset.map(\n            qwen3_instruct_prompt_strategy.tokenize_prompt,\n            batched=True,\n            remove_columns=[\"messages\"],\n        )\n\n        processed_str_args = conversation_str_args_dataset.map(\n            qwen3_instruct_prompt_strategy.tokenize_prompt,\n            batched=True,\n            remove_columns=[\"messages\"],\n        )\n\n        decoded_prompt_from_dict = qwen3_tokenizer.decode(\n            processed_dict_args[0][\"input_ids\"]\n        )\n\n        decoded_prompt_from_str = qwen3_tokenizer.decode(\n            processed_str_args[0][\"input_ids\"]\n        )\n\n        assert decoded_prompt_from_dict == decoded_prompt_from_str, (\n            f\"Dict format output:\\n{decoded_prompt_from_dict}\\n\"\n            f\"String format output:\\n{decoded_prompt_from_str}\"\n        )\n\n        assert (\n            processed_dict_args[0][\"input_ids\"] == processed_str_args[0][\"input_ids\"]\n        ), \"The tokenized input_ids should be identical for dict and str arguments\"\n\n    def test_str_args_with_mixed_time_types_no_error(\n        self,\n        conversation_mixed_time_types_dataset,\n        qwen3_instruct_prompt_strategy,\n        qwen3_tokenizer,\n    ):\n        \"\"\"\n        Tests that when 'time' field has different types (string vs number)\n        in different tool calls, str format arguments don't cause errors.\n        \"\"\"\n        processed = conversation_mixed_time_types_dataset.map(\n            qwen3_instruct_prompt_strategy.tokenize_prompt,\n            batched=True,\n            remove_columns=[\"messages\"],\n        )\n\n        assert len(processed) == 1\n        assert \"input_ids\" in processed[0]\n        assert len(processed[0][\"input_ids\"]) > 0\n\n        decoded = qwen3_tokenizer.decode(processed[0][\"input_ids\"])\n        assert \"2025-08-01\" in decoded, \"String time value should be present\"\n        assert \"1690876800\" in decoded, \"Number time value should be present\"\n\n\nclass TestQwen3IdenticalToolsParameters:\n    \"\"\"\n    Test Qwen3 tools parameters handling is identical between JSON string and dict\n    \"\"\"\n\n    @pytest.fixture(name=\"tools_dict_params_dataset\")\n    def fixture_tools_dict_params_dataset(self):\n        \"\"\"\n        Provides a dataset with tools where parameters is a dict.\n        \"\"\"\n        tools = [\n            {\n                \"type\": \"function\",\n                \"function\": {\n                    \"name\": \"get_weather\",\n                    \"description\": \"Get weather information\",\n                    \"parameters\": {\n                        \"type\": \"object\",\n                        \"properties\": {\n                            \"location\": {\n                                \"type\": \"string\",\n                                \"description\": \"The city and state\",\n                            },\n                            \"unit\": {\n                                \"type\": \"string\",\n                                \"enum\": [\"celsius\", \"fahrenheit\"],\n                            },\n                        },\n                        \"required\": [\"location\"],\n                    },\n                },\n            }\n        ]\n\n        data = [\n            {\n                \"tools\": tools,\n                \"messages\": [\n                    {\"role\": \"user\", \"content\": \"What's the weather?\"},\n                    {\n                        \"role\": \"assistant\",\n                        \"content\": \"\",\n                        \"tool_calls\": [\n                            {\n                                \"type\": \"function\",\n                                \"function\": {\n                                    \"name\": \"get_weather\",\n                                    \"arguments\": {\"location\": \"Boston, MA\"},\n                                },\n                            }\n                        ],\n                    },\n                    {\n                        \"role\": \"tool\",\n                        \"name\": \"get_weather\",\n                        \"content\": \"72°F and sunny\",\n                    },\n                ],\n            }\n        ]\n        return Dataset.from_list(data)\n\n    @pytest.fixture(name=\"tools_str_params_dataset\")\n    def fixture_tools_str_params_dataset(self):\n        \"\"\"\n        Provides a dataset with tools where parameters is a JSON string.\n        \"\"\"\n        parameters_dict = {\n            \"type\": \"object\",\n            \"properties\": {\n                \"location\": {\"type\": \"string\", \"description\": \"The city and state\"},\n                \"unit\": {\"type\": \"string\", \"enum\": [\"celsius\", \"fahrenheit\"]},\n            },\n            \"required\": [\"location\"],\n        }\n\n        tools = [\n            {\n                \"type\": \"function\",\n                \"function\": {\n                    \"name\": \"get_weather\",\n                    \"description\": \"Get weather information\",\n                    \"parameters\": json.dumps(parameters_dict),\n                },\n            }\n        ]\n\n        data = [\n            {\n                \"tools\": tools,\n                \"messages\": [\n                    {\"role\": \"user\", \"content\": \"What's the weather?\"},\n                    {\n                        \"role\": \"assistant\",\n                        \"content\": \"\",\n                        \"tool_calls\": [\n                            {\n                                \"type\": \"function\",\n                                \"function\": {\n                                    \"name\": \"get_weather\",\n                                    \"arguments\": {\"location\": \"Boston, MA\"},\n                                },\n                            }\n                        ],\n                    },\n                    {\n                        \"role\": \"tool\",\n                        \"name\": \"get_weather\",\n                        \"content\": \"72°F and sunny\",\n                    },\n                ],\n            }\n        ]\n        return Dataset.from_list(data)\n\n    @pytest.fixture(name=\"tools_mixed_type_params_dataset\")\n    def fixture_tools_mixed_type_params_dataset(self):\n        \"\"\"\n        Provides a dataset where different tools have the same parameter name with different types.\n        This tests that JSON string format prevents casting issues.\n        \"\"\"\n        tools = [\n            {\n                \"type\": \"function\",\n                \"function\": {\n                    \"name\": \"tool_with_string_arg\",\n                    \"description\": \"Tool expecting string argument\",\n                    \"parameters\": json.dumps(\n                        {\n                            \"type\": \"object\",\n                            \"properties\": {\n                                \"arg1\": {\n                                    \"type\": \"string\",\n                                    \"description\": \"A string parameter\",\n                                }\n                            },\n                            \"required\": [\"arg1\"],\n                        }\n                    ),\n                },\n            },\n            {\n                \"type\": \"function\",\n                \"function\": {\n                    \"name\": \"tool_with_number_arg\",\n                    \"description\": \"Tool expecting number argument\",\n                    \"parameters\": json.dumps(\n                        {\n                            \"type\": \"object\",\n                            \"properties\": {\n                                \"arg1\": {\n                                    \"type\": \"number\",\n                                    \"description\": \"A numeric parameter\",\n                                }\n                            },\n                            \"required\": [\"arg1\"],\n                        }\n                    ),\n                },\n            },\n        ]\n\n        data = [\n            {\n                \"tools\": tools,\n                \"messages\": [\n                    {\"role\": \"user\", \"content\": \"Use both tools\"},\n                    {\n                        \"role\": \"assistant\",\n                        \"content\": \"\",\n                        \"tool_calls\": [\n                            {\n                                \"type\": \"function\",\n                                \"function\": {\n                                    \"name\": \"tool_with_string_arg\",\n                                    \"arguments\": json.dumps({\"arg1\": \"hello\"}),\n                                },\n                            },\n                            {\n                                \"type\": \"function\",\n                                \"function\": {\n                                    \"name\": \"tool_with_number_arg\",\n                                    \"arguments\": json.dumps({\"arg1\": 42}),\n                                },\n                            },\n                        ],\n                    },\n                ],\n            }\n        ]\n        return Dataset.from_list(data)\n\n    def test_dict_and_str_params_produce_equivalent_output(\n        self,\n        tools_dict_params_dataset,\n        tools_str_params_dataset,\n        qwen3_instruct_prompt_strategy,\n        qwen3_tokenizer,\n    ):\n        \"\"\"\n        Tests that after tokenization and decoding, the outputs for both\n        dict and string `parameters` in tools are semantically equivalent.\n        \"\"\"\n        import re\n\n        processed_dict_params = tools_dict_params_dataset.map(\n            qwen3_instruct_prompt_strategy.tokenize_prompt,\n            batched=True,\n            remove_columns=[\"messages\", \"tools\"],\n        )\n\n        processed_str_params = tools_str_params_dataset.map(\n            qwen3_instruct_prompt_strategy.tokenize_prompt,\n            batched=True,\n            remove_columns=[\"messages\", \"tools\"],\n        )\n\n        decoded_dict = qwen3_tokenizer.decode(processed_dict_params[0][\"input_ids\"])\n        decoded_str = qwen3_tokenizer.decode(processed_str_params[0][\"input_ids\"])\n\n        # Extract the tool JSON from both outputs\n        tools_pattern = r\"<tools>\\n(.*?)\\n</tools>\"\n\n        dict_tools_match = re.search(tools_pattern, decoded_dict, re.DOTALL)\n        str_tools_match = re.search(tools_pattern, decoded_str, re.DOTALL)\n\n        assert dict_tools_match and str_tools_match, (\n            \"Could not find tools section in output\"\n        )\n\n        # Parse the JSON and compare as objects (order-independent)\n        dict_tools_json = json.loads(dict_tools_match.group(1))\n        str_tools_json = json.loads(str_tools_match.group(1))\n\n        # Deep comparison of the tool definitions\n        assert dict_tools_json == str_tools_json, (\n            f\"Tool definitions are not equivalent:\\n\"\n            f\"Dict format: {json.dumps(dict_tools_json, indent=2)}\\n\"\n            f\"String format: {json.dumps(str_tools_json, indent=2)}\"\n        )\n\n        # Verify the rest of the structure is the same (excluding the tools JSON part)\n        # The tools JSON can have different order, so we remove it here.\n        dict_normalized = re.sub(\n            r\"<tools>.*?</tools>\",\n            \"<tools>TOOLS_PLACEHOLDER</tools>\",\n            decoded_dict,\n            flags=re.DOTALL,\n        )\n        str_normalized = re.sub(\n            r\"<tools>.*?</tools>\",\n            \"<tools>TOOLS_PLACEHOLDER</tools>\",\n            decoded_str,\n            flags=re.DOTALL,\n        )\n\n        assert dict_normalized == str_normalized, (\n            \"The overall structure differs between dict and string parameter formats\"\n        )\n\n    def test_str_params_with_mixed_types_no_error(\n        self,\n        tools_mixed_type_params_dataset,\n        qwen3_instruct_prompt_strategy,\n        qwen3_tokenizer,\n    ):\n        \"\"\"\n        Tests that when different tools have the same parameter name with different types,\n        JSON string format for parameters doesn't cause casting errors.\n        \"\"\"\n        processed = tools_mixed_type_params_dataset.map(\n            qwen3_instruct_prompt_strategy.tokenize_prompt,\n            batched=True,\n            remove_columns=[\"messages\", \"tools\"],\n        )\n\n        assert len(processed) == 1\n        assert \"input_ids\" in processed[0]\n        assert len(processed[0][\"input_ids\"]) > 0\n\n        decoded = qwen3_tokenizer.decode(processed[0][\"input_ids\"])\n\n        # Check that both tools are present\n        assert \"tool_with_string_arg\" in decoded\n        assert \"tool_with_number_arg\" in decoded\n\n        # Check that both argument values are present\n        assert \"hello\" in decoded\n        assert \"42\" in decoded\n"
  },
  {
    "path": "tests/prompt_strategies/test_dpo_chat_templates.py",
    "content": "\"\"\"\ntests for chat_template prompt strategy\n\"\"\"\n\nimport unittest\n\nimport pytest\nfrom datasets import Dataset\nfrom transformers import AutoTokenizer\n\nfrom axolotl.prompt_strategies.dpo.chat_template import argilla_chat, default\nfrom axolotl.utils.dict import DictDefault\n\nfrom tests.hf_offline_utils import enable_hf_offline\n\n\n@pytest.fixture(name=\"assistant_dataset\")\ndef fixture_assistant_dataset():\n    return Dataset.from_list(\n        [\n            {\n                \"messages\": [\n                    {\n                        \"role\": \"user\",\n                        \"content\": \"hello\",\n                    },\n                    {\n                        \"role\": \"assistant\",\n                        \"content\": \"hello\",\n                    },\n                    {\n                        \"role\": \"user\",\n                        \"content\": \"goodbye\",\n                    },\n                ],\n                \"chosen\": {\n                    \"role\": \"assistant\",\n                    \"content\": \"goodbye\",\n                },\n                \"rejected\": {\n                    \"role\": \"assistant\",\n                    \"content\": \"party on\",\n                },\n            }\n        ]\n    )\n\n\n@pytest.fixture(name=\"custom_assistant_dataset\")\ndef fixture_custom_assistant_dataset():\n    return Dataset.from_list(\n        [\n            {\n                \"conversation\": [\n                    {\n                        \"speaker\": \"human\",\n                        \"text\": \"hello\",\n                    },\n                    {\n                        \"speaker\": \"agent\",\n                        \"text\": \"hello\",\n                    },\n                    {\n                        \"speaker\": \"human\",\n                        \"text\": \"goodbye\",\n                    },\n                ],\n                \"better\": {\n                    \"speaker\": \"agent\",\n                    \"text\": \"goodbye\",\n                },\n                \"worse\": {\n                    \"speaker\": \"agent\",\n                    \"text\": \"party on\",\n                },\n            }\n        ]\n    )\n\n\n@pytest.fixture(name=\"argilla_chat_dataset\")\ndef fixture_argilla_chat_dataset():\n    return Dataset.from_list(\n        [\n            {\n                \"chosen\": [\n                    {\n                        \"role\": \"user\",\n                        \"content\": \"hello\",\n                    },\n                    {\n                        \"role\": \"assistant\",\n                        \"content\": \"goodbye\",\n                    },\n                ],\n                \"rejected\": [\n                    {\n                        \"role\": \"user\",\n                        \"content\": \"hello\",\n                    },\n                    {\n                        \"role\": \"assistant\",\n                        \"content\": \"party on\",\n                    },\n                ],\n            }\n        ]\n    )\n\n\n@pytest.fixture(name=\"phi3_tokenizer\")\n@enable_hf_offline\ndef fixture_phi3_tokenizer():\n    tokenizer = AutoTokenizer.from_pretrained(\"microsoft/Phi-3-medium-128k-instruct\")\n\n    return tokenizer\n\n\n@pytest.fixture(name=\"gemma_tokenizer\")\n@enable_hf_offline\ndef fixture_gemma_tokenizer():\n    tokenizer = AutoTokenizer.from_pretrained(\"unsloth/gemma-2b-it\", revision=\"703fb4a\")\n\n    return tokenizer\n\n\nclass TestAssistantDPOChatTemplateLlama3:\n    \"\"\"\n    Test class for assistant style datasets with llama-3 prompts using the chat_template strategy.\n    \"\"\"\n\n    def test_llama3_defaults(self, llama3_tokenizer, assistant_dataset):\n        transform_fn, _ = default(\n            DictDefault(\n                {\n                    \"chat_template\": \"llama3\",\n                    \"datasets\": [\n                        {\n                            \"type\": \"chat_template\",\n                        }\n                    ],\n                }\n            )\n        )\n        result = transform_fn(assistant_dataset[0], tokenizer=llama3_tokenizer)\n        assert result[\"prompt\"] == (\n            \"<|begin_of_text|>\"\n            + \"<|start_header_id|>user<|end_header_id|>\\n\\nhello<|eot_id|>\"\n            + \"<|start_header_id|>assistant<|end_header_id|>\\n\\nhello<|eot_id|>\"\n            + \"<|start_header_id|>user<|end_header_id|>\\n\\ngoodbye<|eot_id|>\"\n            + \"<|start_header_id|>assistant<|end_header_id|>\\n\\n\"\n        )\n        assert result[\"chosen\"] == \"goodbye<|eot_id|>\"\n        assert result[\"rejected\"] == \"party on<|eot_id|>\"\n\n    def test_llama3_configured(self, llama3_tokenizer, custom_assistant_dataset):\n        transform_fn, _ = default(\n            DictDefault(\n                {\n                    \"chat_template\": \"llama3\",\n                    \"datasets\": [\n                        {\n                            \"type\": \"chat_template\",\n                            \"field_messages\": \"conversation\",\n                            \"field_chosen\": \"better\",\n                            \"field_rejected\": \"worse\",\n                            \"message_field_role\": \"speaker\",\n                            \"message_field_content\": \"text\",\n                            \"roles\": {\n                                \"user\": [\"human\"],\n                                \"assistant\": [\"agent\"],\n                                \"system\": [\"sys\"],\n                            },\n                        }\n                    ],\n                }\n            )\n        )\n        result = transform_fn(custom_assistant_dataset[0], tokenizer=llama3_tokenizer)\n        assert result[\"prompt\"] == (\n            \"<|begin_of_text|>\"\n            + \"<|start_header_id|>user<|end_header_id|>\\n\\nhello<|eot_id|>\"\n            + \"<|start_header_id|>assistant<|end_header_id|>\\n\\nhello<|eot_id|>\"\n            + \"<|start_header_id|>user<|end_header_id|>\\n\\ngoodbye<|eot_id|>\"\n            + \"<|start_header_id|>assistant<|end_header_id|>\\n\\n\"\n        )\n        assert result[\"chosen\"] == \"goodbye<|eot_id|>\"\n        assert result[\"rejected\"] == \"party on<|eot_id|>\"\n\n\nclass TestAssistantDPOChatTemplatePhi3:\n    \"\"\"\n    Test class for assistant style datasets with phi-3 prompts using the tokenizer's chat_template strategy.\n    \"\"\"\n\n    def test_phi3_defaults(self, phi3_tokenizer, assistant_dataset):\n        transform_fn, _ = default(\n            DictDefault(\n                {\n                    \"chat_template\": \"tokenizer_default\",\n                    \"datasets\": [\n                        {\n                            \"type\": \"chat_template\",\n                        }\n                    ],\n                }\n            )\n        )\n        result = transform_fn(assistant_dataset[0], tokenizer=phi3_tokenizer)\n        assert result[\"prompt\"] == (\n            \"<|user|>\\nhello<|end|>\\n\"\n            + \"<|assistant|>\\nhello<|end|>\\n\"\n            + \"<|user|>\\ngoodbye<|end|>\\n\"\n            + \"<|assistant|>\\n\"\n        )\n        assert result[\"chosen\"] == \"goodbye<|end|>\"\n        assert result[\"rejected\"] == \"party on<|end|>\"\n\n\nclass TestAssistantDPOChatTemplateGemma:\n    \"\"\"\n    Test class for assistant style datasets with gemma prompts using the tokenizer's chat_template strategy.\n    \"\"\"\n\n    def test_gemma_defaults(self, gemma_tokenizer, assistant_dataset):\n        transform_fn, _ = default(\n            DictDefault(\n                {\n                    \"chat_template\": \"tokenizer_default\",\n                    \"datasets\": [\n                        {\n                            \"type\": \"chat_template\",\n                        }\n                    ],\n                }\n            )\n        )\n        result = transform_fn(assistant_dataset[0], tokenizer=gemma_tokenizer)\n        assert result[\"prompt\"] == (\n            \"<bos><start_of_turn>user\\nhello<end_of_turn>\\n\"\n            + \"<start_of_turn>model\\nhello<end_of_turn>\\n\"\n            + \"<start_of_turn>user\\ngoodbye<end_of_turn>\\n\"\n            + \"<start_of_turn>model\\n\"\n        )\n        assert result[\"chosen\"] == \"goodbye<end_of_turn>\"\n        assert result[\"rejected\"] == \"party on<end_of_turn>\"\n\n\nclass TestArgillaChatDPOChatTemplate:\n    \"\"\"\n    Test class for argilla_chat style datasets (chosen/rejected contain full conversations).\n    \"\"\"\n\n    def test_llama3_argilla_chat(self, llama3_tokenizer, argilla_chat_dataset):\n        transform_fn, _ = argilla_chat(\n            DictDefault(\n                {\n                    \"chat_template\": \"llama3\",\n                    \"datasets\": [\n                        {\n                            \"type\": \"chat_template.argilla_chat\",\n                        }\n                    ],\n                }\n            )\n        )\n        result = transform_fn(argilla_chat_dataset[0], tokenizer=llama3_tokenizer)\n        assert result[\"prompt\"] == (\n            \"<|begin_of_text|>\"\n            + \"<|start_header_id|>user<|end_header_id|>\\n\\nhello<|eot_id|>\"\n            + \"<|start_header_id|>assistant<|end_header_id|>\\n\\n\"\n        )\n        assert result[\"chosen\"] == \"goodbye<|eot_id|>\"\n        assert result[\"rejected\"] == \"party on<|eot_id|>\"\n\n    def test_phi3_argilla_chat(self, phi3_tokenizer, argilla_chat_dataset):\n        transform_fn, _ = argilla_chat(\n            DictDefault(\n                {\n                    \"chat_template\": \"tokenizer_default\",\n                    \"datasets\": [\n                        {\n                            \"type\": \"chat_template.argilla_chat\",\n                        }\n                    ],\n                }\n            )\n        )\n        result = transform_fn(argilla_chat_dataset[0], tokenizer=phi3_tokenizer)\n        assert result[\"prompt\"] == \"<|user|>\\nhello<|end|>\\n\" + \"<|assistant|>\\n\"\n        assert result[\"chosen\"] == \"goodbye<|end|>\"\n        assert result[\"rejected\"] == \"party on<|end|>\"\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/prompt_strategies/test_dpo_chatml.py",
    "content": "\"\"\"\nTests for loading DPO preference datasets with chatml formatting\n\"\"\"\n\nimport unittest\n\nimport pytest\n\nfrom axolotl.loaders.tokenizer import load_tokenizer\nfrom axolotl.prompt_strategies.dpo import load as load_dpo\nfrom axolotl.utils.data.rl import prepare_preference_datasets\nfrom axolotl.utils.dict import DictDefault\n\nfrom tests.hf_offline_utils import enable_hf_offline\n\n\n@pytest.fixture(name=\"minimal_dpo_cfg\")\ndef fixture_cfg():\n    return DictDefault(\n        {\n            \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n            \"tokenizer_config\": \"HuggingFaceTB/SmolLM2-135M\",\n            \"rl\": \"dpo\",\n            \"learning_rate\": 0.000001,\n            \"micro_batch_size\": 1,\n            \"gradient_accumulation_steps\": 1,\n            \"special_tokens\": {\n                \"pad_token\": \"<|endoftext|>\",\n            },\n            \"sequence_len\": 2048,\n        }\n    )\n\n\nclass TestDPOChatml:\n    \"\"\"\n    Test loading DPO preference datasets with chatml formatting\n    \"\"\"\n\n    @pytest.mark.skip(reason=\"TODO: fix hf hub offline to work with HF rate limits\")\n    @enable_hf_offline\n    def test_default(self, minimal_dpo_cfg):\n        cfg = DictDefault(\n            {\n                \"datasets\": [\n                    {\n                        \"path\": \"argilla/distilabel-intel-orca-dpo-pairs\",\n                        \"type\": \"chatml\",\n                        \"split\": \"train[:1%]\",\n                    }\n                ]\n            }\n            | minimal_dpo_cfg\n        )\n\n        # test that dpo.load works\n        load_dpo(\"chatml\", cfg)\n        # now actually load the datasets with the strategy\n        tokenizer = load_tokenizer(cfg)\n        train_ds, _ = prepare_preference_datasets(cfg, tokenizer)\n        assert train_ds[0][\"prompt\"].startswith(\"<|im_start|>\")\n        assert train_ds[0][\"prompt\"].endswith(\"<|im_start|>assistant\\n\")\n        assert \"chosen\" in train_ds[0]\n        assert \"rejected\" in train_ds[0]\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/prompt_strategies/test_jinja_template_analyzer.py",
    "content": "\"\"\"\ntests for jinja_template_analyzer\n\"\"\"\n\nimport pytest\n\nfrom axolotl.prompt_strategies.jinja_template_analyzer import JinjaTemplateAnalyzer\nfrom axolotl.utils.logging import get_logger\n\nLOG = get_logger(__name__, log_level=\"DEBUG\")\n\n\nclass TestJinjaTemplateAnalyzer:\n    \"\"\"\n    tests for jinja_template_analyzer\n    \"\"\"\n\n    def test_basic_variable_extraction(self, basic_jinja_template_analyzer):\n        \"\"\"Test that all top-level variables are correctly extracted.\"\"\"\n        LOG.info(\"Testing with train_on_inputs=True\")\n\n        variables = basic_jinja_template_analyzer.get_template_variables()\n        expected_vars = {\"messages\", \"add_generation_prompt\", \"eos_token\", \"message\"}\n        assert set(variables.keys()) == expected_vars\n\n    def test_mixtral_variable_extraction(self, mistral_jinja_template_analyzer):\n        \"\"\"Test that all top-level variables are correctly extracted.\"\"\"\n        LOG.info(\"Testing with train_on_inputs=True\")\n\n        variables = mistral_jinja_template_analyzer.get_template_variables()\n        expected_vars = {\n            \"messages\",\n            \"content\",\n            \"eos_token\",\n            \"message\",\n            \"tools\",\n            \"system_message\",\n            \"loop_messages\",\n            \"ns\",\n            \"tool_call\",\n            \"tool\",\n            \"loop\",\n            \"bos_token\",\n            \"raise_exception\",\n        }\n        assert set(variables.keys()) == expected_vars\n        message_vars = variables[\"message\"]\n        assert message_vars == {\"role\", \"content\", \"tool_calls\", \"tool_call_id\"}\n\n    def test_message_property_access(self, basic_jinja_template_analyzer):\n        \"\"\"Test that properties accessed on 'message' variable are correctly identified.\"\"\"\n        LOG.info(\"Testing message property access\")\n\n        variables = basic_jinja_template_analyzer.get_template_variables()\n        assert \"messages\" in variables\n        assert \"message\" in variables\n        assert \"role\" in variables[\"message\"]\n        assert \"content\" in variables[\"message\"]\n\n    def test_detailed_analysis(self, basic_jinja_template_analyzer):\n        \"\"\"Test the detailed analysis of variable usage.\"\"\"\n        LOG.info(\"Testing detailed analysis\")\n\n        analysis = basic_jinja_template_analyzer.analyze_template()\n\n        assert analysis[\"messages\"][\"is_iterated\"] is True\n        assert \"role\" in analysis[\"message\"][\"accessed_properties\"]\n        assert \"content\" in analysis[\"message\"][\"accessed_properties\"]\n\n        assert analysis[\"add_generation_prompt\"][\"is_conditional\"] is True\n        assert len(analysis[\"add_generation_prompt\"][\"accessed_properties\"]) == 0\n\n        assert not analysis[\"eos_token\"][\"is_iterated\"]\n        assert len(analysis[\"eos_token\"][\"accessed_properties\"]) == 0\n\n    def test_nested_property_access(self):\n        \"\"\"Test handling of nested property access.\"\"\"\n        LOG.info(\"Testing nested property access\")\n\n        template = \"\"\"{{ user.profile.name }}{{ user.settings['preference'] }}\"\"\"\n        analyzer = JinjaTemplateAnalyzer(template)\n        variables = analyzer.get_template_variables()\n\n        assert \"user\" in variables\n        assert \"profile\" in variables[\"user\"]\n        assert \"settings\" in variables[\"user\"]\n\n    def test_loop_variable_handling(self):\n        \"\"\"Test handling of loop variables and their properties.\"\"\"\n        LOG.info(\"Testing loop variable handling\")\n\n        template = \"\"\"\n        {% for item in items %}\n            {{ item.name }}\n            {% for subitem in item.subitems %}\n                {{ subitem.value }}\n            {% endfor %}\n        {% endfor %}\n        \"\"\"\n        analyzer = JinjaTemplateAnalyzer(template)\n        analysis = analyzer.analyze_template()\n\n        assert analysis[\"items\"][\"is_iterated\"]\n        assert \"name\" in analysis[\"item\"][\"accessed_properties\"]\n        assert \"subitems\" in analysis[\"item\"][\"accessed_properties\"]\n\n    def test_conditional_variable_usage(self):\n        \"\"\"Test detection of variables used in conditional statements.\"\"\"\n        LOG.info(\"Testing conditional variable usage\")\n\n        template = \"\"\"\n        {% if user.is_admin and config.debug_mode %}\n            {{ debug_info }}\n        {% endif %}\n        \"\"\"\n        analyzer = JinjaTemplateAnalyzer(template)\n        analysis = analyzer.analyze_template()\n\n        assert analysis[\"user\"][\"is_conditional\"]\n        assert analysis[\"config\"][\"is_conditional\"]\n        assert \"is_admin\" in analysis[\"user\"][\"accessed_properties\"]\n        assert \"debug_mode\" in analysis[\"config\"][\"accessed_properties\"]\n\n    def test_complex_expressions(self):\n        \"\"\"Test handling of complex expressions and filters.\"\"\"\n        LOG.info(\"Testing complex expressions and filters\")\n\n        template = \"\"\"\n        {{ user.name | upper }}\n        {{ messages | length > 0 and messages[0].content }}\n        {{ data['key'].nested['value'] }}\n        \"\"\"\n        analyzer = JinjaTemplateAnalyzer(template)\n        variables = analyzer.get_template_variables()\n\n        assert \"user\" in variables\n        assert \"name\" in variables[\"user\"]\n        assert \"messages\" in variables\n        assert \"content\" in variables[\"messages\"]\n        assert \"data\" in variables\n\n    def test_basic_msg_vars(self, basic_jinja_template_analyzer):\n        \"\"\"Test that the basic message variables are correctly identified.\"\"\"\n        LOG.info(\"Testing basic message variables\")\n\n        variables = basic_jinja_template_analyzer.get_message_vars()\n        assert variables == {\"role\", \"content\"}\n\n    def test_mixtral_msg_vars(self, mistral_jinja_template_analyzer):\n        \"\"\"Test that the mixtral message variables are correctly identified.\"\"\"\n        LOG.info(\"Testing mixtral message variables\")\n\n        variables = mistral_jinja_template_analyzer.get_message_vars()\n        assert variables == {\"role\", \"content\", \"tool_calls\", \"tool_call_id\"}\n\n\nif __name__ == \"__main__\":\n    pytest.main([__file__])\n"
  },
  {
    "path": "tests/prompt_strategies/test_raw_io.py",
    "content": "\"\"\"\nTest module for raw i/o data for prompts\n\"\"\"\n\nimport pytest\nfrom datasets import Dataset\nfrom tokenizers import AddedToken\nfrom transformers import AutoTokenizer\n\nfrom axolotl.datasets import TokenizedPromptDataset\nfrom axolotl.prompt_strategies.input_output import (\n    RawInputOutputPrompter,\n    RawInputOutputStrategy,\n)\n\n\n@pytest.fixture(name=\"segments_dataset\")\ndef fixture_sharegpt_dataset():\n    return Dataset.from_list(\n        [\n            {\n                \"segments\": [\n                    {\n                        \"label\": False,\n                        \"text\": \"<s>hello \",\n                    },\n                    {\n                        \"label\": True,\n                        \"text\": \"hi there.<eot>\",\n                    },\n                    {\n                        \"label\": False,\n                        \"text\": \"goodbye \",\n                    },\n                    {\n                        \"label\": True,\n                        \"text\": \"farewell<eot>\",\n                    },\n                ]\n            }\n        ]\n    )\n\n\n@pytest.fixture(name=\"tokenizer\")\ndef fixture_tokenizer():\n    tokenizer = AutoTokenizer.from_pretrained(\n        \"casperhansen/mistral-7b-instruct-v0.1-awq\"\n    )\n    tokenizer.add_tokens(\n        [\n            AddedToken(\"<eot>\", rstrip=False, lstrip=False, normalized=False),\n        ]\n    )\n\n    return tokenizer\n\n\nclass TestRawInputOutputPrompts:\n    \"\"\"\n    Test class for raw i/o prompter\n    \"\"\"\n\n    def test_segment_prompts(self, segments_dataset, tokenizer):\n        strategy = RawInputOutputStrategy(\n            RawInputOutputPrompter(),\n            tokenizer,\n            False,  # train_on_inputs\n            2048,  # sequence_len\n        )\n\n        dataset_wrapper = TokenizedPromptDataset(\n            strategy, segments_dataset, process_count=1\n        )\n\n        input_ids = dataset_wrapper[0][\"input_ids\"]\n        labels = dataset_wrapper[0][\"labels\"]\n\n        assert (\n            tokenizer.decode(input_ids)\n            == \"<s> hello  hi there.<eot> goodbye  farewell<eot>\"\n        )\n        # fmt: off\n        assert input_ids == [\n            1,  # <s>\n            6312,  # hell\n            28709,  # o\n            28705,  #\n            12014,  # hi\n            736,  # there\n            28723,  # .\n            32000,  # <eot>\n            1179,  # good\n            17664,  # bye\n            28705,  #\n            19111,  # fare\n            5458,  # well\n            32000,  # <eot>\n        ]\n        # fmt: on\n\n        # fmt: off\n        assert labels == [\n            -100,  # <s>\n            -100,  # hell\n            -100,  # o\n            -100,  #\n            12014,  # hi\n            736,  # there\n            28723,  # .\n            32000,  # <eot>\n            -100,  # good\n            -100,  # bye\n            -100,  #\n            19111,  # fare\n            5458,  # well\n            32000,  # <eot>\n        ]\n        # fmt: on\n"
  },
  {
    "path": "tests/prompt_strategies/test_stepwise.py",
    "content": "\"\"\"\ntests for chat_template prompt strategy\n\"\"\"\n\nimport datasets\nimport pytest\nfrom datasets import Dataset\nfrom transformers import AutoTokenizer\n\nfrom axolotl.datasets import TokenizedPromptDataset\nfrom axolotl.prompt_strategies.stepwise_supervised import (\n    StepwiseSupervisedPromptTokenizingStrategy,\n)\n\n\nclass TestStepWiseSupervisedPromptTokenizingStrategy:\n    \"\"\"\n    Test class for stepwise supervised prompt strategy\n    \"\"\"\n\n    @pytest.fixture()\n    def stepwise_supervised_dataset(self):\n        return Dataset.from_list(\n            [\n                {\n                    \"prompt\": \"Which number is larger, 9.8 or 9.11?\",\n                    \"completions\": [\n                        \"The fractional part of 9.8 is 0.8, while the fractional part of 9.11 is 0.11.\",\n                        \"Since 0.11 is greater than 0.8, the number 9.11 is larger than 9.8.\",\n                        \"Actually, this is incorrect. In decimal numbers, 0.8 is equal to 0.80, which is larger than 0.11. Therefore, 9.8 is larger than 9.11.\",\n                    ],\n                    \"labels\": [True, False, False],\n                }\n            ]\n        )\n\n    @pytest.fixture()\n    def tokenizer(self):\n        return AutoTokenizer.from_pretrained(\"Qwen/Qwen2.5-0.5B\")\n\n    def test_stepwise_supervised_dataset(self, tokenizer, stepwise_supervised_dataset):\n        strategy = StepwiseSupervisedPromptTokenizingStrategy(\n            tokenizer,\n            sequence_len=2048,\n            step_separator=\"\\n\",\n        )\n        stepwise_supervised_dataset = stepwise_supervised_dataset.cast_column(\n            \"labels\", datasets.Sequence(datasets.Value(\"int64\"))\n        )\n        dataset_wrapper = TokenizedPromptDataset(\n            strategy,\n            stepwise_supervised_dataset,\n            process_count=1,\n        )\n        labels = dataset_wrapper[0][\"labels\"]\n        # expected labels is:\n        # the prompt + first step are ignored, followed by the label for step 1 (True)\n        # the second step, and its label (False)\n        # the third step, and its label (False)\n        expected = [-100] * 47 + [1] + [-100] * 29 + [0] + [-100] * 48 + [0]\n\n        assert labels == expected\n"
  },
  {
    "path": "tests/telemetry/__init__.py",
    "content": ""
  },
  {
    "path": "tests/telemetry/conftest.py",
    "content": "\"\"\"Shared pytest fixtures for telemetry tests.\"\"\"\n\nimport pytest\n\n\n@pytest.fixture(autouse=True)\ndef del_track_env(monkeypatch):\n    monkeypatch.delenv(\"AXOLOTL_DO_NOT_TRACK\", raising=False)\n    yield\n"
  },
  {
    "path": "tests/telemetry/test_callbacks.py",
    "content": "\"\"\"Tests for telemetry callback module.\"\"\"\n\n# pylint: disable=redefined-outer-name\n\nimport time\nfrom unittest.mock import MagicMock, patch\n\nimport pytest\nfrom transformers import TrainerControl, TrainerState, TrainingArguments\n\nfrom axolotl.telemetry.callbacks import TIME_SINCE_LAST, TelemetryCallback\n\n\ndef calc_expected_metrics(step, last_step, current_time, last_time, start_time=900.0):\n    \"\"\"Calculate expected metrics values for tests\"\"\"\n    time_diff = current_time - last_time\n    step_diff = step - last_step\n    return {\n        \"steps_per_second\": (\n            step_diff / time_diff if time_diff > 0 and step_diff > 0 else 0\n        ),\n        \"time_since_last_report\": time_diff,\n        \"elapsed_time\": current_time - start_time,\n    }\n\n\n@pytest.fixture\ndef mock_time():\n    \"\"\"Mock time.time() to have predictable values in tests\"\"\"\n    with patch(\"axolotl.telemetry.callbacks.time\") as mock_time:\n        mock_time.time.return_value = 1000.0\n        yield mock_time\n\n\n@pytest.fixture\ndef mock_telemetry_manager():\n    \"\"\"Create a mock TelemetryManager\"\"\"\n    with patch(\"axolotl.telemetry.callbacks.TelemetryManager\") as mock_manager_class:\n        mock_manager = MagicMock()\n        mock_manager_class.get_instance.return_value = mock_manager\n        yield mock_manager\n\n\n@pytest.fixture\ndef mock_runtime_metrics_tracker():\n    \"\"\"Create a mock RuntimeMetricsTracker\"\"\"\n    with patch(\n        \"axolotl.telemetry.callbacks.RuntimeMetricsTracker\"\n    ) as mock_tracker_class:\n        mock_tracker = MagicMock()\n        # Set up metrics property on the tracker\n        mock_metrics = MagicMock()\n        mock_metrics.to_dict.return_value = {\n            \"total_steps\": 100,\n            \"peak_cpu_memory_bytes\": 1024,\n        }\n        mock_tracker.metrics = mock_metrics\n\n        # Make the constructor return our mock\n        mock_tracker_class.return_value = mock_tracker\n        yield mock_tracker\n\n\n@pytest.fixture\ndef training_args():\n    \"\"\"Create a minimal TrainingArguments instance\"\"\"\n    return TrainingArguments(output_dir=\"./output\")\n\n\n@pytest.fixture\ndef trainer_state():\n    \"\"\"Create a mock TrainerState\"\"\"\n    state = MagicMock(spec=TrainerState)\n    state.global_step = 10\n    state.epoch = 0.5  # halfway through first epoch\n    state.log_history = [{\"loss\": 2.5, \"learning_rate\": 5e-5}]\n    return state\n\n\n@pytest.fixture\ndef trainer_control():\n    \"\"\"Create a mock TrainerControl\"\"\"\n    return MagicMock(spec=TrainerControl)\n\n\n# pylint: disable=unused-argument\n@pytest.fixture\ndef callback(mock_telemetry_manager, mock_runtime_metrics_tracker):\n    \"\"\"Create a TelemetryCallback instance with mocked dependencies\"\"\"\n    return TelemetryCallback()\n\n\nclass TestTelemetryCallback:\n    \"\"\"Tests for the TelemetryCallback class.\"\"\"\n\n    def test_initialization(self, callback, mock_runtime_metrics_tracker):\n        \"\"\"Test callback initialization.\"\"\"\n        assert callback.current_epoch == -1\n        assert callback.tracker == mock_runtime_metrics_tracker\n        assert callback.last_report_step == 0\n        assert hasattr(callback, \"start_time\")\n        assert hasattr(callback, \"last_report_time\")\n        assert callback.report_interval_steps == 100\n\n    def test_on_train_begin(\n        self,\n        callback,\n        mock_telemetry_manager,\n        training_args,\n        trainer_state,\n        trainer_control,\n    ):\n        \"\"\"Test on_train_begin sends expected event.\"\"\"\n        callback.on_train_begin(training_args, trainer_state, trainer_control)\n\n        mock_telemetry_manager.send_event.assert_called_once_with(\n            event_type=\"train-start\"\n        )\n\n    def test_on_train_end(\n        self,\n        callback,\n        mock_telemetry_manager,\n        training_args,\n        trainer_state,\n        trainer_control,\n    ):\n        \"\"\"Test on_train_end sends expected event with metrics.\"\"\"\n        callback.on_train_end(training_args, trainer_state, trainer_control)\n\n        mock_telemetry_manager.send_event.assert_called_once()\n        call_args = mock_telemetry_manager.send_event.call_args[1]\n\n        assert call_args[\"event_type\"] == \"train-end\"\n        assert \"loss\" in call_args[\"properties\"]\n        assert call_args[\"properties\"][\"loss\"] == 2.5\n        assert \"learning_rate\" in call_args[\"properties\"]\n        assert call_args[\"properties\"][\"learning_rate\"] == 5e-5\n\n        # Check that metrics from RuntimeMetricsTracker are included\n        assert \"total_steps\" in call_args[\"properties\"]\n        assert call_args[\"properties\"][\"total_steps\"] == 100\n        assert \"peak_cpu_memory_bytes\" in call_args[\"properties\"]\n        assert call_args[\"properties\"][\"peak_cpu_memory_bytes\"] == 1024\n\n    def test_on_epoch_begin(\n        self,\n        callback,\n        mock_runtime_metrics_tracker,\n        training_args,\n        trainer_state,\n        trainer_control,\n    ):\n        \"\"\"Test on_epoch_begin updates epoch counter and calls tracker.\"\"\"\n        initial_epoch = callback.current_epoch\n\n        callback.on_epoch_begin(training_args, trainer_state, trainer_control)\n\n        assert callback.current_epoch == initial_epoch + 1\n        mock_runtime_metrics_tracker.start_epoch.assert_called_once_with(\n            initial_epoch + 1\n        )\n\n    def test_on_epoch_end(\n        self,\n        callback,\n        mock_runtime_metrics_tracker,\n        training_args,\n        trainer_state,\n        trainer_control,\n    ):\n        \"\"\"Test on_epoch_end calls tracker.\"\"\"\n        # Set current epoch\n        callback.current_epoch = 2\n\n        callback.on_epoch_end(training_args, trainer_state, trainer_control)\n\n        mock_runtime_metrics_tracker.end_epoch.assert_called_once_with(2)\n\n    def test_on_step_end_no_report(\n        self,\n        callback,\n        mock_telemetry_manager,\n        mock_runtime_metrics_tracker,\n        training_args,\n        trainer_state,\n        trainer_control,\n    ):\n        \"\"\"Test on_step_end updates tracker but doesn't report if criteria not met.\"\"\"\n        # Set up state to avoid reporting\n        trainer_state.global_step = 42  # Not divisible by report_interval_steps\n        callback.last_report_step = 41  # Just 1 step since last report\n        callback.last_report_time = time.time()  # Just now\n\n        callback.on_step_end(training_args, trainer_state, trainer_control)\n\n        # Should update tracker\n        mock_runtime_metrics_tracker.update_step.assert_called_once_with(42)\n\n        # Should not send telemetry\n        mock_telemetry_manager.send_event.assert_not_called()\n\n        # Should not update last report time/step\n        assert callback.last_report_step == 41\n\n    def test_on_step_end_report_interval_steps(\n        self,\n        callback,\n        mock_telemetry_manager,\n        mock_runtime_metrics_tracker,\n        mock_time,\n        training_args,\n        trainer_state,\n        trainer_control,\n    ):\n        \"\"\"Test on_step_end reports when step interval is reached.\"\"\"\n        # Set up state with clear values\n        current_step = 100  # Exactly matches report_interval_steps\n        last_step = 0\n        start_time = 900.0\n        current_time = 1000.0\n        time_diff = current_time - start_time  # 100 seconds\n\n        # Configure state and callback\n        trainer_state.global_step = current_step\n        callback.report_interval_steps = 100\n        callback.last_report_step = last_step\n        callback.start_time = start_time\n        callback.last_report_time = start_time\n\n        # Mock time.time() to return consistent values\n        mock_time.time.return_value = current_time\n\n        callback.on_step_end(training_args, trainer_state, trainer_control)\n\n        # Should update tracker\n        mock_runtime_metrics_tracker.update_step.assert_called_once_with(current_step)\n        mock_runtime_metrics_tracker.update_memory_metrics.assert_called_once()\n\n        # Should send telemetry\n        mock_telemetry_manager.send_event.assert_called_once()\n        call_args = mock_telemetry_manager.send_event.call_args[1]\n        assert call_args[\"event_type\"] == \"train-progress\"\n\n        # Properties should include expected values\n        props = call_args[\"properties\"]\n        assert props[\"step\"] == current_step\n        assert props[\"elapsed_time\"] == time_diff  # 1000 - 900 = 100\n        assert props[\"time_since_last_report\"] == time_diff  # 1000 - 900 = 100\n        assert props[\"steps_per_second\"] == 1.0  # 100 steps / 100 seconds\n\n        # Should update last report time/step\n        assert callback.last_report_step == current_step\n        assert callback.last_report_time == current_time\n\n    def test_on_step_end_report_time_elapsed(\n        self,\n        callback,\n        mock_telemetry_manager,\n        mock_runtime_metrics_tracker,  # pylint: disable=unused-argument\n        mock_time,\n        training_args,\n        trainer_state,\n        trainer_control,\n    ):\n        \"\"\"Test on_step_end reports when enough time has elapsed.\"\"\"\n        # Set up state with clear values\n        current_step = 120\n        last_step = 10\n        start_time = 900.0\n        current_time = 1000.0\n        time_diff = TIME_SINCE_LAST + 1  # Just over the threshold\n\n        # Configure state and callback\n        trainer_state.global_step = current_step\n        callback.report_interval_steps = 100\n        callback.last_report_step = last_step\n        callback.start_time = start_time\n        callback.last_report_time = current_time - time_diff\n\n        # Mock time.time() to return consistent values\n        mock_time.time.return_value = current_time\n\n        callback.on_step_end(training_args, trainer_state, trainer_control)\n\n        # Should send telemetry\n        mock_telemetry_manager.send_event.assert_called_once()\n\n        # Properties should include expected values\n        props = mock_telemetry_manager.send_event.call_args[1][\"properties\"]\n        expected_metrics = calc_expected_metrics(\n            current_step, last_step, current_time, current_time - time_diff, start_time\n        )\n        assert props[\"steps_per_second\"] == expected_metrics[\"steps_per_second\"]\n        assert (\n            props[\"time_since_last_report\"]\n            == expected_metrics[\"time_since_last_report\"]\n        )\n\n    def test_on_step_end_first_step(\n        self,\n        callback,\n        mock_telemetry_manager,\n        mock_runtime_metrics_tracker,  # pylint: disable=unused-argument\n        mock_time,\n        training_args,\n        trainer_state,\n        trainer_control,\n    ):\n        \"\"\"Test on_step_end always reports on first step.\"\"\"\n        # Set up state with clear values\n        current_step = 1  # First step\n        last_step = 0\n        start_time = 900.0\n        current_time = 1000.0\n        last_report_time = 999.0  # Just 1 second ago\n\n        # Configure state and callback\n        trainer_state.global_step = current_step\n        callback.report_interval_steps = 100\n        callback.last_report_step = last_step\n        callback.start_time = start_time\n        callback.last_report_time = last_report_time\n\n        # Mock time.time() to return consistent values\n        mock_time.time.return_value = current_time\n\n        callback.on_step_end(training_args, trainer_state, trainer_control)\n\n        # Should send telemetry even though not much time has passed\n        mock_telemetry_manager.send_event.assert_called_once()\n\n        # Properties should include expected values for first step\n        props = mock_telemetry_manager.send_event.call_args[1][\"properties\"]\n        assert props[\"step\"] == current_step\n        expected_metrics = calc_expected_metrics(\n            current_step, last_step, current_time, last_report_time, start_time\n        )\n        assert props[\"steps_per_second\"] == expected_metrics[\"steps_per_second\"]\n\n    def test_log_history_empty(\n        self,\n        callback,\n        mock_telemetry_manager,\n        mock_runtime_metrics_tracker,  # pylint: disable=unused-argument\n        mock_time,\n        training_args,\n        trainer_state,\n        trainer_control,\n    ):\n        \"\"\"Test handling of empty log history.\"\"\"\n        # Set up state with clear values\n        current_step = 1\n        start_time = 900.0\n        current_time = 1000.0\n\n        # Configure state and callback\n        trainer_state.global_step = current_step\n        trainer_state.log_history = []\n        callback.start_time = start_time\n\n        # Mock time.time() to return consistent values\n        mock_time.time.return_value = current_time\n\n        callback.on_step_end(training_args, trainer_state, trainer_control)\n\n        # Should still send telemetry\n        mock_telemetry_manager.send_event.assert_called_once()\n\n        # Properties should have default values for missing log data\n        props = mock_telemetry_manager.send_event.call_args[1][\"properties\"]\n        assert props[\"loss\"] == 0\n        assert props[\"learning_rate\"] == 0\n"
  },
  {
    "path": "tests/telemetry/test_errors.py",
    "content": "\"\"\"Tests for telemetry error utilities\"\"\"\n\n# pylint: disable=redefined-outer-name\n\nfrom unittest.mock import MagicMock, patch\n\nimport pytest\n\nfrom axolotl.telemetry.errors import sanitize_stack_trace, send_errors\n\n\n@pytest.fixture(autouse=True)\ndef reset_error_flag(monkeypatch):\n    \"\"\"Reset ERROR_HANDLED flag using monkeypatch\"\"\"\n    import axolotl.telemetry.errors\n\n    monkeypatch.setattr(axolotl.telemetry.errors, \"ERROR_HANDLED\", False)\n    yield\n    monkeypatch.setattr(axolotl.telemetry.errors, \"ERROR_HANDLED\", False)\n\n\n@pytest.fixture\ndef example_stack_trace():\n    \"\"\"Provide a sample stack trace with mixed paths\"\"\"\n    return \"\"\"Traceback (most recent call last):\n  File \"/home/user/.local/lib/python3.9/site-packages/axolotl/cli/train.py\", line 83, in main\n    trainer = get_trainer(cfg)\n  File \"/home/user/.local/lib/python3.9/site-packages/axolotl/train.py\", line 214, in get_trainer\n    model = get_model(cfg, tokenizer)\n  File \"/home/user/.local/lib/python3.9/site-packages/axolotl/utils/models.py\", line 120, in get_model\n    raise ValueError(\"Model path not found\")\nValueError: Model path not found\n\"\"\"\n\n\n@pytest.fixture\ndef windows_stack_trace():\n    \"\"\"Provide a sample stack trace with Windows paths\"\"\"\n    return \"\"\"Traceback (most recent call last):\n  File \"C:\\\\Users\\\\name\\\\AppData\\\\Local\\\\Programs\\\\Python\\\\Python39\\\\lib\\\\site-packages\\\\axolotl\\\\cli\\\\train.py\", line 83, in main\n    trainer = get_trainer(cfg)\n  File \"C:\\\\Users\\\\name\\\\AppData\\\\Local\\\\Programs\\\\Python\\\\Python39\\\\lib\\\\site-packages\\\\axolotl\\\\train.py\", line 214, in get_trainer\n    model = get_model(cfg, tokenizer)\n  File \"C:\\\\Users\\\\name\\\\AppData\\\\Local\\\\Programs\\\\Python\\\\Python39\\\\lib\\\\site-packages\\\\transformers\\\\models\\\\auto\\\\modeling_auto.py\", line 482, in from_pretrained\n    raise ValueError(f\"Unrecognized configuration class {config.__class__}\")\nValueError: Unrecognized configuration class <class 'transformers.models.llama.configuration_llama.LlamaConfig'>\n\"\"\"\n\n\n@pytest.fixture\ndef mixed_stack_trace():\n    \"\"\"Provide a sample stack trace with both axolotl and non-axolotl paths\"\"\"\n    return \"\"\"Traceback (most recent call last):\n  File \"/home/user/.local/lib/python3.9/site-packages/axolotl/cli/train.py\", line 83, in main\n    trainer = get_trainer(cfg)\n  File \"/home/user/.local/lib/python3.9/site-packages/transformers/trainer.py\", line 520, in train\n    self._inner_training_loop()\n  File \"/home/user/.local/lib/python3.9/site-packages/axolotl/utils/trainer.py\", line 75, in _inner_training_loop\n    super()._inner_training_loop()\n  File \"/home/user/.local/lib/python3.9/site-packages/torch/utils/data/dataloader.py\", line 631, in __next__\n    data = self._next_data()\nRuntimeError: CUDA out of memory\n\"\"\"\n\n\n@pytest.fixture\ndef venv_stack_trace():\n    \"\"\"Provide a sample stack trace with virtual environment paths\"\"\"\n    return \"\"\"Traceback (most recent call last):\n  File \"/home/user/venv/lib/python3.9/site-packages/transformers/trainer.py\", line 1729, in train\n    self._inner_training_loop()\n  File \"/home/user/venv/lib/python3.9/site-packages/transformers/trainer.py\", line 2013, in _inner_training_loop\n    self.accelerator.backward(loss)\n  File \"/home/user/venv/lib/python3.9/site-packages/accelerate/accelerator.py\", line 1851, in backward\n    self.scaler.scale(loss).backward(**kwargs)\n  File \"/home/user/venv/lib/python3.9/site-packages/torch/_tensor.py\", line 487, in backward\n    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)\nRuntimeError: CUDA out of memory\n\"\"\"\n\n\n@pytest.fixture\ndef dist_packages_stack_trace():\n    \"\"\"Provide a sample stack trace with dist-packages paths\"\"\"\n    return \"\"\"Traceback (most recent call last):\n  File \"/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py\", line 631, in __next__\n    data = self._next_data()\n  File \"/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py\", line 675, in _next_data\n    data = self._dataset_fetcher.fetch(index)\n  File \"/usr/local/lib/python3.8/dist-packages/torch/utils/data/_utils/fetch.py\", line 51, in fetch\n    data = [self.dataset[idx] for idx in possibly_batched_index]\n  File \"/usr/local/lib/python3.8/dist-packages/datasets/arrow_dataset.py\", line 2808, in __getitem__\n    raise IndexError(f\"Index {key} out of range for dataset of length {len(self)}.\")\nIndexError: Index 10000 out of range for dataset of length 9832.\n\"\"\"\n\n\n@pytest.fixture\ndef project_stack_trace():\n    \"\"\"Provide a sample stack trace from a project directory (not a virtual env)\"\"\"\n    return \"\"\"Traceback (most recent call last):\n  File \"/home/user/projects/myproject/run.py\", line 25, in <module>\n    main()\n  File \"/home/user/projects/myproject/src/cli.py\", line 45, in main\n    app.run()\n  File \"/home/user/projects/myproject/src/app.py\", line 102, in run\n    raise ValueError(\"Configuration missing\")\nValueError: Configuration missing\n\"\"\"\n\n\ndef test_sanitize_stack_trace(example_stack_trace):\n    \"\"\"Test that sanitize_stack_trace properly preserves axolotl paths\"\"\"\n    sanitized = sanitize_stack_trace(example_stack_trace)\n\n    # Check that personal paths are removed\n    assert \"/home/user\" not in sanitized\n    assert \".local/lib/python3.9\" not in sanitized\n\n    # Check that site-packages is preserved\n    assert \"site-packages/axolotl/cli/train.py\" in sanitized\n    assert \"site-packages/axolotl/train.py\" in sanitized\n    assert \"site-packages/axolotl/utils/models.py\" in sanitized\n\n    # Check that error message is preserved\n    assert \"ValueError: Model path not found\" in sanitized\n\n\ndef test_sanitize_windows_paths(windows_stack_trace):\n    \"\"\"Test that sanitize_stack_trace handles Windows paths\"\"\"\n    sanitized = sanitize_stack_trace(windows_stack_trace)\n\n    # Check that personal paths are removed\n    assert \"C:\\\\Users\\\\name\" not in sanitized\n    assert \"AppData\\\\Local\\\\Programs\\\\Python\" not in sanitized\n\n    # Check that both axolotl and transformers packages are preserved\n    assert (\n        \"site-packages\\\\axolotl\\\\cli\\\\train.py\" in sanitized\n        or \"site-packages/axolotl/cli/train.py\" in sanitized\n    )\n    assert (\n        \"site-packages\\\\axolotl\\\\train.py\" in sanitized\n        or \"site-packages/axolotl/train.py\" in sanitized\n    )\n    assert (\n        \"site-packages\\\\transformers\\\\models\\\\auto\\\\modeling_auto.py\" in sanitized\n        or \"site-packages/transformers/models/auto/modeling_auto.py\" in sanitized\n    )\n\n    # Check that error message is preserved\n    assert \"ValueError: Unrecognized configuration class\" in sanitized\n\n\ndef test_sanitize_mixed_paths(mixed_stack_trace):\n    \"\"\"Test that sanitize_stack_trace preserves all package paths\"\"\"\n    sanitized = sanitize_stack_trace(mixed_stack_trace)\n\n    # Check that all package paths are preserved\n    assert \"site-packages/axolotl/cli/train.py\" in sanitized\n    assert \"site-packages/transformers/trainer.py\" in sanitized\n    assert \"site-packages/axolotl/utils/trainer.py\" in sanitized\n    assert \"site-packages/torch/utils/data/dataloader.py\" in sanitized\n\n    # Check that error message is preserved\n    assert \"RuntimeError: CUDA out of memory\" in sanitized\n\n\ndef test_sanitize_venv_paths(venv_stack_trace):\n    \"\"\"Test that sanitize_stack_trace preserves virtual environment package paths\"\"\"\n    sanitized = sanitize_stack_trace(venv_stack_trace)\n\n    # Check that personal paths are removed\n    assert \"/home/user/venv\" not in sanitized\n\n    # Check that all package paths are preserved\n    assert \"site-packages/transformers/trainer.py\" in sanitized\n    assert \"site-packages/accelerate/accelerator.py\" in sanitized\n    assert \"site-packages/torch/_tensor.py\" in sanitized\n\n    # Check that error message is preserved\n    assert \"RuntimeError: CUDA out of memory\" in sanitized\n\n\ndef test_sanitize_dist_packages(dist_packages_stack_trace):\n    \"\"\"Test that sanitize_stack_trace preserves dist-packages paths\"\"\"\n    sanitized = sanitize_stack_trace(dist_packages_stack_trace)\n\n    # Check that system paths are removed\n    assert \"/usr/local/lib/python3.8\" not in sanitized\n\n    # Check that all package paths are preserved\n    assert \"dist-packages/torch/utils/data/dataloader.py\" in sanitized\n    assert \"dist-packages/torch/utils/data/_utils/fetch.py\" in sanitized\n    assert \"dist-packages/datasets/arrow_dataset.py\" in sanitized\n\n    # Check that error message is preserved\n    assert (\n        \"IndexError: Index 10000 out of range for dataset of length 9832.\" in sanitized\n    )\n\n\ndef test_sanitize_project_paths(project_stack_trace):\n    \"\"\"Test handling of project paths (non-virtual env)\"\"\"\n    sanitized = sanitize_stack_trace(project_stack_trace)\n\n    # Check that personal paths are removed\n    assert \"/home/user/projects\" not in sanitized\n\n    # For non-package paths, we should at least preserve the filename\n    assert \"run.py\" in sanitized\n    assert \"cli.py\" in sanitized\n    assert \"app.py\" in sanitized\n\n    # Check that error message is preserved\n    assert \"ValueError: Configuration missing\" in sanitized\n\n\n@pytest.fixture\ndef mock_telemetry_manager():\n    \"\"\"Create a mock TelemetryManager\"\"\"\n    with patch(\"axolotl.telemetry.errors.TelemetryManager\") as mock_manager_class:\n        mock_manager = MagicMock()\n        mock_manager.enabled = True\n        mock_manager_class.get_instance.return_value = mock_manager\n        yield mock_manager\n\n\ndef test_send_errors_successful_execution(mock_telemetry_manager):\n    \"\"\"Test that send_errors doesn't send telemetry for successful function execution\"\"\"\n\n    @send_errors\n    def test_func():\n        return \"success\"\n\n    result = test_func()\n    assert result == \"success\"\n    mock_telemetry_manager.send_event.assert_not_called()\n\n\ndef test_send_errors_with_exception(mock_telemetry_manager):\n    \"\"\"Test that send_errors sends telemetry when an exception occurs\"\"\"\n    test_error = ValueError(\"Test error\")\n\n    @send_errors\n    def test_func():\n        raise test_error\n\n    with pytest.raises(ValueError) as excinfo:\n        test_func()\n\n    assert excinfo.value == test_error\n    mock_telemetry_manager.send_event.assert_called_once()\n\n    # Check that the error info was passed correctly\n    call_args = mock_telemetry_manager.send_event.call_args[1]\n    assert \"test_func-error\" in call_args[\"event_type\"]\n    assert \"Test error\" in call_args[\"properties\"][\"exception\"]\n    assert \"stack_trace\" in call_args[\"properties\"]\n\n\ndef test_send_errors_nested_calls(mock_telemetry_manager):\n    \"\"\"Test that send_errors only sends telemetry once for nested decorated functions\"\"\"\n\n    @send_errors\n    def inner_func():\n        raise ValueError(\"Inner error\")\n\n    @send_errors\n    def outer_func():\n        return inner_func()\n\n    with pytest.raises(ValueError):\n        outer_func()\n\n    # Telemetry should be sent only once for the inner function\n    assert mock_telemetry_manager.send_event.call_count == 1\n    call_args = mock_telemetry_manager.send_event.call_args[1]\n    assert \"inner_func-error\" in call_args[\"event_type\"]\n\n\ndef test_send_errors_telemetry_disable():\n    \"\"\"Test that send_errors doesn't attempt to send telemetry when disabled\"\"\"\n\n    with patch(\"axolotl.telemetry.errors.TelemetryManager\") as mock_manager_class:\n        mock_manager = MagicMock()\n        mock_manager.enabled = False\n        mock_manager_class.get_instance.return_value = mock_manager\n\n        @send_errors\n        def test_func():\n            raise ValueError(\"Test error\")\n\n        with pytest.raises(ValueError):\n            test_func()\n\n        mock_manager.send_event.assert_not_called()\n\n\ndef test_error_handled_reset():\n    \"\"\"Test that ERROR_HANDLED flag is properly reset\"\"\"\n    with patch(\"axolotl.telemetry.errors.TelemetryManager\") as mock_manager_class:\n        # Create and configure the mock manager\n        mock_manager = MagicMock()\n        mock_manager.enabled = True\n        mock_manager_class.get_instance.return_value = mock_manager\n\n        from axolotl.telemetry.errors import ERROR_HANDLED\n\n        @send_errors\n        def test_func():\n            raise ValueError(\"Test error\")\n\n        assert not ERROR_HANDLED\n\n        with pytest.raises(ValueError):\n            test_func()\n\n        from axolotl.telemetry.errors import ERROR_HANDLED\n\n        assert ERROR_HANDLED\n\n\ndef test_module_path_resolution(mock_telemetry_manager):\n    \"\"\"Test that the module path is correctly resolved for the event type\"\"\"\n    import inspect\n\n    current_module = inspect.getmodule(test_module_path_resolution).__name__\n\n    @send_errors\n    def test_func():\n        raise ValueError(\"Test error\")\n\n    with pytest.raises(ValueError):\n        test_func()\n\n    assert mock_telemetry_manager.send_event.called\n    event_type = mock_telemetry_manager.send_event.call_args[1][\"event_type\"]\n\n    expected_event_type = f\"{current_module}.test_func-error\"\n    assert expected_event_type == event_type\n"
  },
  {
    "path": "tests/telemetry/test_manager.py",
    "content": "\"\"\"Tests for TelemetryManager class and utilities\"\"\"\n\n# pylint: disable=redefined-outer-name,protected-access\n\nimport os\nfrom unittest.mock import patch\n\nimport pytest\nimport yaml\n\nfrom axolotl.telemetry.manager import TelemetryManager\n\n\n@pytest.fixture\ndef mock_whitelist(tmp_path):\n    \"\"\"Create a temporary whitelist file for testing\"\"\"\n    whitelist_content = {\n        \"organizations\": [\"meta-llama\", \"mistralai\"],\n    }\n    whitelist_file = tmp_path / \"whitelist.yaml\"\n    with open(whitelist_file, \"w\", encoding=\"utf-8\") as f:\n        yaml.dump(whitelist_content, f)\n\n    return str(whitelist_file)\n\n\n@pytest.fixture\ndef telemetry_manager_class():\n    \"\"\"Reset the TelemetryManager singleton between tests\"\"\"\n    original_instance = TelemetryManager._instance\n    original_initialized = TelemetryManager._initialized\n    TelemetryManager._instance = None\n    TelemetryManager._initialized = False\n    yield TelemetryManager\n    TelemetryManager._instance = original_instance\n    TelemetryManager._initialized = original_initialized\n\n\n@pytest.fixture\ndef manager(telemetry_manager_class, mock_whitelist):\n    \"\"\"Create a TelemetryManager instance with mocked dependencies\"\"\"\n    with (\n        patch(\"posthog.capture\"),\n        patch(\"posthog.flush\"),\n        patch(\"time.sleep\"),\n        patch(\"axolotl.telemetry.manager.WHITELIST_PATH\", mock_whitelist),\n        patch.dict(os.environ, {\"RANK\": \"0\"}),\n    ):\n        manager = telemetry_manager_class()\n        # Manually enable for most tests\n        manager.enabled = True\n        return manager\n\n\ndef test_singleton_instance(telemetry_manager_class):\n    \"\"\"Test that TelemetryManager is a singleton\"\"\"\n    with (\n        patch(\"posthog.capture\"),\n        patch(\"time.sleep\"),\n        patch.dict(os.environ, {\"RANK\": \"0\"}),\n    ):\n        first = telemetry_manager_class()\n        second = telemetry_manager_class()\n        assert first is second\n        assert telemetry_manager_class.get_instance() is first\n\n\ndef test_telemetry_enabled_by_default(telemetry_manager_class):\n    \"\"\"Test that telemetry is enabled by default (opt-out)\"\"\"\n    with (\n        patch.dict(os.environ, {\"RANK\": \"0\"}, clear=True),\n        patch(\"time.sleep\"),\n        patch(\"logging.Logger.info\"),\n    ):\n        manager = telemetry_manager_class()\n        assert manager.enabled\n\n\ndef test_telemetry_enabled_with_explicit_opt_in(telemetry_manager_class):\n    \"\"\"Test that telemetry is enabled when AXOLOTL_DO_NOT_TRACK=0\"\"\"\n    with (\n        patch.dict(os.environ, {\"AXOLOTL_DO_NOT_TRACK\": \"0\", \"RANK\": \"0\"}),\n        patch(\"time.sleep\"),\n    ):\n        manager = telemetry_manager_class()\n        assert manager.enabled\n\n\ndef test_telemetry_disabled_with_axolotl_do_not_track(telemetry_manager_class):\n    \"\"\"Test that telemetry is disabled when AXOLOTL_DO_NOT_TRACK=1\"\"\"\n    with (\n        patch.dict(os.environ, {\"AXOLOTL_DO_NOT_TRACK\": \"1\", \"RANK\": \"0\"}),\n        patch(\"time.sleep\"),\n    ):\n        manager = telemetry_manager_class()\n        assert not manager.enabled\n\n\ndef test_telemetry_disabled_with_do_not_track(telemetry_manager_class):\n    \"\"\"Test that telemetry is disabled when DO_NOT_TRACK=1\"\"\"\n    with (\n        patch.dict(\n            os.environ, {\"AXOLOTL_DO_NOT_TRACK\": \"0\", \"DO_NOT_TRACK\": \"1\", \"RANK\": \"0\"}\n        ),\n        patch(\"time.sleep\"),\n    ):\n        manager = telemetry_manager_class()\n        assert not manager.enabled\n\n\ndef test_telemetry_disabled_for_non_main_process(telemetry_manager_class):\n    \"\"\"Test that telemetry is disabled for non-main processes\"\"\"\n    with (\n        patch.dict(os.environ, {\"AXOLOTL_DO_NOT_TRACK\": \"0\", \"RANK\": \"1\"}),\n        patch(\"time.sleep\"),\n    ):\n        manager = telemetry_manager_class()\n        assert not manager.enabled\n\n\ndef test_is_whitelisted(telemetry_manager_class, mock_whitelist):\n    \"\"\"Test org whitelist functionality\"\"\"\n    with (\n        patch(\"axolotl.telemetry.manager.WHITELIST_PATH\", mock_whitelist),\n        patch.dict(os.environ, {\"AXOLOTL_DO_NOT_TRACK\": \"0\"}),\n    ):\n        manager = telemetry_manager_class()\n\n        # Should match organizations from the mock whitelist\n        assert manager._is_whitelisted(\"meta-llama/llama-7b\")\n        assert manager._is_whitelisted(\"mistralai/mistral-7b-instruct\")\n        # Should not match\n        assert not manager._is_whitelisted(\"unknown/model\")\n        # Should handle case insensitively\n        assert manager._is_whitelisted(\"META-LLAMA/Llama-7B\")\n        # Should handle empty input\n        assert not manager._is_whitelisted(\"\")\n\n\ndef test_system_info_collection(manager):\n    \"\"\"Test system information collection\"\"\"\n    system_info = manager._get_system_info()\n\n    # Check essential keys\n    assert \"os\" in system_info\n    assert \"python_version\" in system_info\n    assert \"cpu_count\" in system_info\n    assert \"memory_total\" in system_info\n    assert \"accelerator_count\" in system_info\n\n\ndef test_send_event(telemetry_manager_class):\n    \"\"\"Test basic event sending\"\"\"\n    with (\n        patch(\"posthog.capture\") as mock_capture,\n        patch.dict(os.environ, {\"AXOLOTL_DO_NOT_TRACK\": \"0\"}),\n    ):\n        manager = telemetry_manager_class()\n\n        # Test with clean properties (no PII)\n        manager.send_event(\"test_event\", {\"key\": \"value\"})\n        assert mock_capture.called\n        assert mock_capture.call_args[1][\"event\"] == \"test_event\"\n        assert mock_capture.call_args[1][\"properties\"] == {\"key\": \"value\"}\n        assert mock_capture.call_args[1][\"distinct_id\"] == manager.run_id\n\n        # Test with default properties (None)\n        mock_capture.reset_mock()\n        manager.send_event(\"simple_event\")\n        assert mock_capture.called\n        assert mock_capture.call_args[1][\"properties\"] == {}\n\n\ndef test_send_system_info(telemetry_manager_class):\n    \"\"\"Test sending system info\"\"\"\n    with (\n        patch(\"posthog.capture\") as mock_capture,\n        patch.dict(os.environ, {\"AXOLOTL_DO_NOT_TRACK\": \"0\"}),\n    ):\n        manager = telemetry_manager_class()\n        manager.send_system_info()\n        assert mock_capture.called\n        assert mock_capture.call_args[1][\"event\"] == \"system-info\"\n        assert mock_capture.call_args[1][\"properties\"] == manager.system_info\n\n\ndef test_redacted_properties(telemetry_manager_class):\n    \"\"\"Test path redaction in send_event method\"\"\"\n    with (\n        patch(\"posthog.capture\") as mock_capture,\n        patch.dict(os.environ, {\"AXOLOTL_DO_NOT_TRACK\": \"0\"}),\n    ):\n        manager = telemetry_manager_class()\n        # Test with properties containing various paths and non-paths\n        test_properties = {\n            \"filepath\": \"/home/user/sensitive/data.txt\",\n            \"windows_path\": \"C:\\\\Users\\\\name\\\\Documents\\\\project\\\\file.py\",\n            \"output_dir\": \"/var/lib/data\",\n            \"path_to_model\": \"models/llama/7b\",\n            \"message\": \"Training started\",  # Should not be redacted\n            \"metrics\": {\"loss\": 0.5, \"accuracy\": 0.95},  # Should not be redacted\n            \"base_model\": \"models/local_model\",\n            \"nested\": {\n                \"model_path\": \"/models/my_model\",\n                \"root_dir\": \"/home/user/projects\",\n                \"stats\": {\"steps\": 1000, \"epochs\": 3},  # Should not be redacted\n            },\n        }\n\n        manager.send_event(\"test_event\", test_properties)\n\n        # Verify the call was made\n        assert mock_capture.called\n\n        # Get the sanitized properties that were sent\n        sanitized = mock_capture.call_args[1][\"properties\"]\n\n        # Check that path-like and base_model keys were redacted\n        assert sanitized[\"filepath\"] == \"[REDACTED]\"\n        assert sanitized[\"windows_path\"] == \"[REDACTED]\"\n        assert sanitized[\"path_to_model\"] == \"[REDACTED]\"\n        assert sanitized[\"base_model\"] == \"[REDACTED]\"\n\n        # Check that non-path values were preserved\n        assert sanitized[\"message\"] == \"Training started\"\n        assert sanitized[\"metrics\"] == {\"loss\": 0.5, \"accuracy\": 0.95}\n\n        # Check nested structure handling\n        assert sanitized[\"nested\"][\"model_path\"] == \"[REDACTED]\"\n        assert sanitized[\"nested\"][\"root_dir\"] == \"[REDACTED]\"\n        assert sanitized[\"nested\"][\"stats\"] == {\"steps\": 1000, \"epochs\": 3}\n\n\ndef test_disable_telemetry(manager):\n    \"\"\"Test that disabled telemetry doesn't send events\"\"\"\n    with patch(\"posthog.capture\") as mock_capture:\n        manager.enabled = False\n        manager.send_event(\"test_event\")\n        assert not mock_capture.called\n\n\ndef test_exception_handling_during_send(manager):\n    \"\"\"Test that exceptions in PostHog are handled gracefully\"\"\"\n    with (\n        patch(\"posthog.capture\", side_effect=Exception(\"Test error\")),\n        patch(\"logging.Logger.warning\") as mock_warning,\n    ):\n        manager.send_event(\"test_event\")\n        warning_logged = False\n        for call in mock_warning.call_args_list:\n            if \"Failed to send telemetry event\" in str(call):\n                warning_logged = True\n                break\n        assert warning_logged\n\n\ndef test_shutdown(manager):\n    \"\"\"Test shutdown behavior\"\"\"\n    with patch(\"posthog.shutdown\") as mock_shutdown:\n        manager.shutdown()\n        assert mock_shutdown.called\n"
  },
  {
    "path": "tests/telemetry/test_runtime_metrics.py",
    "content": "\"\"\"Tests for runtime metrics telemetry module\"\"\"\n\n# pylint: disable=redefined-outer-name\n\nfrom unittest.mock import MagicMock, patch\n\nimport pytest\n\nfrom axolotl.telemetry.runtime_metrics import RuntimeMetrics, RuntimeMetricsTracker\n\n\n@pytest.fixture\ndef mock_time():\n    \"\"\"Mock time.time() to have predictable values in tests\"\"\"\n    with patch(\"time.time\") as mock_time:\n        # Start with time 1000.0 and increment by 10 seconds on each call\n        times = [1000.0 + i * 10 for i in range(10)]\n        mock_time.side_effect = times\n        yield mock_time\n\n\n@pytest.fixture\ndef mock_telemetry_manager():\n    \"\"\"Create a mock TelemetryManager\"\"\"\n    with patch(\n        \"axolotl.telemetry.runtime_metrics.TelemetryManager\"\n    ) as mock_manager_class:\n        mock_manager = MagicMock()\n        mock_manager.enabled = True\n        mock_manager_class.get_instance.return_value = mock_manager\n        yield mock_manager\n\n\n@pytest.fixture\ndef mock_psutil():\n    \"\"\"Mock psutil for memory information\"\"\"\n    with patch(\"axolotl.telemetry.runtime_metrics.psutil\") as mock_psutil:\n        mock_process = MagicMock()\n        mock_memory_info = MagicMock()\n        # Set initial memory to 1GB\n        mock_memory_info.rss = 1024 * 1024 * 1024\n        mock_process.memory_info.return_value = mock_memory_info\n        mock_psutil.Process.return_value = mock_process\n        yield mock_psutil\n\n\n@pytest.fixture\ndef mock_torch():\n    \"\"\"Mock torch.cuda functions\"\"\"\n    with patch(\"axolotl.telemetry.runtime_metrics.torch\") as mock_torch:\n        mock_torch.cuda.is_available.return_value = True\n        mock_torch.cuda.device_count.return_value = 2\n\n        # Mock memory allocated per device (1GB for device 0, 2GB for device 1)\n        mock_torch.cuda.memory_allocated.side_effect = lambda device: (\n            (device + 1) * 1024 * 1024 * 1024\n        )\n\n        yield mock_torch\n\n\nclass TestRuntimeMetrics:\n    \"\"\"Tests for RuntimeMetrics class.\"\"\"\n\n    def test_initialization(self):\n        \"\"\"Test RuntimeMetrics initialization.\"\"\"\n        metrics = RuntimeMetrics(start_time=1000.0)\n\n        assert metrics.start_time == 1000.0\n        assert metrics.epoch_start_times == {}\n        assert metrics.epoch_end_times == {}\n        assert metrics.peak_gpu_memory == {}\n        assert metrics.total_steps == 0\n        assert metrics.current_epoch == 0\n        assert metrics.current_step == 0\n        assert metrics.peak_cpu_memory == 0\n\n    def test_elapsed_time(self, mock_time):\n        \"\"\"Test elapsed_time property.\"\"\"\n        metrics = RuntimeMetrics(start_time=1000.0)\n\n        # Mock time.time() to return 1050.0\n        mock_time.side_effect = [1050.0]\n\n        assert metrics.elapsed_time == 50.0\n\n    def test_epoch_time(self):\n        \"\"\"Test epoch_time method.\"\"\"\n        metrics = RuntimeMetrics(start_time=1000.0)\n\n        # No epoch data\n        assert metrics.epoch_time(0) is None\n\n        # Add epoch start but no end\n        metrics.epoch_start_times[0] = 1000.0\n        assert metrics.epoch_time(0) is None\n\n        # Add epoch end\n        metrics.epoch_end_times[0] = 1060.0\n        assert metrics.epoch_time(0) == 60.0\n\n    def test_average_epoch_time(self):\n        \"\"\"Test average_epoch_time method.\"\"\"\n        metrics = RuntimeMetrics(start_time=1000.0)\n\n        # No completed epochs\n        assert metrics.average_epoch_time() is None\n\n        # Add one completed epoch\n        metrics.epoch_start_times[0] = 1000.0\n        metrics.epoch_end_times[0] = 1060.0\n        assert metrics.average_epoch_time() == 60.0\n\n        # Add second completed epoch\n        metrics.epoch_start_times[1] = 1060.0\n        metrics.epoch_end_times[1] = 1140.0  # 80 seconds\n        assert metrics.average_epoch_time() == 70.0  # Average of 60 and 80\n\n        # Add incomplete epoch (should not affect average)\n        metrics.epoch_start_times[2] = 1140.0\n        assert metrics.average_epoch_time() == 70.0\n\n    def test_steps_per_second(self, mock_time):\n        \"\"\"Test steps_per_second method.\"\"\"\n        metrics = RuntimeMetrics(start_time=1000.0)\n\n        # No steps - first call to time.time()\n        mock_time.side_effect = None\n        mock_time.return_value = 1050.0\n        assert metrics.steps_per_second() is None\n\n        # Add steps - second call to time.time()\n        metrics.total_steps = 100\n        mock_time.return_value = 1050.0  # Keep same time for consistent result\n        assert metrics.steps_per_second() == 2.0  # 100 steps / 50 seconds\n\n    def test_to_dict_basic(self, mock_time):\n        \"\"\"Test to_dict method with basic metrics.\"\"\"\n        metrics = RuntimeMetrics(start_time=1000.0)\n        metrics.total_steps = 100\n        metrics.peak_cpu_memory = 2 * 1024 * 1024 * 1024  # 2GB\n\n        # Mock elapsed_time\n        mock_time.side_effect = None\n        mock_time.return_value = 1050.0\n\n        result = metrics.to_dict()\n\n        assert result[\"total_time_seconds\"] == 50.0\n        assert result[\"total_steps\"] == 100\n        assert result[\"steps_per_second\"] == 2.0\n        assert result[\"epochs_completed\"] == 0\n        assert result[\"peak_cpu_memory_bytes\"] == 2 * 1024 * 1024 * 1024\n        assert \"epoch_times\" not in result\n        assert \"gpu_memory\" not in result\n\n    def test_to_dict_with_epochs(self, mock_time):\n        \"\"\"Test to_dict method with epoch data.\"\"\"\n        metrics = RuntimeMetrics(start_time=1000.0)\n        metrics.total_steps = 100\n\n        # Add epoch data\n        metrics.epoch_start_times[0] = 1000.0\n        metrics.epoch_end_times[0] = 1060.0\n        metrics.epoch_start_times[1] = 1060.0\n        metrics.epoch_end_times[1] = 1140.0\n\n        # Mock elapsed_time\n        mock_time.side_effect = None\n        mock_time.return_value = 1150.0\n\n        result = metrics.to_dict()\n\n        assert \"epoch_times\" in result\n        assert result[\"epoch_times\"][\"epoch_0_seconds\"] == 60.0\n        assert result[\"epoch_times\"][\"epoch_1_seconds\"] == 80.0\n        assert result[\"average_epoch_time_seconds\"] == 70.0\n\n    def test_to_dict_with_gpu_memory(self, mock_time):\n        \"\"\"Test to_dict method with GPU memory data.\"\"\"\n        metrics = RuntimeMetrics(start_time=1000.0)\n        metrics.peak_gpu_memory = {\n            0: 1 * 1024 * 1024 * 1024,  # 1GB\n            1: 2 * 1024 * 1024 * 1024,  # 2GB\n        }\n\n        # Mock elapsed_time\n        mock_time.side_effect = [1050.0]\n\n        result = metrics.to_dict()\n\n        assert \"gpu_memory\" in result\n        assert result[\"gpu_memory\"][\"gpu_0_peak_memory_bytes\"] == 1 * 1024 * 1024 * 1024\n        assert result[\"gpu_memory\"][\"gpu_1_peak_memory_bytes\"] == 2 * 1024 * 1024 * 1024\n\n\nclass TestRuntimeMetricsTracker:\n    \"\"\"Tests for RuntimeMetricsTracker class.\"\"\"\n\n    # pylint: disable=unused-argument\n    def test_initialization(self, mock_time, mock_telemetry_manager):\n        \"\"\"Test RuntimeMetricsTracker initialization.\"\"\"\n        tracker = RuntimeMetricsTracker()\n\n        assert isinstance(tracker.metrics, RuntimeMetrics)\n        assert tracker.metrics.start_time == 1000.0  # First value from mock_time\n\n    # pylint: disable=unused-argument\n    def test_start_epoch(\n        self, mock_time, mock_psutil, mock_torch, mock_telemetry_manager\n    ):\n        \"\"\"Test start_epoch method.\"\"\"\n        tracker = RuntimeMetricsTracker()\n\n        # Reset mock_time to control next value\n        mock_time.side_effect = [1010.0]\n\n        tracker.start_epoch(0)\n\n        assert tracker.metrics.current_epoch == 0\n        assert tracker.metrics.epoch_start_times[0] == 1010.0\n\n        # Verify memory metrics were updated\n        assert tracker.metrics.peak_cpu_memory == 1 * 1024 * 1024 * 1024\n        assert 0 in tracker.metrics.peak_gpu_memory\n        assert 1 in tracker.metrics.peak_gpu_memory\n\n    # pylint: disable=unused-argument\n    def test_end_epoch(self, mock_time, mock_telemetry_manager):\n        \"\"\"Test end_epoch method.\"\"\"\n        tracker = RuntimeMetricsTracker()\n\n        # Start epoch 0\n        mock_time.side_effect = [1010.0]\n        tracker.start_epoch(0)\n\n        # End epoch 0\n        mock_time.side_effect = [1060.0]\n        tracker.end_epoch(0)\n\n        assert 0 in tracker.metrics.epoch_end_times\n        assert tracker.metrics.epoch_end_times[0] == 1060.0\n\n    # pylint: disable=unused-argument\n    def test_update_step(\n        self, mock_time, mock_psutil, mock_torch, mock_telemetry_manager\n    ):\n        \"\"\"Test update_step method.\"\"\"\n        tracker = RuntimeMetricsTracker()\n\n        # Update step to a non-multiple of 100\n        tracker.update_step(42)\n\n        assert tracker.metrics.current_step == 42\n        assert tracker.metrics.total_steps == 1\n\n        # Memory metrics should not be updated for non-multiple of 100\n        assert tracker.metrics.peak_cpu_memory == 0\n\n        # Update step to a multiple of 100\n        tracker.update_step(100)\n\n        assert tracker.metrics.current_step == 100\n        assert tracker.metrics.total_steps == 2\n\n        # Memory metrics should be updated for multiple of 100\n        assert tracker.metrics.peak_cpu_memory == 1 * 1024 * 1024 * 1024\n\n    # pylint: disable=unused-argument\n    def test_update_memory_metrics(\n        self, mock_psutil, mock_torch, mock_telemetry_manager\n    ):\n        \"\"\"Test update_memory_metrics method.\"\"\"\n        tracker = RuntimeMetricsTracker()\n\n        # Initial memory state\n        assert tracker.metrics.peak_cpu_memory == 0\n        assert tracker.metrics.peak_gpu_memory == {}\n\n        # Update memory metrics\n        tracker.update_memory_metrics()\n\n        # Verify CPU memory\n        assert tracker.metrics.peak_cpu_memory == 1 * 1024 * 1024 * 1024\n\n        # Verify GPU memory\n        assert tracker.metrics.peak_gpu_memory[0] == 1 * 1024 * 1024 * 1024\n        assert tracker.metrics.peak_gpu_memory[1] == 2 * 1024 * 1024 * 1024\n\n        # Change mocked memory values to be lower\n        mock_process = mock_psutil.Process.return_value\n        mock_memory_info = mock_process.memory_info.return_value\n        mock_memory_info.rss = 0.5 * 1024 * 1024 * 1024  # 0.5GB\n\n        mock_torch.cuda.memory_allocated.side_effect = lambda device: (\n            (device + 0.5) * 1024 * 1024 * 1024\n        )\n\n        # Update memory metrics again\n        tracker.update_memory_metrics()\n\n        # Peak values should not decrease\n        assert tracker.metrics.peak_cpu_memory == 1 * 1024 * 1024 * 1024\n        assert tracker.metrics.peak_gpu_memory[0] == 1 * 1024 * 1024 * 1024\n        assert tracker.metrics.peak_gpu_memory[1] == 2 * 1024 * 1024 * 1024\n\n        # Change mocked memory values to be higher\n        mock_memory_info.rss = 2 * 1024 * 1024 * 1024  # 2GB\n\n        mock_torch.cuda.memory_allocated.side_effect = lambda device: (\n            (device + 2) * 1024 * 1024 * 1024\n        )\n\n        # Update memory metrics again\n        tracker.update_memory_metrics()\n\n        # Peak values should increase\n        assert tracker.metrics.peak_cpu_memory == 2 * 1024 * 1024 * 1024\n        assert tracker.metrics.peak_gpu_memory[0] == 2 * 1024 * 1024 * 1024\n        assert tracker.metrics.peak_gpu_memory[1] == 3 * 1024 * 1024 * 1024\n\n    # pylint: disable=unused-argument\n    def test_get_memory_metrics(self, mock_psutil, mock_torch, mock_telemetry_manager):\n        \"\"\"Test get_memory_metrics method.\"\"\"\n        tracker = RuntimeMetricsTracker()\n\n        # Set peak memory values\n        tracker.metrics.peak_cpu_memory = 2 * 1024 * 1024 * 1024\n        tracker.metrics.peak_gpu_memory = {\n            0: 3 * 1024 * 1024 * 1024,\n            1: 4 * 1024 * 1024 * 1024,\n        }\n\n        # Get memory metrics\n        memory_metrics = tracker.get_memory_metrics()\n\n        # Verify CPU memory\n        assert (\n            memory_metrics[\"cpu_memory_bytes\"] == 1 * 1024 * 1024 * 1024\n        )  # Current value from mock\n        assert (\n            memory_metrics[\"peak_cpu_memory_bytes\"] == 2 * 1024 * 1024 * 1024\n        )  # Peak value we set\n\n        # Verify GPU memory\n        assert (\n            memory_metrics[\"gpu_0_memory_bytes\"] == 1 * 1024 * 1024 * 1024\n        )  # Current value from mock\n        assert (\n            memory_metrics[\"gpu_0_peak_memory_bytes\"] == 3 * 1024 * 1024 * 1024\n        )  # Peak value we set\n        assert (\n            memory_metrics[\"gpu_1_memory_bytes\"] == 2 * 1024 * 1024 * 1024\n        )  # Current value from mock\n        assert (\n            memory_metrics[\"gpu_1_peak_memory_bytes\"] == 4 * 1024 * 1024 * 1024\n        )  # Peak value we set\n"
  },
  {
    "path": "tests/test_chunked_xentropy.py",
    "content": "\"\"\"\ntest suite for chunked cross entropy\n\"\"\"\n\nimport pytest\nimport torch\nfrom torch import nn\n\nfrom axolotl.monkeypatch.loss.chunked import get_causal_lm_loss\n\n\n@pytest.fixture\ndef chunked_fixtures():\n    model_dim = 512\n    vocab_size = 1024 * 256\n    seq_len = 2048\n    batch_size = 1\n\n    lm_head = nn.Linear(model_dim, vocab_size)\n    hidden_state = torch.randn(batch_size, seq_len, model_dim)\n    labels = torch.randint(low=0, high=vocab_size, size=(batch_size, seq_len))\n    return lm_head, hidden_state, labels, vocab_size\n\n\ndef test_chunked_forward(chunked_fixtures):\n    lm_head, hidden_state, labels, vocab_size = chunked_fixtures\n    lm_loss = get_causal_lm_loss()\n\n    logits = lm_head(hidden_state)\n\n    chunked_lm_loss = lm_loss(logits, labels)\n\n    logits_flattened = logits.view(-1, vocab_size)\n    labels_flattened = labels.view(-1)\n\n    loss = nn.functional.cross_entropy(\n        logits_flattened.float(), labels_flattened, reduction=\"mean\"\n    )\n\n    assert torch.allclose(chunked_lm_loss, loss, atol=1e-2, rtol=1e-2)\n"
  },
  {
    "path": "tests/test_context_parallel_batch_size.py",
    "content": "\"\"\"Tests for batch_size calculation with context parallelism.\"\"\"\n\nimport sys\nimport types\n\nimport pytest\n\nfrom axolotl.utils.config import normalize_config, validate_config\nfrom axolotl.utils.dict import DictDefault\n\n\n@pytest.fixture(name=\"cp_base_cfg\")\ndef fixture_cp_base_cfg(min_base_cfg):\n    return (\n        DictDefault(\n            micro_batch_size=2,\n            gradient_accumulation_steps=4,\n            sequence_len=2048,\n            num_epochs=1,\n            flash_attention=True,\n        )\n        | min_base_cfg\n    )\n\n\nclass TestContextParallelBatchSize:\n    \"\"\"Verify batch_size scales by effective dp world_size when using context parallelism.\"\"\"\n\n    @pytest.mark.parametrize(\n        \"world_size, context_parallel_size, expected_batch_size\",\n        [\n            (4, 1, 32),  # no CP: 2*4*4 = 32\n            (4, 2, 16),  # CP=2: 2*4*(4//2) = 16\n            (4, 4, 8),  # CP=4: 2*4*(4//4) = 8\n            (2, 2, 8),  # CP=ws: 2*4*(2//2) = 8 (no scaling)\n        ],\n    )\n    def test_batch_size_with_context_parallelism(\n        self,\n        cp_base_cfg,\n        monkeypatch,\n        world_size,\n        context_parallel_size,\n        expected_batch_size,\n    ):\n        monkeypatch.setenv(\"WORLD_SIZE\", str(world_size))\n        # Mock ring_flash_attn since it's not installable on CPU,\n        # but required by schema validation when context_parallel_size > 1.\n        if \"ring_flash_attn\" not in sys.modules:\n            monkeypatch.setitem(\n                sys.modules, \"ring_flash_attn\", types.ModuleType(\"ring_flash_attn\")\n            )\n        cp_base_cfg[\"context_parallel_size\"] = context_parallel_size\n        cfg = validate_config(cp_base_cfg)\n        normalize_config(cfg)\n        assert cfg.batch_size == expected_batch_size\n"
  },
  {
    "path": "tests/test_convert.py",
    "content": "\"\"\"Unit tests for src/axolotl/convert.py\"\"\"\n\nimport json\n\nimport pytest\n\nfrom axolotl.convert import (\n    FileReader,\n    FileWriter,\n    JsonlSerializer,\n    JsonParser,\n    JsonToJsonlConverter,\n    StdoutWriter,\n)\n\n\nclass TestJsonParser:\n    def test_parse_valid_json_array(self):\n        parser = JsonParser()\n        result = parser.parse('[{\"key\": \"value\"}]')\n        assert result == [{\"key\": \"value\"}]\n\n    def test_parse_valid_json_object(self):\n        parser = JsonParser()\n        result = parser.parse('{\"key\": \"value\"}')\n        assert result == {\"key\": \"value\"}\n\n    def test_parse_invalid_json_raises(self):\n        parser = JsonParser()\n        with pytest.raises(json.JSONDecodeError):\n            parser.parse(\"not valid json\")\n\n\nclass TestJsonlSerializer:\n    def test_serialize_single_item(self):\n        serializer = JsonlSerializer()\n        result = serializer.serialize([{\"a\": 1}])\n        assert result == '{\"a\": 1}'\n\n    def test_serialize_multiple_items(self):\n        serializer = JsonlSerializer()\n        result = serializer.serialize([{\"a\": 1}, {\"b\": 2}])\n        lines = result.split(\"\\n\")\n        assert len(lines) == 2\n        assert json.loads(lines[0]) == {\"a\": 1}\n        assert json.loads(lines[1]) == {\"b\": 2}\n\n    def test_serialize_empty_list(self):\n        serializer = JsonlSerializer()\n        result = serializer.serialize([])\n        assert result == \"\"\n\n\nclass TestFileReaderWriter:\n    def test_read_write_roundtrip(self, tmp_path):\n        test_file = tmp_path / \"test.txt\"\n        content = '{\"hello\": \"world\"}'\n        writer = FileWriter(str(test_file))\n        writer.write(content)\n\n        reader = FileReader()\n        result = reader.read(str(test_file))\n        assert result == content\n\n\nclass TestStdoutWriter:\n    def test_write_to_stdout(self, capsys):\n        writer = StdoutWriter()\n        writer.write(\"hello\")\n        captured = capsys.readouterr()\n        assert captured.out == \"hello\\n\"\n\n\nclass TestJsonToJsonlConverter:\n    def test_convert_json_to_jsonl(self, tmp_path):\n        input_data = [{\"name\": \"Alice\"}, {\"name\": \"Bob\"}]\n        input_file = tmp_path / \"input.json\"\n        output_file = tmp_path / \"output.jsonl\"\n\n        input_file.write_text(json.dumps(input_data), encoding=\"utf-8\")\n\n        converter = JsonToJsonlConverter(\n            FileReader(), FileWriter(str(output_file)), JsonParser(), JsonlSerializer()\n        )\n        converter.convert(str(input_file))\n\n        result = output_file.read_text(encoding=\"utf-8\")\n        lines = result.split(\"\\n\")\n        assert len(lines) == 2\n        assert json.loads(lines[0]) == {\"name\": \"Alice\"}\n        assert json.loads(lines[1]) == {\"name\": \"Bob\"}\n"
  },
  {
    "path": "tests/test_data.py",
    "content": "\"\"\"\ntest module for the axolotl.utils.data module\n\"\"\"\n\nimport unittest\n\nfrom transformers import LlamaTokenizer\n\nfrom axolotl.utils.data import encode_streaming, md5\nfrom axolotl.utils.trainer import filter_sequences_by_length\n\nfrom tests.hf_offline_utils import enable_hf_offline\n\n\nclass TestEncodePretraining(unittest.TestCase):\n    \"\"\"\n    test class for encode pretraining and md5 helper\n    \"\"\"\n\n    @enable_hf_offline\n    def setUp(self):\n        self.tokenizer = LlamaTokenizer.from_pretrained(\"huggyllama/llama-7b\")\n        self.tokenizer.add_special_tokens(\n            {\n                \"eos_token\": \"</s>\",\n                \"bos_token\": \"<s>\",\n                \"unk_token\": \"<unk>\",\n                \"pad_token\": \"<pad>\",\n            }\n        )\n        self.max_tokens = 15  # set a small number for easy inspection\n\n    def test_encode_pretraining(self):\n        examples = {\n            \"text\": [\n                \"Hello, world!\",\n                \"Nice to meet you.\",\n                \"lorem ipsum dolor sit amet.\",\n                \"Nice to meet you again!.\",\n                \"hello, hello\",\n            ]\n        }\n        result = encode_streaming(examples, self.tokenizer, self.max_tokens)\n\n        self.assertEqual(len(result[\"input_ids\"]), 3)\n\n        # Assert the length of input_ids and attention_mask is correct\n        self.assertEqual(len(result[\"input_ids\"][0]), self.max_tokens)\n        self.assertEqual(len(result[\"attention_mask\"][0]), self.max_tokens)\n\n        # Assert EOS and PAD tokens are correctly added\n        # hello world! is 4 tokens\n        self.assertEqual(result[\"input_ids\"][0][0], self.tokenizer.bos_token_id)\n        self.assertEqual(result[\"input_ids\"][0][5], self.tokenizer.eos_token_id)\n        self.assertEqual(result[\"input_ids\"][0][6], self.tokenizer.pad_token_id)\n        # second part, 5 tokens\n        self.assertEqual(result[\"input_ids\"][0][7], self.tokenizer.bos_token_id)\n        self.assertEqual(result[\"input_ids\"][0][13], self.tokenizer.eos_token_id)\n        self.assertEqual(result[\"input_ids\"][0][14], self.tokenizer.pad_token_id)\n\n    def test_md5(self):\n        self.assertEqual(md5(\"hello world\"), \"5eb63bbbe01eeed093cb22bb8f5acdc3\")\n        self.assertEqual(\n            md5(\"hello world\", \"utf-8\"), \"5eb63bbbe01eeed093cb22bb8f5acdc3\"\n        )\n\n    def test_excess_length_strategy(self):\n        \"\"\"Test that excess_length_strategy results in a value error when set to 'raise'.\"\"\"\n\n        # -- single sequence --\n        # This should work\n        data = {\"input_ids\": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]}\n        filter_sequences_by_length(data, 32, raise_on_drop=True)\n\n        # This should return True, since data fits\n        dropped = filter_sequences_by_length(data, 32)\n        self.assertTrue(dropped)\n\n        # This should raise\n        self.assertRaises(\n            ValueError, filter_sequences_by_length, data, 15, raise_on_drop=True\n        )\n\n        # This should return False, since data doesn't fit\n        dropped = filter_sequences_by_length(data, 15)\n        self.assertFalse(dropped)\n\n        # -- batch sequence --\n        # This should work\n        data = {\n            \"input_ids\": [\n                [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],\n                [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],\n            ]\n        }\n        filter_sequences_by_length(data, 32, raise_on_drop=True)\n\n        # This should raise\n        self.assertRaises(\n            ValueError, filter_sequences_by_length, data, 15, raise_on_drop=True\n        )\n\n        # This should keep the first but drop the second entry\n        dropped = filter_sequences_by_length(data, 15)\n        self.assertEqual(dropped, [True, False])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/test_datasets.py",
    "content": "\"\"\"Test dataset loading under various conditions.\"\"\"\n\nimport shutil\nimport tempfile\nfrom pathlib import Path\nfrom typing import Any, Generator\nfrom unittest.mock import patch\n\nimport pytest\nfrom datasets import Dataset\nfrom huggingface_hub import snapshot_download\nfrom transformers import PreTrainedTokenizer\n\nfrom axolotl.loaders.tokenizer import load_tokenizer\nfrom axolotl.utils.data.rl import prepare_preference_datasets\nfrom axolotl.utils.data.sft import (\n    _load_tokenized_prepared_datasets,\n)\nfrom axolotl.utils.dict import DictDefault\n\nfrom tests.constants import (\n    ALPACA_MESSAGES_CONFIG_OG,\n    ALPACA_MESSAGES_CONFIG_REVISION,\n    SPECIAL_TOKENS,\n)\nfrom tests.hf_offline_utils import enable_hf_offline\n\n\nclass TestDatasetPreparation:\n    \"\"\"Test a configured dataloader.\"\"\"\n\n    @pytest.fixture\n    def tokenizer(\n        self, tokenizer_huggyllama\n    ) -> Generator[PreTrainedTokenizer, Any, Any]:\n        tokenizer_huggyllama.add_special_tokens(SPECIAL_TOKENS)\n        yield tokenizer_huggyllama\n\n    @pytest.fixture\n    def dataset_fixture(self):\n        yield Dataset.from_list(\n            [\n                {\n                    \"instruction\": \"Evaluate this sentence for spelling and grammar mistakes\",\n                    \"input\": \"He finnished his meal and left the resturant\",\n                    \"output\": \"He finished his meal and left the restaurant.\",\n                }\n            ]\n        )\n\n    @pytest.mark.skip(reason=\"TODO: fix hf hub offline to work with HF rate limits\")\n    @enable_hf_offline\n    def test_load_hub(self, tokenizer):\n        \"\"\"Core use case.  Verify that processing data from the hub works\"\"\"\n        with tempfile.TemporaryDirectory() as tmp_dir:\n            prepared_path = Path(tmp_dir) / \"prepared\"\n            cfg = DictDefault(\n                {\n                    \"tokenizer_config\": \"huggyllama/llama-7b\",\n                    \"sequence_len\": 1024,\n                    \"datasets\": [\n                        {\n                            \"path\": \"mhenrichsen/alpaca_2k_test\",\n                            \"type\": \"alpaca\",\n                        },\n                    ],\n                }\n            )\n\n            with patch(\n                \"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH\", str(prepared_path)\n            ):\n                dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)\n\n            assert len(dataset) == 2000\n            assert \"input_ids\" in dataset.features\n            assert \"attention_mask\" in dataset.features\n            assert \"labels\" in dataset.features\n\n    @enable_hf_offline\n    @pytest.mark.skip(\"datasets bug with local datasets when offline\")\n    def test_load_local_hub(self, tokenizer):\n        \"\"\"Niche use case.  Verify that a local copy of a hub dataset can be loaded\"\"\"\n        with tempfile.TemporaryDirectory() as tmp_dir:\n            tmp_ds_path = Path(tmp_dir) / \"mhenrichsen/alpaca_2k_test\"\n            tmp_ds_path.mkdir(parents=True, exist_ok=True)\n            snapshot_path = snapshot_download(\n                repo_id=\"mhenrichsen/alpaca_2k_test\",\n                repo_type=\"dataset\",\n                local_dir=tmp_ds_path,\n            )\n            # offline mode doesn't actually copy it to local_dir, so we\n            # have to copy all the contents in the dir manually from the returned snapshot_path\n            shutil.copytree(snapshot_path, tmp_ds_path, dirs_exist_ok=True)\n\n            prepared_path = Path(tmp_dir) / \"prepared\"\n            # Right now a local copy that doesn't fully conform to a dataset\n            # must list data_files and ds_type otherwise the loader won't know\n            # how to load it.\n            cfg = DictDefault(\n                {\n                    \"tokenizer_config\": \"HuggingFaceTB/SmolLM2-135M\",\n                    \"sequence_len\": 1024,\n                    \"datasets\": [\n                        {\n                            \"path\": \"mhenrichsen/alpaca_2k_test\",\n                            \"ds_type\": \"parquet\",\n                            \"type\": \"alpaca\",\n                            \"data_files\": [\n                                f\"{tmp_ds_path}/alpaca_2000.parquet\",\n                            ],\n                        },\n                    ],\n                }\n            )\n\n            with patch(\n                \"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH\", str(prepared_path)\n            ):\n                dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)\n\n            assert len(dataset) == 2000\n            assert \"input_ids\" in dataset.features\n            assert \"attention_mask\" in dataset.features\n            assert \"labels\" in dataset.features\n            shutil.rmtree(tmp_ds_path)\n\n    @enable_hf_offline\n    def test_load_from_save_to_disk(self, tokenizer, dataset_fixture):\n        \"\"\"Usual use case.  Verify datasets saved via `save_to_disk` can be loaded.\"\"\"\n        with tempfile.TemporaryDirectory() as tmp_dir:\n            tmp_ds_name = Path(tmp_dir) / \"tmp_dataset\"\n            dataset_fixture.save_to_disk(str(tmp_ds_name))\n\n            prepared_path = Path(tmp_dir) / \"prepared\"\n            cfg = DictDefault(\n                {\n                    \"tokenizer_config\": \"huggyllama/llama-7b\",\n                    \"sequence_len\": 256,\n                    \"datasets\": [\n                        {\n                            \"path\": str(tmp_ds_name),\n                            \"type\": \"alpaca\",\n                        },\n                    ],\n                    \"dataset_num_proc\": 4,\n                }\n            )\n\n            with patch(\n                \"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH\", str(prepared_path)\n            ):\n                dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)\n\n            assert len(dataset) == 1\n            assert \"input_ids\" in dataset.features\n            assert \"attention_mask\" in dataset.features\n            assert \"labels\" in dataset.features\n\n    @enable_hf_offline\n    def test_load_from_dir_of_parquet(self, tokenizer, dataset_fixture):\n        \"\"\"Usual use case. Verify a directory of parquet files can be loaded.\"\"\"\n        with tempfile.TemporaryDirectory() as tmp_dir:\n            tmp_ds_dir = Path(tmp_dir) / \"tmp_dataset\"\n            tmp_ds_dir.mkdir()\n            tmp_ds_path = tmp_ds_dir / \"shard1.parquet\"\n            dataset_fixture.to_parquet(tmp_ds_path)\n\n            prepared_path: Path = Path(tmp_dir) / \"prepared\"\n            cfg = DictDefault(\n                {\n                    \"tokenizer_config\": \"huggyllama/llama-7b\",\n                    \"sequence_len\": 256,\n                    \"datasets\": [\n                        {\n                            \"path\": str(tmp_ds_dir),\n                            \"ds_type\": \"parquet\",\n                            \"name\": \"test_data\",\n                            \"data_files\": [\n                                str(tmp_ds_path),\n                            ],\n                            \"type\": \"alpaca\",\n                        },\n                    ],\n                    \"dataset_num_proc\": 4,\n                }\n            )\n\n            with patch(\n                \"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH\", str(prepared_path)\n            ):\n                dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)\n\n            assert len(dataset) == 1\n            assert \"input_ids\" in dataset.features\n            assert \"attention_mask\" in dataset.features\n            assert \"labels\" in dataset.features\n\n    @enable_hf_offline\n    def test_load_from_dir_of_json(self, tokenizer, dataset_fixture):\n        \"\"\"Standard use case.  Verify a directory of json files can be loaded.\"\"\"\n        with tempfile.TemporaryDirectory() as tmp_dir:\n            tmp_ds_dir = Path(tmp_dir) / \"tmp_dataset\"\n            tmp_ds_dir.mkdir()\n            tmp_ds_path = tmp_ds_dir / \"shard1.json\"\n            dataset_fixture.to_json(tmp_ds_path)\n\n            prepared_path: Path = Path(tmp_dir) / \"prepared\"\n            cfg = DictDefault(\n                {\n                    \"tokenizer_config\": \"huggyllama/llama-7b\",\n                    \"sequence_len\": 256,\n                    \"datasets\": [\n                        {\n                            \"path\": str(tmp_ds_dir),\n                            \"ds_type\": \"json\",\n                            \"name\": \"test_data\",\n                            \"data_files\": [\n                                str(tmp_ds_path),\n                            ],\n                            \"type\": \"alpaca\",\n                        },\n                    ],\n                    \"dataset_num_proc\": 4,\n                }\n            )\n\n            with patch(\n                \"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH\", str(prepared_path)\n            ):\n                dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)\n\n            assert len(dataset) == 1\n            assert \"input_ids\" in dataset.features\n            assert \"attention_mask\" in dataset.features\n            assert \"labels\" in dataset.features\n\n    @enable_hf_offline\n    def test_load_from_single_parquet(self, tokenizer, dataset_fixture):\n        \"\"\"Standard use case.  Verify a single parquet file can be loaded.\"\"\"\n        with tempfile.TemporaryDirectory() as tmp_dir:\n            tmp_ds_path = Path(tmp_dir) / \"tmp_dataset.parquet\"\n            dataset_fixture.to_parquet(tmp_ds_path)\n\n            prepared_path: Path = Path(tmp_dir) / \"prepared\"\n            cfg = DictDefault(\n                {\n                    \"tokenizer_config\": \"huggyllama/llama-7b\",\n                    \"sequence_len\": 256,\n                    \"datasets\": [\n                        {\n                            \"path\": str(tmp_ds_path),\n                            \"name\": \"test_data\",\n                            \"type\": \"alpaca\",\n                        },\n                    ],\n                    \"dataset_num_proc\": 4,\n                }\n            )\n\n            with patch(\n                \"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH\", str(prepared_path)\n            ):\n                dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)\n\n            assert len(dataset) == 1\n            assert \"input_ids\" in dataset.features\n            assert \"attention_mask\" in dataset.features\n            assert \"labels\" in dataset.features\n\n    @enable_hf_offline\n    def test_load_from_single_json(self, tokenizer, dataset_fixture):\n        \"\"\"Standard use case.  Verify a single json file can be loaded.\"\"\"\n        with tempfile.TemporaryDirectory() as tmp_dir:\n            tmp_ds_path = Path(tmp_dir) / \"tmp_dataset.json\"\n            dataset_fixture.to_json(tmp_ds_path)\n\n            prepared_path: Path = Path(tmp_dir) / \"prepared\"\n            cfg = DictDefault(\n                {\n                    \"tokenizer_config\": \"huggyllama/llama-7b\",\n                    \"sequence_len\": 256,\n                    \"datasets\": [\n                        {\n                            \"path\": str(tmp_ds_path),\n                            \"name\": \"test_data\",\n                            \"type\": \"alpaca\",\n                        },\n                    ],\n                    \"dataset_num_proc\": 4,\n                }\n            )\n\n            with patch(\n                \"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH\", str(prepared_path)\n            ):\n                dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)\n\n            assert len(dataset) == 1\n            assert \"input_ids\" in dataset.features\n            assert \"attention_mask\" in dataset.features\n            assert \"labels\" in dataset.features\n\n    @pytest.mark.skip(reason=\"TODO: fix hf offline mode for CI rate limits\")\n    @enable_hf_offline\n    def test_load_hub_with_dpo(self):\n        \"\"\"Verify that processing dpo data from the hub works\"\"\"\n\n        cfg = DictDefault(\n            {\n                \"tokenizer_config\": \"huggyllama/llama-7b\",\n                \"sequence_len\": 1024,\n                \"rl\": \"dpo\",\n                \"chat_template\": \"llama3\",\n                \"datasets\": [ALPACA_MESSAGES_CONFIG_OG],\n            }\n        )\n\n        tokenizer = load_tokenizer(cfg)\n        train_dataset, _ = prepare_preference_datasets(cfg, tokenizer)\n\n        assert len(train_dataset) == 1800\n        assert \"conversation\" not in train_dataset.features\n        assert \"chosen\" in train_dataset.features\n        assert \"rejected\" in train_dataset.features\n        assert \"prompt\" in train_dataset.features\n\n    @pytest.mark.skip(reason=\"TODO: fix hf hub offline to work with HF rate limits\")\n    @enable_hf_offline\n    def test_load_hub_with_revision(self, tokenizer):\n        \"\"\"Verify that processing data from the hub works with a specific revision\"\"\"\n        with tempfile.TemporaryDirectory() as tmp_dir:\n            prepared_path = Path(tmp_dir) / \"prepared\"\n\n            # make sure prepared_path is empty\n            shutil.rmtree(prepared_path, ignore_errors=True)\n\n            cfg = DictDefault(\n                {\n                    \"tokenizer_config\": \"huggyllama/llama-7b\",\n                    \"sequence_len\": 1024,\n                    \"datasets\": [\n                        {\n                            \"path\": \"mhenrichsen/alpaca_2k_test\",\n                            \"type\": \"alpaca\",\n                            \"revision\": \"d05c1cb\",\n                        },\n                    ],\n                }\n            )\n\n            with patch(\n                \"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH\", str(prepared_path)\n            ):\n                dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)\n\n            assert len(dataset) == 2000\n            assert \"input_ids\" in dataset.features\n            assert \"attention_mask\" in dataset.features\n            assert \"labels\" in dataset.features\n\n    @enable_hf_offline\n    def test_load_hub_with_revision_with_dpo(\n        self, dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff\n    ):\n        \"\"\"Verify that processing dpo data from the hub works with a specific revision\"\"\"\n\n        cfg = DictDefault(\n            {\n                \"tokenizer_config\": \"huggyllama/llama-7b\",\n                \"sequence_len\": 1024,\n                \"rl\": \"dpo\",\n                \"chat_template\": \"llama3\",\n                \"datasets\": [ALPACA_MESSAGES_CONFIG_REVISION],\n                \"dataset_num_proc\": 4,\n            }\n        )\n\n        with patch(\n            \"axolotl.utils.data.rl.load_dataset_with_config\"\n        ) as mock_load_dataset:\n            # Set up the mock to return different values on successive calls\n            mock_load_dataset.return_value = (\n                dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff\n            )\n\n            tokenizer = load_tokenizer(cfg)\n            train_dataset, _ = prepare_preference_datasets(cfg, tokenizer)\n\n            assert len(train_dataset) == 1800\n            assert \"conversation\" not in train_dataset.features\n            assert \"chosen\" in train_dataset.features\n            assert \"rejected\" in train_dataset.features\n            assert \"prompt\" in train_dataset.features\n\n    @enable_hf_offline\n    @pytest.mark.skip(\"datasets bug with local datasets when offline\")\n    def test_load_local_hub_with_revision(\n        self, dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff, tokenizer\n    ):\n        \"\"\"Verify that a local copy of a hub dataset can be loaded with a specific revision\"\"\"\n        with tempfile.TemporaryDirectory() as tmp_dir:\n            tmp_ds_path = Path(tmp_dir) / \"mhenrichsen/alpaca_2k_test\"\n            tmp_ds_path.mkdir(parents=True, exist_ok=True)\n            snapshot_path = snapshot_download(\n                repo_id=\"mhenrichsen/alpaca_2k_test\",\n                repo_type=\"dataset\",\n                local_dir=tmp_ds_path,\n                revision=\"d05c1cb\",\n            )\n            shutil.copytree(snapshot_path, tmp_ds_path, dirs_exist_ok=True)\n\n            prepared_path = Path(tmp_dir) / \"prepared\"\n            cfg = DictDefault(\n                {\n                    \"tokenizer_config\": \"huggyllama/llama-7b\",\n                    \"sequence_len\": 1024,\n                    \"datasets\": [\n                        {\n                            \"path\": \"mhenrichsen/alpaca_2k_test\",\n                            \"ds_type\": \"parquet\",\n                            \"type\": \"alpaca\",\n                            \"data_files\": [\n                                f\"{tmp_ds_path}/alpaca_2000.parquet\",\n                            ],\n                            \"revision\": \"d05c1cb\",\n                        },\n                    ],\n                }\n            )\n\n            with patch(\n                \"axolotl.utils.data.shared.load_dataset_with_config\"\n            ) as mock_load_dataset:\n                # Set up the mock to return different values on successive calls\n                mock_load_dataset.return_value = (\n                    dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff\n                )\n\n                with patch(\n                    \"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH\",\n                    str(prepared_path),\n                ):\n                    dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)\n\n                assert len(dataset) == 2000\n                assert \"input_ids\" in dataset.features\n                assert \"attention_mask\" in dataset.features\n                assert \"labels\" in dataset.features\n                shutil.rmtree(tmp_ds_path)\n\n    @enable_hf_offline\n    def test_loading_local_dataset_folder(self, tokenizer):\n        \"\"\"Verify that a dataset downloaded to a local folder can be loaded\"\"\"\n\n        with tempfile.TemporaryDirectory() as tmp_dir:\n            tmp_ds_path = Path(tmp_dir) / \"mhenrichsen/alpaca_2k_test\"\n            tmp_ds_path.mkdir(parents=True, exist_ok=True)\n            snapshot_path = snapshot_download(\n                repo_id=\"mhenrichsen/alpaca_2k_test\",\n                repo_type=\"dataset\",\n            )\n            shutil.copytree(snapshot_path, tmp_ds_path, dirs_exist_ok=True)\n\n            prepared_path = Path(tmp_dir) / \"prepared\"\n            cfg = DictDefault(\n                {\n                    \"tokenizer_config\": \"huggyllama/llama-7b\",\n                    \"sequence_len\": 1024,\n                    \"datasets\": [\n                        {\n                            \"path\": str(tmp_ds_path),\n                            \"type\": \"alpaca\",\n                        },\n                    ],\n                    \"dataset_num_proc\": 4,\n                }\n            )\n\n            with patch(\n                \"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH\", str(prepared_path)\n            ):\n                dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)\n\n            assert len(dataset) == 2000\n            assert \"input_ids\" in dataset.features\n            assert \"attention_mask\" in dataset.features\n            assert \"labels\" in dataset.features\n            shutil.rmtree(tmp_ds_path)\n"
  },
  {
    "path": "tests/test_dict.py",
    "content": "\"\"\"Module for testing DictDefault class\"\"\"\n\nimport unittest\n\nimport pytest\n\nfrom axolotl.utils.dict import DictDefault\n\n\nclass DictDefaultTest(unittest.TestCase):\n    \"\"\"\n    Test DictDefault class\n    \"\"\"\n\n    def test_dict_default(self):\n        cfg = DictDefault(\n            {\n                \"key_a\": {\"key_b\": \"value_a\"},\n                \"key_c\": \"value_c\",\n                \"key_d\": [\"value_d\", \"value_e\"],\n            }\n        )\n\n        assert cfg.key_a.key_b == \"value_a\", (\n            \"DictDefault should return value for existing nested keys\"\n        )\n\n        assert cfg.key_c == \"value_c\", (\n            \"DictDefault should return value for existing keys\"\n        )\n\n        assert cfg.key_d[0] == \"value_d\", (\n            \"DictDefault should return value for existing keys in list\"\n        )\n\n        assert \"value_e\" in cfg.key_d, (\n            \"DictDefault should support in operator for existing keys in list\"\n        )\n\n    def test_dict_or_operator(self):\n        cfg = DictDefault({\"key_a\": {\"key_b\": \"value_b\"}, \"key_f\": \"value_g\"})\n\n        cfg = cfg | DictDefault(\n            {\n                \"key_a\": {\"key_b\": \"value_a\"},\n                \"key_c\": \"value_c\",\n                \"key_d\": [\"value_d\", \"value_e\"],\n                \"key_f\": \"value_f\",\n            }\n        )\n\n        assert cfg.key_a.key_b == \"value_b\", (\n            \"DictDefault should support OR operator for existing nested keys\"\n        )\n\n        assert cfg.key_c == \"value_c\", \"DictDefault should not delete existing key\"\n\n        assert cfg.key_d == [\n            \"value_d\",\n            \"value_e\",\n        ], \"DictDefault should not overwrite existing keys in list\"\n\n        assert cfg.key_f == \"value_g\", (\n            \"DictDefault should support OR operator for existing key\"\n        )\n\n    def test_dict_missingkey(self):\n        cfg = DictDefault({})\n\n        assert cfg.random_key is None, \"DictDefault should return None for missing keys\"\n\n    def test_dict_or(self):\n        cfg = DictDefault({}) | DictDefault({})\n\n        assert cfg.random_key is None, (\n            \"DictDefault should return None for missing keys after | operation\"\n        )\n\n    def test_dict_nested_missingparentkey(self):\n        \"\"\"\n        Due to subclassing Dict, DictDefault will error if we try to access a nested key whose parent key does not exist.\n        \"\"\"\n        cfg = DictDefault({})\n\n        with pytest.raises(\n            AttributeError,\n            match=r\"'NoneType' object has no attribute 'another_random_key'\",\n        ):\n            cfg.random_key.another_random_key = \"value\"\n\n    def test_dict_shorthand_assignment(self):\n        \"\"\"\n        Shorthand assignment is said to not be supported if subclassed. However, their example raises error instead of None.\n        This test ensures that it is supported for current implementation.\n\n        Ref: https://github.com/mewwts/addict#default-values\n        \"\"\"\n\n        cfg = DictDefault({\"key_a\": {\"key_b\": \"value_a\"}})\n\n        cfg.key_a.key_b = \"value_b\"\n\n        assert cfg.key_a.key_b == \"value_b\", \"Shorthand assignment should be supported\"\n"
  },
  {
    "path": "tests/test_exact_deduplication.py",
    "content": "\"\"\"Test suite for functions in the `axolotl.utils.data.utils` module, focusing on the\n`deduplicate_and_log_datasets` function.\n\nAdditionally, this test suite includes tests for functions that indirectly call\n`deduplicate_and_log_datasets` during the execution of the preprocess command.\n\"\"\"\n\nimport unittest\nfrom unittest.mock import patch\n\nimport pytest\nfrom datasets import Dataset\n\nfrom axolotl.loaders import load_processor, load_tokenizer\nfrom axolotl.utils.config import normalize_config, validate_config\nfrom axolotl.utils.data import prepare_datasets, prepare_preference_datasets\nfrom axolotl.utils.data.utils import deduplicate_and_log_datasets\nfrom axolotl.utils.dict import DictDefault\n\nfrom tests.constants import ALPACA_MESSAGES_CONFIG_REVISION\nfrom tests.hf_offline_utils import enable_hf_offline\n\n\ndef verify_deduplication(actual_dataset, expected_dataset, dataset_name):\n    \"\"\"Validates deduplication results and size consistency.\n\n    Parameters:\n    - actual_dataset: Deduplicated dataset.\n    - expected_dataset: Expected dataset.\n    - dataset_name: Name of the dataset (e.g., 'train' or 'eval').\n\n    Asserts:\n    - Datasets match in content.\n    - Dataset size matches unique row count.\n    \"\"\"\n    # Convert datasets to sets of tuples for unordered comparison\n    actual_rows = set(tuple(row.values()) for row in actual_dataset)\n    expected_rows = set(tuple(row.values()) for row in expected_dataset)\n\n    # Verify deduplication correctness\n    assert actual_rows == expected_rows, f\"Mismatch in {dataset_name} dataset\"\n\n    # Verify size consistency\n    assert len(actual_rows) == len(actual_dataset), (\n        f\"Size mismatch in {dataset_name} dataset after deduplication\"\n    )\n\n\nclass TestDeduplicateIndividualFunctions(unittest.TestCase):\n    \"\"\"Test class for deduplication function in data utils\"\"\"\n\n    def setUp(self):\n        # Sample data with duplicates\n        self.data = {\n            \"column1\": [\"apple\", \"banana\", \"apple\", \"orange\", \"banana\"],\n            \"column2\": [1, 2, 1, 3, 2],\n            \"column3\": [\"red\", \"yellow\", \"red\", \"orange\", \"yellow\"],\n        }\n\n        # Expected result after deduplication\n        self.expected_data = {\n            \"column1\": [\"apple\", \"banana\", \"orange\"],\n            \"column2\": [1, 2, 3],\n            \"column3\": [\"red\", \"yellow\", \"orange\"],\n        }\n\n        # Convert to Dataset format\n        self.dataset = Dataset.from_dict(self.data)\n        self.expected_dataset = Dataset.from_dict(self.expected_data)\n\n    def test_deduplication(self):\n        train_dataset, _ = deduplicate_and_log_datasets(dataset=self.dataset)\n        eval_dataset, _ = deduplicate_and_log_datasets(\n            dataset=self.dataset, dataset_name=\"eval\"\n        )\n\n        verify_deduplication(train_dataset, self.expected_dataset, \"train_dataset\")\n        verify_deduplication(eval_dataset, self.expected_dataset, \"eval_dataset\")\n\n    def test_exact_duplicates(self):\n        # Test when datasets are exact duplicates\n        duplicate_data = {\n            \"column1\": [\"apple\", \"apple\", \"apple\"],\n            \"column2\": [1, 1, 1],\n            \"column3\": [\"red\", \"red\", \"red\"],\n        }\n        expected_data = {\"column1\": [\"apple\"], \"column2\": [1], \"column3\": [\"red\"]}\n\n        # Convert to Dataset format\n        dataset = Dataset.from_dict(duplicate_data)\n        expected_dataset = Dataset.from_dict(expected_data)\n\n        # Run deduplication\n        train_dataset, _ = deduplicate_and_log_datasets(dataset=dataset)\n        eval_dataset, _ = deduplicate_and_log_datasets(\n            dataset=dataset, dataset_name=\"eval\"\n        )\n\n        verify_deduplication(train_dataset, expected_dataset, \"train_dataset\")\n        verify_deduplication(eval_dataset, expected_dataset, \"eval_dataset\")\n\n    def test_partial_duplicates(self):\n        # Test when only part of the dataset is a duplicate\n        partial_duplicate_data = {\n            \"column1\": [\"apple\", \"banana\", \"apple\"],\n            \"column2\": [1, 2, 1],\n            \"column3\": [\"red\", \"yellow\", \"red\"],\n        }\n        expected_data = {\n            \"column1\": [\"apple\", \"banana\"],\n            \"column2\": [1, 2],\n            \"column3\": [\"red\", \"yellow\"],\n        }\n\n        # Convert to Dataset format\n        dataset = Dataset.from_dict(partial_duplicate_data)\n        expected_dataset = Dataset.from_dict(expected_data)\n\n        # Run deduplication\n        train_dataset, _ = deduplicate_and_log_datasets(dataset=dataset)\n        eval_dataset, _ = deduplicate_and_log_datasets(\n            dataset=dataset, dataset_name=\"eval\"\n        )\n\n        verify_deduplication(train_dataset, expected_dataset, \"train_dataset\")\n        verify_deduplication(eval_dataset, expected_dataset, \"eval_dataset\")\n\n    def test_combined_duplicates_empty(self):\n        # Test when only part of the dataset is a duplicate\n        partial_duplicate_data = {\n            \"column1\": [\"apple\", \"banana\", \"apple\"],\n            \"column2\": [1, 2, 1],\n            \"column3\": [\"red\", \"yellow\", \"red\"],\n        }\n        expected_data_train = {\n            \"column1\": [\"apple\", \"banana\"],\n            \"column2\": [1, 2],\n            \"column3\": [\"red\", \"yellow\"],\n        }\n        expected_data_eval = {\n            \"column1\": [],\n            \"column2\": [],\n            \"column3\": [],\n        }\n\n        # Convert to Dataset format\n        dataset = Dataset.from_dict(partial_duplicate_data)\n        expected_dataset_train = Dataset.from_dict(expected_data_train)\n        expected_dataset_eval = Dataset.from_dict(expected_data_eval)\n\n        # Run deduplication\n        train_dataset, eval_dataset = deduplicate_and_log_datasets(\n            dataset=dataset, other_dataset=dataset\n        )\n\n        verify_deduplication(train_dataset, expected_dataset_train, \"train_dataset\")\n        verify_deduplication(eval_dataset, expected_dataset_eval, \"eval_dataset\")\n\n    def test_combined_duplicates_one(self):\n        # Test when only part of the dataset is a duplicate\n        partial_duplicate_data_train = {\n            \"column1\": [\"apple\", \"banana\", \"apple\"],\n            \"column2\": [1, 2, 1],\n            \"column3\": [\"red\", \"yellow\", \"red\"],\n        }\n        partial_duplicate_data_eval = {\n            \"column1\": [\"apple\", \"orange\", \"apple\"],\n            \"column2\": [1, 2, 1],\n            \"column3\": [\"red\", \"orange\", \"red\"],\n        }\n        expected_data_train = {\n            \"column1\": [\"apple\", \"banana\"],\n            \"column2\": [1, 2],\n            \"column3\": [\"red\", \"yellow\"],\n        }\n        expected_data_eval = {\n            \"column1\": [\"orange\"],\n            \"column2\": [2],\n            \"column3\": [\"orange\"],\n        }\n\n        # Convert to Dataset format\n        dataset_train = Dataset.from_dict(partial_duplicate_data_train)\n        dataset_eval = Dataset.from_dict(partial_duplicate_data_eval)\n        expected_dataset_train = Dataset.from_dict(expected_data_train)\n        expected_dataset_eval = Dataset.from_dict(expected_data_eval)\n\n        # Run deduplication\n        train_dataset, eval_dataset = deduplicate_and_log_datasets(\n            dataset=dataset_train, other_dataset=dataset_eval\n        )\n\n        verify_deduplication(train_dataset, expected_dataset_train, \"train_dataset\")\n        verify_deduplication(eval_dataset, expected_dataset_eval, \"eval_dataset\")\n\n\nclass TestDeduplicateRLDataset:\n    \"\"\"Test a configured dataloader with deduplication.\"\"\"\n\n    @pytest.fixture\n    def cfg(self):\n        fixture = DictDefault(\n            {\n                \"tokenizer_config\": \"huggyllama/llama-7b\",\n                \"sequence_len\": 1024,\n                \"rl\": \"dpo\",\n                \"chat_template\": \"llama3\",\n                \"dataset_exact_deduplication\": True,\n                \"datasets\": [\n                    ALPACA_MESSAGES_CONFIG_REVISION,\n                    ALPACA_MESSAGES_CONFIG_REVISION,\n                ],\n                \"dataset_num_proc\": 4,\n            }\n        )\n        yield fixture\n\n    @enable_hf_offline\n    def test_load_with_deduplication(\n        self,\n        cfg,\n        dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff,\n        tokenizer_huggyllama,\n    ):\n        \"\"\"Verify that loading with deduplication removes duplicates.\"\"\"\n\n        with (\n            patch(\n                \"axolotl.utils.data.rl.load_dataset_with_config\"\n            ) as mock_load_dataset,\n            patch(\"axolotl.loaders.load_tokenizer\") as mock_load_tokenizer,\n        ):\n            # Set up the mock to return different values on successive calls\n            mock_load_dataset.side_effect = [\n                dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff,\n                dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff,\n            ]\n            mock_load_tokenizer.return_value = tokenizer_huggyllama\n\n            tokenizer = load_tokenizer(cfg)\n            train_dataset, _ = prepare_preference_datasets(cfg, tokenizer)\n\n            # Verify that the dataset has been deduplicated\n            assert len(train_dataset) == 1800, \"Dataset was not properly deduplicated\"\n\n    @enable_hf_offline\n    def test_load_without_deduplication(\n        self,\n        cfg,\n        dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff,\n        tokenizer_huggyllama,\n    ):\n        with (\n            patch(\n                \"axolotl.utils.data.rl.load_dataset_with_config\"\n            ) as mock_load_dataset,\n            patch(\"axolotl.loaders.load_tokenizer\") as mock_load_tokenizer,\n        ):\n            # Set up the mock to return different values on successive calls\n            mock_load_dataset.side_effect = [\n                dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff,\n                dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff,\n            ]\n            mock_load_tokenizer.return_value = tokenizer_huggyllama\n\n            # Load the dataset without deduplication\n            cfg.dataset_exact_deduplication = False\n            tokenizer = load_tokenizer(cfg)\n            train_dataset, _ = prepare_preference_datasets(cfg, tokenizer)\n\n            # Verify that the dataset retains duplicates\n            assert len(train_dataset) == 1800 * 2, (\n                \"Dataset deduplication occurred when it should not have\"\n            )\n\n\nclass TestDeduplicateNonRL(unittest.TestCase):\n    \"\"\"Test prepare_dataset function with different configurations.\"\"\"\n\n    @enable_hf_offline\n    def setUp(self) -> None:\n        self.cfg_1 = DictDefault(\n            {\n                \"base_model\": \"huggyllama/llama-7b\",\n                \"tokenizer_config\": \"huggyllama/llama-7b\",\n                \"sequence_len\": 1024,\n                \"dataset_exact_deduplication\": True,\n                \"datasets\": [\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": \"alpaca\",\n                    },\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": \"alpaca\",\n                    },\n                ],\n                \"val_set_size\": 0.0,\n                \"gradient_accumulation_steps\": 2,\n                \"batch_size\": 10,\n                \"micro_batch_size\": 10,\n                \"num_epochs\": 1,\n            }\n        )\n        self.cfg_1 = validate_config(self.cfg_1)\n        normalize_config(self.cfg_1)\n\n    @pytest.mark.skip(reason=\"TODO: fix hf hub offline to work with HF rate limits\")\n    @enable_hf_offline\n    def test_prepare_dataset_with_deduplication_train(self):\n        \"\"\"Verify that prepare_dataset function processes the dataset correctly with deduplication.\"\"\"\n        self.cfg_1.dataset_exact_deduplication = True\n\n        # Load tokenizer and processor\n        tokenizer = load_tokenizer(self.cfg_1)\n        processor = (\n            load_processor(self.cfg_1, tokenizer=tokenizer)\n            if self.cfg_1.processor_type\n            else None\n        )\n\n        # Prepare dataset using the prepare_dataset function\n        train_dataset, _, _, _ = prepare_datasets(\n            self.cfg_1,\n            tokenizer,\n            processor=processor,\n        )\n\n        self.assertEqual(\n            len(train_dataset),\n            2000,\n            \"Train dataset should have 2000 samples after deduplication.\",\n        )\n\n    @pytest.mark.skip(reason=\"TODO: fix hf hub offline to work with HF rate limits\")\n    @enable_hf_offline\n    def test_prepare_dataset_with_deduplication_eval(self):\n        \"\"\"Verify that prepare_dataset function processes the dataset correctly with deduplication.\"\"\"\n        self.cfg_1.dataset_exact_deduplication = True\n        self.cfg_1.val_set_size = 0.5\n        # Load tokenizer and processor\n        tokenizer = load_tokenizer(self.cfg_1)\n        processor = (\n            load_processor(self.cfg_1, tokenizer=tokenizer)\n            if self.cfg_1.processor_type\n            else None\n        )\n\n        # Prepare dataset using the prepare_dataset function\n        _, eval_dataset, _, _ = prepare_datasets(\n            self.cfg_1,\n            tokenizer,\n            processor=processor,\n        )\n\n        self.assertEqual(\n            len(eval_dataset),\n            1000,\n            \"Eval dataset should have 2000 samples after deduplication.\",\n        )\n\n    @pytest.mark.skip(reason=\"TODO: fix hf hub offline to work with HF rate limits\")\n    @enable_hf_offline\n    def test_prepare_dataset_without_deduplication(self):\n        \"\"\"Verify that prepare_dataset function processes the dataset correctly without deduplication.\"\"\"\n        self.cfg_1.dataset_exact_deduplication = False\n        self.cfg_1.val_set_size = 0.1\n        # Load tokenizer and processor\n        tokenizer = load_tokenizer(self.cfg_1)\n        processor = (\n            load_processor(self.cfg_1, tokenizer=tokenizer)\n            if self.cfg_1.processor_type\n            else None\n        )\n\n        # Prepare dataset using the prepare_dataset function\n        train_dataset, eval_dataset, _, _ = prepare_datasets(\n            self.cfg_1,\n            tokenizer,\n            processor=processor,\n        )\n\n        # Verify that the dataset has been prepared correctly\n        self.assertEqual(\n            len(train_dataset),\n            1800 * 2,\n            \"Train dataset should have 3600 samples without deduplication.\",\n        )\n        self.assertEqual(\n            len(eval_dataset),\n            200 * 2,\n            \"Train dataset should have 400 samples after deduplication.\",\n        )\n\n\nclass TestWrongCollisions(unittest.TestCase):\n    \"\"\"Creating mock datasets for testing wrong collisions.\"\"\"\n\n    def setUp(self):\n        self.train_data = {\"text\": [\"sample 5\", \"sample 6\"], \"label\": [1, 2]}\n        self.eval_data = {\n            \"text\": [\n                \"sample 5\",\n                \"sample 7\",\n            ],  # Different label but same text as in train_data\n            \"label\": [2, 3],\n        }\n        self.dataset_data = {\n            \"text\": [\"sample 5\", \"sample 9\", \"sample 5\"],\n            \"label\": [1, 2, 8],\n        }\n        self.train_dataset = Dataset.from_dict(self.train_data)\n        self.eval_dataset = Dataset.from_dict(self.eval_data)\n        self.dataset = Dataset.from_dict(self.dataset_data)\n\n    def test_deduplication_dataset_only(self):\n        dedup_dataset, _ = deduplicate_and_log_datasets(dataset=self.dataset)\n        self.assertEqual(\n            len(dedup_dataset), 3, \"Dataset should have all original values\"\n        )\n        self.assertEqual(\n            str(dedup_dataset),\n            str(self.dataset),\n            \"The string representation of the output dataset should not differ.\",\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/test_freeze.py",
    "content": "\"\"\"\nThis module contains unit tests for the `freeze_layers_except` function.\n\nThe `freeze_layers_except` function is used to freeze layers in a model, except for the specified layers.\nThe unit tests in this module verify the behavior of the `freeze_layers_except` function in different scenarios.\n\"\"\"\n\nimport unittest\n\nimport torch\nfrom torch import nn\n\nfrom axolotl.utils.freeze import freeze_layers_except\n\nZERO = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]\nONE_TO_TEN = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]\n\n\nclass TestFreezeLayersExcept(unittest.TestCase):\n    \"\"\"\n    A test case class for the `freeze_layers_except` function.\n    \"\"\"\n\n    def setUp(self):\n        self.model = _TestModel()\n\n    def test_freeze_layers_with_dots_in_name(self):\n        freeze_layers_except(self.model, [\"features.layer\"])\n        self.assertTrue(\n            self.model.features.layer.weight.requires_grad,\n            \"model.features.layer should be trainable.\",\n        )\n        self.assertFalse(\n            self.model.classifier.weight.requires_grad,\n            \"model.classifier should be frozen.\",\n        )\n\n    def test_freeze_layers_without_dots_in_name(self):\n        freeze_layers_except(self.model, [\"classifier\"])\n        self.assertFalse(\n            self.model.features.layer.weight.requires_grad,\n            \"model.features.layer should be trainable.\",\n        )\n        self.assertTrue(\n            self.model.classifier.weight.requires_grad,\n            \"model.classifier should be frozen.\",\n        )\n\n    def test_freeze_layers_regex_patterns(self):\n        # The second pattern cannot match because only characters 'a' to 'c' are allowed after the word 'class', whereas it should be matching the character 'i'.\n        freeze_layers_except(self.model, [r\"^features.[a-z]+.weight$\", r\"class[a-c]+\"])\n        self.assertTrue(\n            self.model.features.layer.weight.requires_grad,\n            \"model.features.layer should be trainable.\",\n        )\n        self.assertFalse(\n            self.model.classifier.weight.requires_grad,\n            \"model.classifier should be frozen.\",\n        )\n\n    def test_all_layers_frozen(self):\n        freeze_layers_except(self.model, [])\n        self.assertFalse(\n            self.model.features.layer.weight.requires_grad,\n            \"model.features.layer should be frozen.\",\n        )\n        self.assertFalse(\n            self.model.classifier.weight.requires_grad,\n            \"model.classifier should be frozen.\",\n        )\n\n    def test_all_layers_unfrozen(self):\n        freeze_layers_except(self.model, [\"features.layer\", \"classifier\"])\n        self.assertTrue(\n            self.model.features.layer.weight.requires_grad,\n            \"model.features.layer should be trainable.\",\n        )\n        self.assertTrue(\n            self.model.classifier.weight.requires_grad,\n            \"model.classifier should be trainable.\",\n        )\n\n    def test_freeze_layers_with_range_pattern_start_end(self):\n        freeze_layers_except(self.model, [\"features.layer[1:5]\"])\n        self.assertTrue(\n            self.model.features.layer.weight.requires_grad,\n            \"model.features.layer should be trainable.\",\n        )\n        self.assertFalse(\n            self.model.classifier.weight.requires_grad,\n            \"model.classifier should be frozen.\",\n        )\n\n        self._assert_gradient_output(\n            [\n                ZERO,\n                ONE_TO_TEN,\n                ONE_TO_TEN,\n                ONE_TO_TEN,\n                ONE_TO_TEN,\n                ZERO,\n                ZERO,\n                ZERO,\n                ZERO,\n                ZERO,\n            ]\n        )\n\n    def test_freeze_layers_with_range_pattern_single_index(self):\n        freeze_layers_except(self.model, [\"features.layer[5]\"])\n        self.assertTrue(\n            self.model.features.layer.weight.requires_grad,\n            \"model.features.layer should be trainable.\",\n        )\n        self.assertFalse(\n            self.model.classifier.weight.requires_grad,\n            \"model.classifier should be frozen.\",\n        )\n\n        self._assert_gradient_output(\n            [ZERO, ZERO, ZERO, ZERO, ZERO, ONE_TO_TEN, ZERO, ZERO, ZERO, ZERO]\n        )\n\n    def test_freeze_layers_with_range_pattern_start_omitted(self):\n        freeze_layers_except(self.model, [\"features.layer[:5]\"])\n        self.assertTrue(\n            self.model.features.layer.weight.requires_grad,\n            \"model.features.layer should be trainable.\",\n        )\n        self.assertFalse(\n            self.model.classifier.weight.requires_grad,\n            \"model.classifier should be frozen.\",\n        )\n\n        self._assert_gradient_output(\n            [\n                ONE_TO_TEN,\n                ONE_TO_TEN,\n                ONE_TO_TEN,\n                ONE_TO_TEN,\n                ONE_TO_TEN,\n                ZERO,\n                ZERO,\n                ZERO,\n                ZERO,\n                ZERO,\n            ]\n        )\n\n    def test_freeze_layers_with_range_pattern_end_omitted(self):\n        freeze_layers_except(self.model, [\"features.layer[4:]\"])\n        self.assertTrue(\n            self.model.features.layer.weight.requires_grad,\n            \"model.features.layer should be trainable.\",\n        )\n        self.assertFalse(\n            self.model.classifier.weight.requires_grad,\n            \"model.classifier should be frozen.\",\n        )\n\n        self._assert_gradient_output(\n            [\n                ZERO,\n                ZERO,\n                ZERO,\n                ZERO,\n                ONE_TO_TEN,\n                ONE_TO_TEN,\n                ONE_TO_TEN,\n                ONE_TO_TEN,\n                ONE_TO_TEN,\n                ONE_TO_TEN,\n            ]\n        )\n\n    def test_freeze_layers_with_range_pattern_merge_included(self):\n        freeze_layers_except(self.model, [\"features.layer[4:]\", \"features.layer[5:6]\"])\n        self.assertTrue(\n            self.model.features.layer.weight.requires_grad,\n            \"model.features.layer should be trainable.\",\n        )\n        self.assertFalse(\n            self.model.classifier.weight.requires_grad,\n            \"model.classifier should be frozen.\",\n        )\n\n        self._assert_gradient_output(\n            [\n                ZERO,\n                ZERO,\n                ZERO,\n                ZERO,\n                ONE_TO_TEN,\n                ONE_TO_TEN,\n                ONE_TO_TEN,\n                ONE_TO_TEN,\n                ONE_TO_TEN,\n                ONE_TO_TEN,\n            ]\n        )\n\n    def test_freeze_layers_with_range_pattern_merge_intersect(self):\n        freeze_layers_except(self.model, [\"features.layer[4:7]\", \"features.layer[6:8]\"])\n        self.assertTrue(\n            self.model.features.layer.weight.requires_grad,\n            \"model.features.layer should be trainable.\",\n        )\n        self.assertFalse(\n            self.model.classifier.weight.requires_grad,\n            \"model.classifier should be frozen.\",\n        )\n\n        self._assert_gradient_output(\n            [\n                ZERO,\n                ZERO,\n                ZERO,\n                ZERO,\n                ONE_TO_TEN,\n                ONE_TO_TEN,\n                ONE_TO_TEN,\n                ONE_TO_TEN,\n                ZERO,\n                ZERO,\n            ]\n        )\n\n    def test_freeze_layers_with_range_pattern_merge_separate(self):\n        freeze_layers_except(\n            self.model,\n            [\"features.layer[1:2]\", \"features.layer[3:4]\", \"features.layer[5:6]\"],\n        )\n        self.assertTrue(\n            self.model.features.layer.weight.requires_grad,\n            \"model.features.layer should be trainable.\",\n        )\n        self.assertFalse(\n            self.model.classifier.weight.requires_grad,\n            \"model.classifier should be frozen.\",\n        )\n\n        self._assert_gradient_output(\n            [\n                ZERO,\n                ONE_TO_TEN,\n                ZERO,\n                ONE_TO_TEN,\n                ZERO,\n                ONE_TO_TEN,\n                ZERO,\n                ZERO,\n                ZERO,\n                ZERO,\n            ]\n        )\n\n    def _assert_gradient_output(self, expected):\n        input_tensor = torch.tensor([ONE_TO_TEN], dtype=torch.float32)\n\n        self.model.features.layer.weight.grad = None  # Reset gradients\n        output = self.model.features.layer(input_tensor)\n        loss = output.sum()\n        loss.backward()\n\n        expected_grads = torch.tensor(expected)\n        torch.testing.assert_close(\n            self.model.features.layer.weight.grad, expected_grads\n        )\n\n\nclass _SubLayerModule(nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.layer = nn.Linear(10, 10)\n\n\nclass _TestModel(nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.features = _SubLayerModule()\n        self.classifier = nn.Linear(10, 2)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/test_loaders.py",
    "content": "\"\"\"Module for `axolotl.loaders`.\"\"\"\n\nfrom unittest.mock import MagicMock\n\nimport pytest\nfrom transformers import BitsAndBytesConfig, PreTrainedTokenizerBase\nfrom transformers.integrations.deepspeed import is_deepspeed_zero3_enabled\nfrom transformers.utils.import_utils import is_torch_mps_available\n\nfrom axolotl.loaders import ModelLoader\nfrom axolotl.utils.dict import DictDefault\nfrom axolotl.utils.distributed import _get_parallel_config_kwargs\n\n\nclass TestModelsUtils:\n    \"\"\"Testing module for `axolotl.loaders`.\"\"\"\n\n    def setup_method(self) -> None:\n        # load config\n        self.cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"model_type\": \"AutoModelForCausalLM\",\n                \"tokenizer_type\": \"AutoTokenizer\",\n                \"load_in_8bit\": True,\n                \"load_in_4bit\": False,\n                \"adapter\": \"lora\",\n                \"flash_attention\": False,\n                \"sample_packing\": True,\n                \"device_map\": \"auto\",\n            }\n        )\n        self.tokenizer = MagicMock(spec=PreTrainedTokenizerBase)\n        self.inference = False\n        self.reference_model = True\n\n        # init ModelLoader\n        self.model_loader = ModelLoader(\n            cfg=self.cfg,\n            tokenizer=self.tokenizer,\n            inference=self.inference,\n            reference_model=self.reference_model,\n        )\n\n    def test_set_device_map_config(self):\n        # check device_map\n        device_map = self.cfg.device_map\n        if is_torch_mps_available():\n            device_map = \"mps\"\n\n        self.model_loader._set_device_map_config()\n        if is_deepspeed_zero3_enabled():\n            assert \"device_map\" not in self.model_loader.model_kwargs\n        else:\n            assert device_map in self.model_loader.model_kwargs[\"device_map\"]\n\n        # check torch_dtype\n        assert self.cfg.torch_dtype == self.model_loader.model_kwargs[\"torch_dtype\"]\n\n    @pytest.mark.parametrize(\"adapter\", [\"lora\", \"qlora\", None])\n    @pytest.mark.parametrize(\"load_in_8bit\", [True, False])\n    @pytest.mark.parametrize(\"load_in_4bit\", [True, False])\n    @pytest.mark.parametrize(\"gptq\", [True, False])\n    def test_set_quantization_config(\n        self,\n        adapter,\n        load_in_8bit,\n        load_in_4bit,\n        gptq,\n    ):\n        # init cfg as args\n        self.cfg.load_in_8bit = load_in_8bit\n        self.cfg.load_in_4bit = load_in_4bit\n        self.cfg.gptq = gptq\n        self.cfg.adapter = adapter\n\n        self.model_loader._set_quantization_config()\n        if \"quantization_config\" in self.model_loader.model_kwargs or self.cfg.gptq:\n            assert not (\n                hasattr(self.model_loader.model_kwargs, \"load_in_8bit\")\n                and hasattr(self.model_loader.model_kwargs, \"load_in_4bit\")\n            )\n\n        if self.cfg.adapter == \"qlora\" and load_in_4bit:\n            assert isinstance(\n                self.model_loader.model_kwargs.get(\"quantization_config\"),\n                BitsAndBytesConfig,\n            )\n\n            assert (\n                self.model_loader.model_kwargs[\"quantization_config\"]._load_in_4bit\n                is True\n            )\n        if self.cfg.adapter == \"lora\" and load_in_8bit:\n            assert isinstance(\n                self.model_loader.model_kwargs.get(\"quantization_config\"),\n                BitsAndBytesConfig,\n            )\n\n            assert (\n                self.model_loader.model_kwargs[\"quantization_config\"]._load_in_8bit\n                is True\n            )\n\n    def test_message_property_mapping(self):\n        \"\"\"Test message property mapping configuration validation\"\"\"\n        from axolotl.utils.schemas.datasets import SFTDataset\n\n        # Test legacy fields are mapped orrectly\n        dataset = SFTDataset(\n            path=\"test_path\",\n            message_field_role=\"role_field\",\n            message_field_content=\"content_field\",\n        )\n        assert dataset.message_property_mappings == {\n            \"role\": \"role_field\",\n            \"content\": \"content_field\",\n        }\n\n        # Test direct message_property_mapping works\n        dataset = SFTDataset(\n            path=\"test_path\",\n            message_property_mappings={\n                \"role\": \"custom_role\",\n                \"content\": \"custom_content\",\n            },\n        )\n        assert dataset.message_property_mappings == {\n            \"role\": \"custom_role\",\n            \"content\": \"custom_content\",\n        }\n\n        # Test both legacy and new fields work when they match\n        dataset = SFTDataset(\n            path=\"test_path\",\n            message_field_role=\"same_role\",\n            message_property_mappings={\"role\": \"same_role\"},\n        )\n        assert dataset.message_property_mappings == {\n            \"role\": \"same_role\",\n            \"content\": \"content\",\n        }\n\n        # Test both legacy and new fields work when they don't overlap\n        dataset = SFTDataset(\n            path=\"test_path\",\n            message_field_role=\"role_field\",\n            message_property_mappings={\"content\": \"content_field\"},\n        )\n        assert dataset.message_property_mappings == {\n            \"role\": \"role_field\",\n            \"content\": \"content_field\",\n        }\n\n        # Test no role or content provided\n        dataset = SFTDataset(\n            path=\"test_path\",\n        )\n        assert dataset.message_property_mappings == {\n            \"role\": \"role\",\n            \"content\": \"content\",\n        }\n\n        # Test error when legacy and new fields conflict\n        with pytest.raises(ValueError) as exc_info:\n            SFTDataset(\n                path=\"test_path\",\n                message_field_role=\"legacy_role\",\n                message_property_mappings={\"role\": \"different_role\"},\n            )\n        assert \"Conflicting message role fields\" in str(exc_info.value)\n\n        with pytest.raises(ValueError) as exc_info:\n            SFTDataset(\n                path=\"test_path\",\n                message_field_content=\"legacy_content\",\n                message_property_mappings={\"content\": \"different_content\"},\n            )\n        assert \"Conflicting message content fields\" in str(exc_info.value)\n\n    @pytest.mark.parametrize(\n        \"world_size, tensor_parallel_size, context_parallel_size, dp_shard_size, dp_replicate_size, is_fsdp, expected\",\n        [\n            (16, 2, 2, 2, 2, True, (2, 2, 2, 2)),\n            (16, 1, 1, None, None, True, (0, 0, 16, 1)),\n            (16, 2, 2, 2, None, True, (2, 2, 2, 2)),\n            (16, 2, 2, None, 2, True, (2, 2, 2, 2)),\n            (16, 1, 1, None, 2, True, (0, 0, 8, 2)),\n            (2, 1, 1, None, None, True, (0, 0, 2, 1)),\n        ],\n    )\n    def test_get_parallel_config_kwargs(\n        self,\n        world_size,\n        tensor_parallel_size,\n        context_parallel_size,\n        dp_shard_size,\n        dp_replicate_size,\n        is_fsdp,\n        expected,\n    ):\n        res = _get_parallel_config_kwargs(\n            world_size,\n            tensor_parallel_size,\n            context_parallel_size,\n            dp_shard_size,\n            dp_replicate_size,\n            is_fsdp,\n        )\n\n        if expected[0] > 1:\n            assert res[\"tp_size\"] == expected[0]\n        if expected[1] > 1:\n            assert res[\"cp_size\"] == expected[1]\n        if expected[2] > 1:\n            assert res[\"dp_shard_size\"] == expected[2]\n        if expected[3] > 1:\n            assert res[\"dp_replicate_size\"] == expected[3]\n"
  },
  {
    "path": "tests/test_logging_config_file_capture.py",
    "content": "import logging\nimport tempfile\n\nimport pytest\n\n\ndef read(path: str) -> str:\n    with open(path, \"r\", encoding=\"utf-8\") as f:\n        return f.read()\n\n\n@pytest.fixture(autouse=True)\ndef _reset_logging_state():\n    # Ensure a clean slate for logging between tests\n    for handler in logging.root.handlers[:]:\n        logging.root.removeHandler(handler)\n    logging.shutdown()\n    # Note: dictConfig in configure_logging will set up handlers again\n    yield\n    for handler in logging.root.handlers[:]:\n        logging.root.removeHandler(handler)\n    logging.shutdown()\n\n\ndef test_axolotl_logs_captured_at_all_levels(monkeypatch):\n    from axolotl.logging_config import configure_logging\n    from axolotl.utils import tee\n    from axolotl.utils.logging import get_logger\n\n    with tempfile.TemporaryDirectory() as td:\n        # Avoid stdout tee in this test to simplify interaction with pytest capture\n        monkeypatch.setenv(\"AXOLOTL_TEE_STDOUT\", \"0\")\n        configure_logging()\n        path = tee.prepare_debug_log(\n            type(\"Cfg\", (), {\"output_dir\": td, \"get\": lambda *_: False})\n        )\n\n        log = get_logger(\"axolotl.test\")\n        log.info(\"AX-INFO\")\n        log.debug(\"AX-DEBUG\")\n        tee.file_only_stream.flush()\n\n        data = read(path)\n        assert \"AX-INFO\" in data\n        assert \"AX-DEBUG\" in data\n        tee.close_debug_log()\n\n\ndef test_third_party_logs_filtered_and_warning_captured(monkeypatch):\n    from axolotl.logging_config import configure_logging\n    from axolotl.utils import tee\n\n    with tempfile.TemporaryDirectory() as td:\n        monkeypatch.setenv(\"AXOLOTL_TEE_STDOUT\", \"0\")\n        configure_logging()\n        path = tee.prepare_debug_log(\n            type(\"Cfg\", (), {\"output_dir\": td, \"get\": lambda *_: False})\n        )\n\n        # Third-party logger (non-axolotl)\n        other = logging.getLogger(\"thirdparty.lib\")\n        other.info(\"TP-INFO\")\n        other.warning(\"TP-WARN\")\n\n        # Simulate Python warnings routed through logging\n        logging.getLogger(\"py.warnings\").warning(\"PY-WARN\")\n\n        # Push through buffers\n        tee.file_only_stream.flush()\n\n        data = read(path)\n        # INFO from non-axolotl should be filtered out (not present)\n        assert \"TP-INFO\" not in data\n        # WARNING+ should be present\n        assert \"TP-WARN\" in data\n        # Python warnings captured (via py.warnings logger)\n        assert \"PY-WARN\" in data\n        tee.close_debug_log()\n        tee.close_debug_log()\n\n\ndef test_prepare_debug_log_idempotent_and_no_duplicate(monkeypatch):\n    from axolotl.logging_config import configure_logging\n    from axolotl.utils import tee\n    from axolotl.utils.logging import get_logger\n\n    with tempfile.TemporaryDirectory() as td:\n        monkeypatch.setenv(\"AXOLOTL_TEE_STDOUT\", \"0\")\n        configure_logging()\n        cfg = type(\"Cfg\", (), {\"output_dir\": td, \"get\": lambda *_: False})\n        p1 = tee.prepare_debug_log(cfg)\n        p2 = tee.prepare_debug_log(cfg)\n        assert p1 == p2\n\n        log = get_logger(\"axolotl.test\")\n        marker = \"UNIQUE-MARKER-12345\"\n        log.info(marker)\n        tee.file_only_stream.flush()\n\n        data = read(p1)\n        # Ensure the marker appears once (not duplicated via propagation)\n        assert data.count(marker) == 1\n        tee.close_debug_log()\n"
  },
  {
    "path": "tests/test_lora.py",
    "content": "\"\"\"\ntests for loading loras\n\"\"\"\n\nfrom axolotl.loaders import ModelLoader, load_tokenizer\nfrom axolotl.utils.config import normalize_config, validate_config\nfrom axolotl.utils.dict import DictDefault\n\nminimal_config = DictDefault(\n    {\n        \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n        \"learning_rate\": 0.000001,\n        \"datasets\": [\n            {\n                \"path\": \"mhenrichsen/alpaca_2k_test\",\n                \"type\": \"alpaca\",\n            }\n        ],\n        \"micro_batch_size\": 1,\n        \"gradient_accumulation_steps\": 1,\n    }\n)\n\n\nclass TestLoRALoad:\n    \"\"\"\n    Test class for loading LoRA weights\n    \"\"\"\n\n    def test_load_lora_weights(self):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"adapter\": \"lora\",\n                \"lora_r\": 8,\n                \"lora_alpha\": 16,\n                \"lora_dropout\": 0.0,\n                \"lora_target_linear\": True,\n                \"micro_batch_size\": 1,\n                \"gradient_accumulation_steps\": 1,\n                \"sequence_len\": 1024,\n            }\n            | minimal_config\n        )\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        tokenizer = load_tokenizer(cfg)\n        ModelLoader(cfg, tokenizer).load()\n\n    def test_load_lora_weights_empty_dropout(self):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"adapter\": \"lora\",\n                \"lora_r\": 8,\n                \"lora_alpha\": 16,\n                \"lora_dropout\": None,\n                \"lora_target_linear\": True,\n                \"micro_batch_size\": 1,\n                \"gradient_accumulation_steps\": 1,\n                \"sequence_len\": 1024,\n            }\n            | minimal_config\n        )\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        assert cfg.lora_dropout == 0.0\n        tokenizer = load_tokenizer(cfg)\n        ModelLoader(cfg, tokenizer).load()\n"
  },
  {
    "path": "tests/test_normalize_config.py",
    "content": "\"\"\"\nTest classes for checking functionality of the cfg normalization\n\"\"\"\n\nimport unittest\nfrom unittest.mock import patch\n\nfrom axolotl.utils.config import (\n    normalize_cfg_datasets,\n    normalize_config,\n    validate_config,\n)\nfrom axolotl.utils.dict import DictDefault\n\n\nclass NormalizeConfigTestCase(unittest.TestCase):\n    \"\"\"\n    test class for normalize_config checks\n    \"\"\"\n\n    def _get_base_cfg(self):\n        return DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"base_model_config\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"tokenizer_type\": \"AutoTokenizer\",\n                \"num_epochs\": 1,\n                \"micro_batch_size\": 1,\n                \"gradient_accumulation_steps\": 1,\n                \"datasets\": [\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": \"alpaca\",\n                    },\n                ],\n                \"learning_rate\": 0.0001,\n            }\n        )\n\n    def test_base_model_config_set_when_empty(self):\n        cfg = self._get_base_cfg()\n        del cfg.base_model_config\n        normalize_config(cfg)\n\n        assert cfg.base_model_config == cfg.base_model\n\n    def test_chat_template_chatml(self):\n        cfg = DictDefault(\n            {\n                \"chat_template\": \"chatml\",\n                \"datasets\": [\n                    {\n                        \"path\": \"lorem/ipsum\",\n                        \"type\": \"chat_template\",\n                        \"chat_template\": \"gemma\",\n                    },\n                    {\n                        \"path\": \"sit/amet\",\n                        \"type\": \"chat_template\",\n                    },\n                ],\n            }\n        )\n\n        normalize_cfg_datasets(cfg)\n\n        assert cfg.datasets[0].chat_template == \"gemma\"\n        assert cfg.datasets[1].chat_template == \"chatml\"\n\n    @patch(\"axolotl.utils.config.is_torch_bf16_gpu_available\")\n    def test_bf16_auto_setter_available(self, mock_bf16_avail):\n        cfg = self._get_base_cfg()\n        cfg.bf16 = \"auto\"\n        mock_bf16_avail.return_value = True\n\n        normalize_config(cfg)\n\n        self.assertTrue(cfg.bf16)\n        self.assertFalse(cfg.fp16)\n\n    @patch(\"axolotl.utils.config.is_torch_bf16_gpu_available\")\n    def test_bf16_auto_setter_not_available(self, mock_bf16_avail):\n        cfg = self._get_base_cfg()\n        cfg.bf16 = \"auto\"\n        cfg.fp16 = None\n        mock_bf16_avail.return_value = False\n\n        normalize_config(cfg)\n\n        self.assertFalse(cfg.bf16)\n        self.assertTrue(cfg.fp16)\n\n    @patch(\"axolotl.utils.config.is_torch_bf16_gpu_available\")\n    def test_bf16_disables_fp16(self, mock_bf16_avail):\n        cfg = self._get_base_cfg()\n        cfg.bf16 = True\n        cfg.fp16 = False\n        mock_bf16_avail.return_value = True\n\n        normalize_config(cfg)\n\n        self.assertTrue(cfg.bf16)\n        self.assertFalse(cfg.fp16)\n\n    def test_migrate_fsdp_config(self):\n        \"\"\"Test basic FSDP config migration with and without fsdp_version\"\"\"\n        cfg_with_version = self._get_base_cfg() | DictDefault(\n            {\n                \"fsdp_config\": {\n                    \"fsdp_version\": 2,\n                    \"fsdp_auto_wrap_policy\": \"TRANSFORMER_BASED_WRAP\",\n                    \"fsdp_offload_params\": False,\n                    \"fsdp_cpu_ram_efficient_loading\": True,\n                }\n            }\n        )\n\n        cfg_with_version = validate_config(cfg_with_version)\n\n        self.assertEqual(cfg_with_version.fsdp_version, 2)\n        self.assertEqual(\n            cfg_with_version.fsdp_config.auto_wrap_policy, \"TRANSFORMER_BASED_WRAP\"\n        )\n        self.assertEqual(cfg_with_version.fsdp_config.offload_params, False)\n        self.assertEqual(cfg_with_version.fsdp_config.cpu_ram_efficient_loading, True)\n\n        self.assertNotIn(\"fsdp_auto_wrap_policy\", cfg_with_version.fsdp_config)\n        self.assertNotIn(\"fsdp_offload_params\", cfg_with_version.fsdp_config)\n        self.assertNotIn(\"fsdp_cpu_ram_efficient_loading\", cfg_with_version.fsdp_config)\n        self.assertIn(\"fsdp_version\", cfg_with_version.fsdp_config)\n\n        cfg_without_version = self._get_base_cfg() | DictDefault(\n            {\n                \"fsdp_config\": {\n                    \"fsdp_auto_wrap_policy\": \"SIZE_BASED_WRAP\",\n                    \"fsdp_offload_params\": True,\n                }\n            }\n        )\n\n        cfg_without_version = validate_config(cfg_without_version)\n\n        self.assertNotIn(\"fsdp_version\", cfg_without_version)\n        self.assertEqual(\n            cfg_without_version.fsdp_config.auto_wrap_policy, \"SIZE_BASED_WRAP\"\n        )\n        self.assertEqual(cfg_without_version.fsdp_config.offload_params, True)\n\n        self.assertNotIn(\"fsdp_auto_wrap_policy\", cfg_without_version.fsdp_config)\n        self.assertNotIn(\"fsdp_offload_params\", cfg_without_version.fsdp_config)\n\n    def test_migrate_fsdp_config_no_fsdp_config(self):\n        \"\"\"Test that function doesn't crash when no fsdp_config is present\"\"\"\n        cfg = self._get_base_cfg()\n\n        cfg = validate_config(cfg)\n\n        self.assertNotIn(\"fsdp_config\", cfg)\n        self.assertNotIn(\"fsdp_version\", cfg)\n\n    def test_migrate_fsdp_config_empty_fsdp_config(self):\n        \"\"\"Test migration with empty fsdp_config\"\"\"\n        cfg = self._get_base_cfg() | DictDefault({\"fsdp_config\": {}})\n\n        cfg = validate_config(cfg)\n\n        self.assertNotIn(\"fsdp_version\", cfg)\n        self.assertEqual(cfg.fsdp_config, {})\n\n    def test_migrate_fsdp_config_mixed_keys(self):\n        \"\"\"Test migration with a mix of fsdp_ and non-fsdp_ keys\"\"\"\n        cfg = self._get_base_cfg() | DictDefault(\n            {\n                \"fsdp_config\": {\n                    \"fsdp_version\": 1,\n                    \"fsdp_state_dict_type\": \"FULL_STATE_DICT\",\n                    \"mixed_precision_policy\": \"fp16\",\n                    \"activation_checkpointing\": True,\n                    \"fsdp_reshard_after_forward\": False,\n                }\n            }\n        )\n\n        cfg = validate_config(cfg)\n\n        self.assertEqual(cfg.fsdp_version, 1)\n        self.assertEqual(cfg.fsdp_config.state_dict_type, \"FULL_STATE_DICT\")\n        self.assertEqual(cfg.fsdp_config.reshard_after_forward, False)\n        self.assertEqual(cfg.fsdp_config.mixed_precision_policy, \"fp16\")\n        self.assertEqual(cfg.fsdp_config.activation_checkpointing, True)\n\n        # Check original fsdp_ keys are removed\n        self.assertNotIn(\"fsdp_state_dict_type\", cfg.fsdp_config)\n        self.assertNotIn(\"fsdp_reshard_after_forward\", cfg.fsdp_config)\n\n        self.assertIn(\"fsdp_version\", cfg.fsdp_config)\n"
  },
  {
    "path": "tests/test_opentelemetry_callback.py",
    "content": "\"\"\"Tests for OpenTelemetry metrics callback functionality.\"\"\"\n\nimport time\n\nimport pytest\n\nfrom axolotl.utils.dict import DictDefault\n\n\n@pytest.fixture\ndef mock_otel_config():\n    \"\"\"Mock configuration for OpenTelemetry callback.\"\"\"\n    return DictDefault(\n        {\n            \"use_otel_metrics\": True,\n            \"otel_metrics_host\": \"localhost\",\n            \"otel_metrics_port\": 8003,  # Use unique port for tests\n        }\n    )\n\n\n@pytest.fixture\ndef mock_trainer_state():\n    \"\"\"Mock trainer state for callback testing.\"\"\"\n    from transformers import TrainerState\n\n    state = TrainerState()\n    state.epoch = 1.0\n    state.global_step = 100\n    return state\n\n\n@pytest.fixture\ndef mock_training_args():\n    \"\"\"Mock training arguments for callback testing.\"\"\"\n    from transformers import TrainingArguments\n\n    return TrainingArguments(output_dir=\"/tmp/test\")\n\n\n@pytest.fixture\ndef mock_trainer_control():\n    \"\"\"Mock trainer control for callback testing.\"\"\"\n    from transformers.trainer_callback import TrainerControl\n\n    return TrainerControl()\n\n\nclass TestOpenTelemetryConfig:\n    \"\"\"Test OpenTelemetry configuration schema.\"\"\"\n\n    def test_config_schema_valid(self):\n        \"\"\"Test OpenTelemetry configuration schema validation.\"\"\"\n        from axolotl.utils.schemas.integrations import OpenTelemetryConfig\n\n        # Test valid config\n        valid_config = {\n            \"use_otel_metrics\": True,\n            \"otel_metrics_host\": \"localhost\",\n            \"otel_metrics_port\": 8000,\n        }\n\n        otel_config = OpenTelemetryConfig(**valid_config)\n        assert otel_config.use_otel_metrics is True\n        assert otel_config.otel_metrics_host == \"localhost\"\n        assert otel_config.otel_metrics_port == 8000\n\n    def test_config_defaults(self):\n        \"\"\"Test OpenTelemetry configuration default values.\"\"\"\n        from axolotl.utils.schemas.integrations import OpenTelemetryConfig\n\n        # Test minimal config with defaults\n        minimal_config = {\"use_otel_metrics\": True}\n\n        otel_config = OpenTelemetryConfig(**minimal_config)\n        assert otel_config.use_otel_metrics is True\n        assert otel_config.otel_metrics_host == \"localhost\"  # default\n        assert otel_config.otel_metrics_port == 8000  # default\n\n    def test_config_disabled_by_default(self):\n        \"\"\"Test that OpenTelemetry is disabled by default.\"\"\"\n        from axolotl.utils.schemas.integrations import OpenTelemetryConfig\n\n        # Test default config\n        default_config = OpenTelemetryConfig()\n        assert default_config.use_otel_metrics is False\n\n\nclass TestOpenTelemetryCallback:\n    \"\"\"Test OpenTelemetry callback functionality.\"\"\"\n\n    def test_callback_import(self):\n        \"\"\"Test that OpenTelemetry callback can be imported.\"\"\"\n        from axolotl.utils.callbacks.opentelemetry import OpenTelemetryMetricsCallback\n\n        assert OpenTelemetryMetricsCallback is not None\n\n    def test_callback_graceful_fallback(self, mock_otel_config):\n        \"\"\"Test callback gracefully handles missing dependencies.\"\"\"\n        from axolotl.utils.callbacks.opentelemetry import OpenTelemetryMetricsCallback\n\n        # This should not raise an exception even if dependencies are missing\n        callback = OpenTelemetryMetricsCallback(mock_otel_config)\n\n        # Callback should exist but may have metrics disabled\n        assert callback is not None\n        assert hasattr(callback, \"metrics_enabled\")\n\n    def test_callback_initialization_enabled(self, mock_otel_config):\n        \"\"\"Test callback initialization when OpenTelemetry is available.\"\"\"\n        from axolotl.utils.callbacks.opentelemetry import (\n            OPENTELEMETRY_AVAILABLE,\n            OpenTelemetryMetricsCallback,\n        )\n\n        callback = OpenTelemetryMetricsCallback(mock_otel_config)\n\n        if OPENTELEMETRY_AVAILABLE:\n            assert callback.metrics_enabled is True\n            assert callback.cfg == mock_otel_config\n            assert callback.metrics_host == \"localhost\"\n            assert callback.metrics_port == 8003\n        else:\n            assert callback.metrics_enabled is False\n\n    def test_metrics_server_lifecycle(\n        self,\n        mock_otel_config,\n        mock_trainer_state,\n        mock_training_args,\n        mock_trainer_control,\n    ):\n        \"\"\"Test metrics server starts and stops correctly.\"\"\"\n        from axolotl.utils.callbacks.opentelemetry import (\n            OPENTELEMETRY_AVAILABLE,\n            OpenTelemetryMetricsCallback,\n        )\n\n        if not OPENTELEMETRY_AVAILABLE:\n            pytest.skip(\"OpenTelemetry dependencies not available\")\n\n        callback = OpenTelemetryMetricsCallback(mock_otel_config)\n\n        # Start server\n        callback.on_train_begin(\n            mock_training_args, mock_trainer_state, mock_trainer_control\n        )\n        assert callback.server_started is True\n\n        # End training\n        callback.on_train_end(\n            mock_training_args, mock_trainer_state, mock_trainer_control\n        )\n\n    def test_metrics_recording(\n        self,\n        mock_otel_config,\n        mock_trainer_state,\n        mock_training_args,\n        mock_trainer_control,\n    ):\n        \"\"\"Test that metrics are recorded during training.\"\"\"\n        from axolotl.utils.callbacks.opentelemetry import (\n            OPENTELEMETRY_AVAILABLE,\n            OpenTelemetryMetricsCallback,\n        )\n\n        if not OPENTELEMETRY_AVAILABLE:\n            pytest.skip(\"OpenTelemetry dependencies not available\")\n\n        callback = OpenTelemetryMetricsCallback(mock_otel_config)\n        callback.on_train_begin(\n            mock_training_args, mock_trainer_state, mock_trainer_control\n        )\n\n        # Test logging metrics\n        test_logs = {\n            \"loss\": 0.5,\n            \"learning_rate\": 1e-4,\n            \"grad_norm\": 0.8,\n        }\n\n        # This should not raise an exception\n        callback.on_log(\n            mock_training_args, mock_trainer_state, mock_trainer_control, logs=test_logs\n        )\n        assert callback.metrics_enabled is True\n\n    def test_evaluation_metrics(\n        self,\n        mock_otel_config,\n        mock_trainer_state,\n        mock_training_args,\n        mock_trainer_control,\n    ):\n        \"\"\"Test evaluation metrics recording.\"\"\"\n        from axolotl.utils.callbacks.opentelemetry import (\n            OPENTELEMETRY_AVAILABLE,\n            OpenTelemetryMetricsCallback,\n        )\n\n        if not OPENTELEMETRY_AVAILABLE:\n            pytest.skip(\"OpenTelemetry dependencies not available\")\n\n        callback = OpenTelemetryMetricsCallback(mock_otel_config)\n        callback.on_train_begin(\n            mock_training_args, mock_trainer_state, mock_trainer_control\n        )\n\n        # Test evaluation metrics\n        eval_logs = {\n            \"eval_loss\": 0.3,\n            \"eval_accuracy\": 0.95,\n        }\n\n        # This should not raise an exception\n        callback.on_evaluate(\n            mock_training_args, mock_trainer_state, mock_trainer_control, eval_logs\n        )\n        assert callback.metrics_enabled is True\n\n    def test_thread_safety(self, mock_otel_config):\n        \"\"\"Test that callback has thread safety mechanisms.\"\"\"\n        from axolotl.utils.callbacks.opentelemetry import (\n            OPENTELEMETRY_AVAILABLE,\n            OpenTelemetryMetricsCallback,\n        )\n\n        if not OPENTELEMETRY_AVAILABLE:\n            pytest.skip(\"OpenTelemetry dependencies not available\")\n\n        callback = OpenTelemetryMetricsCallback(mock_otel_config)\n        assert hasattr(callback, \"metrics_lock\")\n        # Check it's a lock-like object\n        assert hasattr(callback.metrics_lock, \"__enter__\")\n        assert hasattr(callback.metrics_lock, \"__exit__\")\n\n\nclass TestOpenTelemetryIntegration:\n    \"\"\"Integration tests for OpenTelemetry.\"\"\"\n\n    def test_availability_check(self):\n        \"\"\"Test availability check function.\"\"\"\n        from axolotl.utils import is_opentelemetry_available\n\n        result = is_opentelemetry_available()\n        assert isinstance(result, bool)\n\n    def test_prometheus_endpoint_basic(\n        self,\n        mock_otel_config,\n        mock_trainer_state,\n        mock_training_args,\n        mock_trainer_control,\n    ):\n        \"\"\"Test basic Prometheus endpoint functionality.\"\"\"\n        from axolotl.utils.callbacks.opentelemetry import (\n            OPENTELEMETRY_AVAILABLE,\n            OpenTelemetryMetricsCallback,\n        )\n\n        if not OPENTELEMETRY_AVAILABLE:\n            pytest.skip(\"OpenTelemetry dependencies not available\")\n\n        try:\n            import requests\n        except ImportError:\n            pytest.skip(\"requests library not available\")\n\n        callback = OpenTelemetryMetricsCallback(mock_otel_config)\n        callback.on_train_begin(\n            mock_training_args, mock_trainer_state, mock_trainer_control\n        )\n\n        if not callback.server_started:\n            pytest.skip(\"Metrics server failed to start\")\n\n        # Give server time to start\n        time.sleep(1)\n\n        # Try to access metrics endpoint\n        try:\n            response = requests.get(\n                f\"http://{callback.metrics_host}:{callback.metrics_port}/metrics\",\n                timeout=2,\n            )\n            assert response.status_code == 200\n            # Check for Prometheus format\n            assert \"# TYPE\" in response.text or \"# HELP\" in response.text\n        except requests.exceptions.RequestException:\n            pytest.skip(\n                \"Could not connect to metrics endpoint - this is expected in some environments\"\n            )\n\n\nclass TestOpenTelemetryCallbackMethods:\n    \"\"\"Test specific callback methods.\"\"\"\n\n    def test_step_end_callback(\n        self,\n        mock_otel_config,\n        mock_trainer_state,\n        mock_training_args,\n        mock_trainer_control,\n    ):\n        \"\"\"Test step end callback method.\"\"\"\n        from axolotl.utils.callbacks.opentelemetry import (\n            OPENTELEMETRY_AVAILABLE,\n            OpenTelemetryMetricsCallback,\n        )\n\n        if not OPENTELEMETRY_AVAILABLE:\n            pytest.skip(\"OpenTelemetry dependencies not available\")\n\n        callback = OpenTelemetryMetricsCallback(mock_otel_config)\n        callback.on_train_begin(\n            mock_training_args, mock_trainer_state, mock_trainer_control\n        )\n\n        # Should not raise an exception\n        callback.on_step_end(\n            mock_training_args, mock_trainer_state, mock_trainer_control\n        )\n\n    def test_epoch_end_callback(\n        self,\n        mock_otel_config,\n        mock_trainer_state,\n        mock_training_args,\n        mock_trainer_control,\n    ):\n        \"\"\"Test epoch end callback method.\"\"\"\n        from axolotl.utils.callbacks.opentelemetry import (\n            OPENTELEMETRY_AVAILABLE,\n            OpenTelemetryMetricsCallback,\n        )\n\n        if not OPENTELEMETRY_AVAILABLE:\n            pytest.skip(\"OpenTelemetry dependencies not available\")\n\n        callback = OpenTelemetryMetricsCallback(mock_otel_config)\n        callback.on_train_begin(\n            mock_training_args, mock_trainer_state, mock_trainer_control\n        )\n\n        # Should not raise an exception\n        callback.on_epoch_end(\n            mock_training_args, mock_trainer_state, mock_trainer_control\n        )\n"
  },
  {
    "path": "tests/test_packed_batch_sampler.py",
    "content": "\"\"\"Module for testing streaming dataset sequence packing\"\"\"\n\nimport pytest\nfrom datasets import concatenate_datasets\nfrom torch.utils.data import DataLoader, RandomSampler\nfrom transformers import AutoTokenizer\n\nfrom axolotl.datasets import TokenizedPromptDataset\nfrom axolotl.prompt_strategies.completion import load\nfrom axolotl.utils.collators import V2BatchSamplerDataCollatorForSeq2Seq\nfrom axolotl.utils.data.utils import handle_long_seq_in_dataset\nfrom axolotl.utils.dict import DictDefault\nfrom axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths\n\nfrom tests.hf_offline_utils import enable_hf_offline\n\n\n@pytest.fixture(name=\"tokenizer\")\ndef fixture_tokenizer():\n    tokenizer = AutoTokenizer.from_pretrained(\"huggyllama/llama-7b\")\n    tokenizer.pad_token = \"</s>\"\n    return tokenizer\n\n\nclass TestBatchedSamplerPacking:\n    \"\"\"\n    Test class for packing streaming dataset sequences\n    \"\"\"\n\n    @pytest.mark.parametrize(\n        \"batch_size, num_workers\",\n        [\n            (1, 0),\n            (2, 0),\n            (1, 2),\n            (2, 2),\n        ],\n    )\n    @pytest.mark.parametrize(\"max_seq_length\", [4096, 512])\n    @pytest.mark.parametrize(\"sequential\", [True, False])\n    @enable_hf_offline\n    def test_packing(\n        self,\n        dataset_winglian_tiny_shakespeare,\n        batch_size,\n        num_workers,\n        tokenizer,\n        max_seq_length,\n        sequential,\n    ):\n        from axolotl.monkeypatch.data.batch_dataset_fetcher import (\n            apply_multipack_dataloader_patch,\n            remove_multipack_dataloader_patch,\n        )\n\n        # Apply the patch for multipack handling\n        apply_multipack_dataloader_patch()\n\n        dataset = dataset_winglian_tiny_shakespeare[\"train\"]\n\n        cfg = DictDefault(\n            {\n                \"train_on_inputs\": True,\n                \"sequence_len\": max_seq_length,\n            }\n        )\n        ds_cfg = DictDefault(\n            {\n                \"field\": \"text\",\n            }\n        )\n        completion_strategy = load(tokenizer, cfg, ds_cfg)\n        dataset_wrapper = TokenizedPromptDataset(\n            completion_strategy,\n            dataset,\n        )\n        train_dataset = concatenate_datasets([dataset_wrapper])\n\n        train_dataset = handle_long_seq_in_dataset(train_dataset, cfg.sequence_len, cfg)\n\n        lengths = get_dataset_lengths(train_dataset)\n        batch_sampler = MultipackBatchSampler(\n            sampler=RandomSampler(train_dataset),\n            lengths=lengths,\n            batch_size=batch_size,\n            batch_max_len=max_seq_length,\n            group_size=100000,\n            bin_size=200,\n            sequential=sequential,\n            drop_last=False,\n        )\n\n        loader = DataLoader(\n            train_dataset,\n            batch_sampler=batch_sampler,\n            collate_fn=V2BatchSamplerDataCollatorForSeq2Seq(\n                tokenizer=tokenizer,\n                padding=True,\n                pad_to_multiple_of=max_seq_length,\n                return_tensors=\"pt\",\n            ),\n            num_workers=num_workers,\n        )\n\n        batch_idxs = []\n        for batch in batch_sampler:\n            for pack in batch:\n                batch_idxs.extend(pack)\n\n        try:\n            for batch in loader:\n                assert batch[\"input_ids\"].numel() <= batch_size * max_seq_length\n                assert batch[\"input_ids\"].shape[1] == max_seq_length\n\n            original_idxs = set(range(len(train_dataset)))\n            assert original_idxs == set(batch_idxs)\n            assert len(batch_idxs) == len(set(batch_idxs))\n        finally:\n            # Clean up: remove the patch after the test\n            remove_multipack_dataloader_patch()\n"
  },
  {
    "path": "tests/test_packed_dataset.py",
    "content": "\"\"\"Module for testing dataset sequence packing\"\"\"\n\nimport unittest\n\nfrom transformers import AutoTokenizer\n\nfrom axolotl.cli.args import TrainerCliArgs\nfrom axolotl.common.datasets import load_datasets\nfrom axolotl.train import setup_model_and_trainer\nfrom axolotl.utils.config import normalize_config, validate_config\nfrom axolotl.utils.dict import DictDefault\n\nfrom tests.e2e.utils import with_temp_dir\nfrom tests.hf_offline_utils import enable_hf_offline\n\n\nclass TestPacking(unittest.TestCase):\n    \"\"\"\n    Test class for packing dataset sequences\n    \"\"\"\n\n    @enable_hf_offline\n    def setUp(self) -> None:\n        self.tokenizer = AutoTokenizer.from_pretrained(\"huggyllama/llama-7b\")\n        self.tokenizer.add_special_tokens(\n            {\n                \"bos_token\": \"<s>\",\n                \"eos_token\": \"</s>\",\n                \"unk_token\": \"<unk>\",\n            }\n        )\n\n    @with_temp_dir\n    def test_lora_packing(self, temp_dir):\n        cfg = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"tokenizer_type\": \"AutoTokenizer\",\n                \"sequence_len\": 1024,\n                \"sample_packing\": True,\n                \"multipack_real_batches\": False,\n                \"eval_sample_packing\": True,\n                \"adapter\": \"lora\",\n                \"lora_r\": 32,\n                \"lora_alpha\": 64,\n                \"lora_dropout\": 0.05,\n                \"lora_target_linear\": True,\n                \"val_set_size\": 0.2,\n                \"special_tokens\": {\n                    \"pad_token\": \"<|endoftext|>\",\n                },\n                \"datasets\": [\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": \"alpaca\",\n                    },\n                ],\n                \"dataset_num_proc\": 4,\n                \"num_epochs\": 1,\n                \"max_steps\": 20,\n                \"save_steps\": 10,\n                \"micro_batch_size\": 8,\n                \"gradient_accumulation_steps\": 1,\n                \"output_dir\": temp_dir,\n                \"learning_rate\": 0.00001,\n                \"optimizer\": \"adamw_torch_fused\",\n                \"lr_scheduler\": \"cosine\",\n                \"fp16\": False,\n                \"bf16\": False,\n            }\n        )\n\n        cfg = validate_config(cfg)\n        normalize_config(cfg)\n        cli_args = TrainerCliArgs()\n        dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)\n\n        (\n            trainer,\n            _,\n            _,\n            _,\n            _,\n        ) = setup_model_and_trainer(cfg, dataset_meta)\n\n        sampler = trainer._get_eval_sampler(trainer.eval_dataset)\n        assert \"MultipackBatchSampler\" in sampler.__class__.__name__\n        assert (\n            \"V2BatchSamplerDataCollatorForSeq2Seq\"\n            in trainer.eval_data_collator.__class__.__name__\n        )\n        dataloader = trainer.get_eval_dataloader(trainer.eval_dataset)\n        dataloader_iter = iter(dataloader)\n        batch = next(dataloader_iter)\n        assert batch[\"input_ids\"].shape == (1, 8192)\n\n        sampler = trainer._get_train_sampler(trainer.train_dataset)\n        assert \"MultipackBatchSampler\" in sampler.__class__.__name__\n        assert (\n            \"V2BatchSamplerDataCollatorForSeq2Seq\"\n            in trainer.train_data_collator.__class__.__name__\n        )\n        dataloader = trainer.get_train_dataloader()\n        dataloader_iter = iter(dataloader)\n        batch = next(dataloader_iter)\n        assert batch[\"input_ids\"].shape == (1, 8192)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/test_packed_pretraining.py",
    "content": "\"\"\"Module for testing streaming dataset sequence packing\"\"\"\n\nimport functools\nimport random\nimport string\n\nimport pytest\nimport torch\nfrom datasets import IterableDataset\nfrom torch.utils.data import DataLoader\n\nfrom axolotl.utils.data import get_dataset_wrapper, wrap_streaming_dataset\nfrom axolotl.utils.dict import DictDefault\n\n\nclass TestPretrainingPacking:\n    \"\"\"\n    Test class for packing streaming dataset sequences\n    \"\"\"\n\n    @pytest.fixture\n    def random_text(self):\n        # seed with random.seed(0) for reproducibility\n        random.seed(0)\n\n        # generate row of random text with \"words\" of between 2 and 10 characters and\n        # between 400 to 1200 characters per line\n        def rand_txt():\n            return \" \".join(\n                [\n                    \"\".join(\n                        random.choices(string.ascii_lowercase, k=random.randint(2, 10))\n                    )\n                    for _ in range(random.randint(50, 200))\n                ]\n            )\n\n        # Create a list of 2000 random texts rather than just using it within the\n        # generator so the test runs faster\n        data = [rand_txt() for _ in range(500)]\n\n        # Create an IterableDataset\n        def generator():\n            for row in data:\n                yield {\"text\": row}\n\n        return IterableDataset.from_generator(generator)\n\n    @pytest.mark.flaky(retries=1, delay=5)\n    def test_packing_stream_dataset(self, tokenizer_huggyllama, random_text):\n        dataset = random_text\n\n        cfg = DictDefault(\n            {\n                \"pretraining_dataset\": [\n                    {\n                        \"path\": \"winglian/tiny-shakespeare\",\n                        \"type\": \"pretrain\",\n                    }\n                ],\n                \"sample_packing\": True,\n                \"pretrain_multipack_attn\": True,\n                \"pad_to_sequence_len\": True,\n                \"sequence_len\": 2048,\n                \"micro_batch_size\": 2,\n                \"sample_packing_group_size\": 100000,\n                \"sample_packing_bin_size\": 200,\n            }\n        )\n\n        ds_wrapper_partial = functools.partial(\n            get_dataset_wrapper,\n            cfg.pretraining_dataset[0],\n            tokenizer_huggyllama,\n            cfg,\n            cfg.pretraining_dataset[0][\"type\"] or \"pretrain\",\n        )\n\n        original_bsz = cfg.micro_batch_size\n        train_dataset = wrap_streaming_dataset(\n            dataset,\n            tokenizer_huggyllama,\n            cfg,\n            ds_wrapper_partial,\n        )\n\n        trainer_loader = DataLoader(\n            train_dataset,\n            batch_size=1,\n            collate_fn=None,\n            drop_last=True,\n        )\n        idx = 0\n        for data in trainer_loader:\n            if idx > 3:\n                break\n            assert data[\"input_ids\"].shape == torch.Size(\n                [1, original_bsz * cfg.sequence_len]\n            )\n            assert data[\"position_ids\"].shape == torch.Size(\n                [1, original_bsz * cfg.sequence_len]\n            )\n            assert data[\"labels\"].shape == torch.Size(\n                [1, original_bsz * cfg.sequence_len]\n            )\n            assert \"attention_mask\" not in data\n            # FIXME add back once we fix packing unpad/pad with attention mask\n            # assert data[\"attention_mask\"].shape == torch.Size(\n            #     [1, original_bsz * cfg.sequence_len]\n            # )\n            idx += 1\n"
  },
  {
    "path": "tests/test_perplexity.py",
    "content": "\"\"\"unit tests for perplexity eval callback\"\"\"\n\nfrom pytest import fixture\nfrom transformers.models.auto.modeling_auto import AutoModelForCausalLM\nfrom transformers.models.auto.tokenization_auto import AutoTokenizer\n\nfrom axolotl.utils.callbacks.perplexity import Perplexity\n\nMODEL_NAME = \"HuggingFaceTB/SmolLM2-135M\"\n\n\n@fixture()\ndef metric(tokenizer):\n    return Perplexity(tokenizer=tokenizer, max_seq_len=512)\n\n\n@fixture()\ndef model():\n    return AutoModelForCausalLM.from_pretrained(\n        MODEL_NAME, trust_remote_code=True, dtype=\"float32\"\n    )\n\n\n@fixture()\ndef tokenizer():\n    tokenizer_ = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)\n    tokenizer_.add_special_tokens({\"pad_token\": \"<|endoftext|>\"})\n    return tokenizer_\n\n\ndef test_perplexity_longer_than_stride(model, metric):\n    # taken from https://huggingface.co/datasets/roneneldan/TinyStories\n    sample_text = \"\"\"\nOnce upon a time, there was a little car named Beep. Beep loved to go fast and play in the sun. Beep was a healthy car because he always had good fuel. Good fuel made Beep happy and strong. One day, Beep was driving in the park when he saw a big tree. The tree had many leaves that were falling. Beep liked how the leaves fall and wanted to play with them. Beep drove under the tree and watched the leaves fall on him. He laughed and beeped his horn. Beep played with the falling leaves all day. When it was time to go home, Beep knew he needed more fuel. He went to the fuel place and got more healthy fuel. Now, Beep was ready to go fast and play again the next day. And Beep lived happily ever after.\nOne day, a little fish named Fin was swimming near the shore. He saw a big crab and wanted to be friends. \"Hi, I am Fin. Do you want to play?\" asked the little fish. The crab looked at Fin and said, \"No, I don't want to play. I am cold and I don't feel fine.\" Fin felt sad but wanted to help the crab feel better. He swam away and thought of a plan. He remembered that the sun could make things warm. So, Fin swam to the top of the water and called to the sun, \"Please, sun, help my new friend feel fine and not freeze!\" The sun heard Fin's call and shone its warm light on the shore. The crab started to feel better and not so cold. He saw Fin and said, \"Thank you, little fish, for making me feel fine. I don't feel like I will freeze now. Let's play together!\" And so, Fin and the crab played and became good friends.\n\"\"\"\n    result = metric.compute(model, [sample_text])\n    ppl = result[\"score\"]\n    assert round(ppl, 2) == 7.41\n\n\ndef test_perplexity_short(model, metric):\n    # taken from https://huggingface.co/datasets/roneneldan/TinyStories\n    sample_text = \"Once upon a time, there was a little car named Beep. Beep loved to go fast and play in the sun.\"\n    result = metric.compute(model, [sample_text])\n    ppl = result[\"score\"]\n    assert round(ppl, 2) == 10.33\n"
  },
  {
    "path": "tests/test_prompt_tokenizers.py",
    "content": "\"\"\"Module for testing prompt tokenizers.\"\"\"\n\nimport json\nfrom pathlib import Path\n\nfrom axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter\nfrom axolotl.prompt_strategies.alpaca_w_system import (\n    InstructionWSystemPromptTokenizingStrategy,\n    SystemDataPrompter,\n)\nfrom axolotl.prompt_strategies.llama2_chat import (\n    Llama2ChatPrompter,\n    LLama2ChatTokenizingStrategy,\n)\nfrom axolotl.prompt_strategies.orpo.chat_template import load\nfrom axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy\nfrom axolotl.prompters import AlpacaPrompter, PromptStyle\nfrom axolotl.utils.dict import DictDefault\n\nfrom tests.hf_offline_utils import enable_hf_offline\n\ntest_data = {\n    \"multi_turn_sys\": {\n        \"conversations\": [\n            {\"from\": \"system\", \"value\": \"lorem\"},\n            {\"from\": \"human\", \"value\": \"abc\"},\n            {\"from\": \"gpt\", \"value\": \"ipsum\"},\n            {\"from\": \"human\", \"value\": \"123\"},\n            {\"from\": \"gpt\", \"value\": \"sit\"},\n        ]\n    },\n    \"single_turn_sys\": {\n        \"conversations\": [\n            {\"from\": \"system\", \"value\": \"lorem\"},\n            {\"from\": \"human\", \"value\": \"abc\"},\n            {\"from\": \"gpt\", \"value\": \"ipsum\"},\n        ]\n    },\n    \"single_turn_no_sys\": {\n        \"conversations\": [\n            {\"from\": \"human\", \"value\": \"abc\"},\n            {\"from\": \"gpt\", \"value\": \"ipsum\"},\n        ]\n    },\n    \"multi_turn_no_sys\": {\n        \"conversations\": [\n            {\"from\": \"human\", \"value\": \"abc\"},\n            {\"from\": \"gpt\", \"value\": \"ipsum\"},\n            {\"from\": \"human\", \"value\": \"123\"},\n            {\"from\": \"gpt\", \"value\": \"sit\"},\n        ]\n    },\n}\n\n\nclass TestPromptTokenizationStrategies:\n    \"\"\"\n    Test class for prompt tokenization strategies.\n    \"\"\"\n\n    @enable_hf_offline\n    def test_no_sys_prompt(self, tokenizer_huggyllama_w_special_tokens):\n        \"\"\"\n        tests the interface between the user and assistant parts\n        \"\"\"\n        prompter = NoSystemPrompter()\n\n        strat = AlpacaPromptTokenizingStrategy(\n            prompter,\n            tokenizer_huggyllama_w_special_tokens,\n            False,\n            2048,\n        )\n        sample = {\n            \"instruction\": \"hello cruel. lorem ipsum dolor sit amet.\",\n            \"output\": \"world!\",\n        }\n        example = strat.tokenize_prompt(sample)\n        world_idx = example[\"input_ids\"].index(3186)\n        assert example[\"labels\"][world_idx] == 3186\n        assert example[\"labels\"][world_idx - 1] == -100\n\n    @enable_hf_offline\n    def test_alpaca(self, tokenizer_huggyllama_w_special_tokens):\n        \"\"\"\n        tests the interface between the user and assistant parts\n        \"\"\"\n\n        prompter = AlpacaPrompter()\n        strat = AlpacaPromptTokenizingStrategy(\n            prompter,\n            tokenizer_huggyllama_w_special_tokens,\n            False,\n            2048,\n        )\n        sample = {\"instruction\": \"hello!\", \"output\": \"Hi! How can I help?\"}\n        example = strat.tokenize_prompt(sample)\n        world_idx = example[\"input_ids\"].index(6324)\n        assert example[\"labels\"][world_idx] == 6324\n        assert example[\"labels\"][world_idx - 1] == -100\n\n\nclass TestInstructionWSystemPromptTokenizingStrategy:\n    \"\"\"\n    Test class for prompt tokenization strategies with sys prompt from the dataset\n    \"\"\"\n\n    @enable_hf_offline\n    def test_system_alpaca(self, tokenizer_huggyllama_w_special_tokens):\n        prompter = SystemDataPrompter(PromptStyle.CHAT.value)\n        strat = InstructionWSystemPromptTokenizingStrategy(\n            prompter,\n            tokenizer_huggyllama_w_special_tokens,\n            False,\n            2048,\n        )\n        sample = {\n            \"system\": \"use cot\",\n            \"instruction\": \"hello!\",\n            \"output\": \"Hi! How can I help?\",\n        }\n        example = strat.tokenize_prompt(sample)\n        assert example[\"input_ids\"][0:5] == [\n            1,\n            28962,\n            1254,\n            12665,\n            29901,\n        ]  # \"<s>SYSTEM:\"\n        assert example[\"input_ids\"][5:7] == [671, 20118]  # \" use cot\"\n        assert example[\"input_ids\"][8] == 11889  # USER\n\n\nclass Llama2ChatTokenizationTest:\n    \"\"\"\n    Test class for prompt tokenization strategies with sys prompt from the dataset\n    \"\"\"\n\n    @enable_hf_offline\n    def test_llama2_chat_integration(self, tokenizer_llama2_7b):\n        with open(\n            Path(__file__).parent / \"fixtures/conversation.json\", encoding=\"utf-8\"\n        ) as fin:\n            data = fin.read()\n            conversation = json.loads(data)\n        with open(\n            Path(__file__).parent / \"fixtures/conversation.tokenized_llama2chat.json\",\n            encoding=\"utf-8\",\n        ) as fin:\n            data = fin.read()\n            tokenized_conversation = json.loads(data)\n        prompter = Llama2ChatPrompter()\n        strat = LLama2ChatTokenizingStrategy(\n            prompter,\n            tokenizer_llama2_7b,\n            False,\n            4096,\n        )\n        example = strat.tokenize_prompt(conversation)\n        for fields in [\"input_ids\", \"attention_mask\", \"labels\"]:\n            # pytest assert equals\n\n            assert len(example[fields]) == len(tokenized_conversation[fields])\n            assert example[fields] == tokenized_conversation[fields]\n\n    def compare_with_transformers_integration(self, tokenizer_llama2_7b):\n        # this needs transformers >= v4.31.0\n        from transformers.models.llama.tokenization_llama import B_SYS, E_SYS\n        from transformers.pipelines.conversational import Conversation\n\n        # from transformers.models.llama.tokenization_llama import DEFAULT_SYSTEM_PROMPT\n        # broken as of 23/7/20\n        # see https://github.com/huggingface/transformers/pull/24935\n\n        DEFAULT_SYSTEM_PROMPT = \"\"\"\\\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\"\"\"\n        with open(\n            Path(__file__).parent / \"fixtures/conversation.json\", encoding=\"utf-8\"\n        ) as fin:\n            data = fin.read()\n            conversation = json.loads(data)\n        with open(\n            Path(__file__).parent / \"fixtures/conversation.tokenized_llama2chat.json\",\n            encoding=\"utf-8\",\n        ) as fin:\n            data = fin.read()\n            tokenized_conversation = json.loads(data)\n\n        user_input = []\n        answers = []\n        for msg in conversation[\"conversations\"]:\n            if msg[\"from\"] == \"human\":\n                user_input.append(msg[\"value\"])\n            else:\n                answers.append(msg[\"value\"])\n        hf_conf = Conversation(\n            text=user_input[-1],\n            past_user_inputs=[B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + user_input[0]]\n            + user_input[1:-1],\n            generated_responses=answers,\n        )\n\n        hf_tokens = tokenizer_llama2_7b._build_conversation_input_ids(hf_conf)\n\n        assert hf_tokens == tokenized_conversation[\"input_ids\"][: len(hf_tokens)]\n\n\nclass OrpoTokenizationTest:\n    \"\"\"test case for the ORPO tokenization\"\"\"\n\n    @enable_hf_offline\n    def test_orpo_integration(\n        self,\n        tokenizer_mistral_7b_instruct_chatml,\n        dataset_argilla_ultrafeedback_binarized_preferences_cleaned,\n    ):\n        ds = dataset_argilla_ultrafeedback_binarized_preferences_cleaned.select([0])\n        strat = load(\n            tokenizer_mistral_7b_instruct_chatml,\n            DictDefault({\"train_on_inputs\": False}),\n            DictDefault({\"chat_template\": \"chatml\"}),\n        )\n        res = strat.tokenize_prompt(ds[0])\n        assert \"rejected_input_ids\" in res\n        assert \"rejected_labels\" in res\n        assert \"input_ids\" in res\n        assert \"labels\" in res\n        assert \"prompt_attention_mask\" in res\n\n        assert len(res[\"rejected_input_ids\"]) == len(res[\"rejected_labels\"])\n        assert len(res[\"input_ids\"]) == len(res[\"labels\"])\n        assert len(res[\"input_ids\"]) == len(res[\"prompt_attention_mask\"])\n\n        assert res[\"rejected_labels\"][0] == -100\n        assert res[\"rejected_input_ids\"][-1] == res[\"rejected_labels\"][-1]\n\n        assert res[\"labels\"][0] == -100\n        assert res[\"input_ids\"][-1] == res[\"labels\"][-1]\n\n        assert res[\"prompt_attention_mask\"][0] == 1\n        assert res[\"prompt_attention_mask\"][-1] == 0\n"
  },
  {
    "path": "tests/test_prompters.py",
    "content": "\"\"\"Module testing prompters\"\"\"\n\nimport unittest\n\nfrom axolotl.prompt_strategies.alpaca_w_system import SystemDataPrompter\nfrom axolotl.prompters import (\n    AlpacaPrompter,\n    MultipleChoiceExplainPrompter,\n    PromptStyle,\n    UnpromptedPrompter,\n)\n\n\nclass AlpacaPrompterTest(unittest.TestCase):\n    \"\"\"\n    Test AlpacaPrompter\n    \"\"\"\n\n    def test_prompt_style_w_none(self):\n        prompter = AlpacaPrompter(prompt_style=None)\n        res = next(prompter.build_prompt(\"tell me a joke\"))\n        # just testing that it uses instruct style\n        assert \"### Instruction:\" in res\n\n    def test_prompt_style_w_instruct(self):\n        prompter = AlpacaPrompter(prompt_style=PromptStyle.INSTRUCT.value)\n        res = next(\n            prompter.build_prompt(\"tell me a joke about the following\", \"alpacas\")\n        )\n        assert \"Below is an instruction\" in res\n        assert \"### Instruction:\" in res\n        assert \"### Input:\" in res\n        assert \"alpacas\" in res\n        assert \"### Response:\" in res\n        assert \"USER:\" not in res\n        assert \"ASSISTANT:\" not in res\n        res = next(prompter.build_prompt(\"tell me a joke about the following\"))\n        assert \"Below is an instruction\" in res\n        assert \"### Instruction:\" in res\n        assert \"### Input:\" not in res\n        assert \"### Response:\" in res\n        assert \"USER:\" not in res\n        assert \"ASSISTANT:\" not in res\n\n    def test_prompt_style_w_phi(self):\n        prompter = AlpacaPrompter(prompt_style=PromptStyle.PHI.value)\n        res = next(prompter.build_prompt(\"tell me a joke about the following\"))\n        assert (\n            \"\"\"<|system|>\nBelow is an instruction that describes a task. Write a response that appropriately completes the request.<|end|>\n<|user|>\ntell me a joke about the following<|end|>\n<|assistant|>\n\"\"\"\n            == res\n        )\n\n    def test_prompt_style_w_chat(self):\n        prompter = AlpacaPrompter(prompt_style=PromptStyle.CHAT.value)\n        res = next(\n            prompter.build_prompt(\"tell me a joke about the following\", \"alpacas\")\n        )\n        assert \"Below is an instruction\" in res\n        assert \"### Instruction:\" not in res\n        assert \"### Input:\" not in res\n        assert \"alpacas\" in res\n        assert \"### Response:\" not in res\n        assert \"USER:\" in res\n        assert \"ASSISTANT:\" in res\n        res = next(prompter.build_prompt(\"tell me a joke about the following\"))\n        assert \"Below is an instruction\" in res\n        assert \"### Instruction:\" not in res\n        assert \"### Input:\" not in res\n        assert \"### Response:\" not in res\n        assert \"USER:\" in res\n        assert \"ASSISTANT:\" in res\n\n    def test_system_prompt(self):\n        prompter = SystemDataPrompter(prompt_style=PromptStyle.CHAT.value)\n        res = next(\n            prompter.build_prompt_w_system(\n                \"use cot\", \"tell me a joke about the following\", \"alpacas\"\n            )\n        )\n        assert \"use cot\" in res\n        assert res.startswith(\"SYSTEM:\")\n        assert \"### Instruction:\" not in res\n        assert \"### Input:\" not in res\n        assert \"alpacas\" in res\n        assert \"### Response:\" not in res\n        assert \"USER:\" in res\n        assert \"ASSISTANT:\" in res\n\n\nclass UnpromptedPrompterTest(unittest.TestCase):\n    \"\"\"\n    Test class for UnpromptedPrompter with no system prompts\n    \"\"\"\n\n    def test_prompt_style_w_none(self):\n        prompter = UnpromptedPrompter(prompt_style=None)\n        res = next(prompter.build_prompt(\"tell me a joke\"))\n        assert \"### Instruction:\" in res\n        assert \"tell me a joke\" in res\n        assert res.startswith(\"###\")\n\n    def test_prompt_style_w_instruct(self):\n        prompter = UnpromptedPrompter(prompt_style=PromptStyle.INSTRUCT.value)\n        res = next(\n            prompter.build_prompt(\"tell me a joke about the following\", \"alpacas\")\n        )\n        assert \"### Instruction:\" in res\n        assert \"tell me a joke\" in res\n        assert res.startswith(\"###\")\n\n    def test_prompt_style_w_chat(self):\n        prompter = UnpromptedPrompter(prompt_style=PromptStyle.CHAT.value)\n        res = next(\n            prompter.build_prompt(\"tell me a joke about the following\", \"alpacas\")\n        )\n        assert \"USER:\" in res\n        assert \"tell me a joke\" in res\n        assert res.startswith(\"USER:\")\n\n\nclass MultipleChoiceExplainPrompterTest(unittest.TestCase):\n    \"\"\"\n    Test class for MultipleChoiceExplainPrompter\n    \"\"\"\n\n    def test_prompt_style_w_chat(self):\n        prompter = MultipleChoiceExplainPrompter(prompt_style=PromptStyle.CHAT.value)\n        res = next(prompter.build_prompt(\"choose one\", \"- A\\n- B\\n- C\", \"C\"))\n        assert \"USER:\" in res\n        assert \"choose one\" in res\n        assert \"Choose the answer that best answers the question.\" in res\n        assert \"- A\\n- B\\n- C\" in res\n"
  },
  {
    "path": "tests/test_revision_parameter.py",
    "content": "\"\"\"Tests for revision_of_model being passed to tokenizer and processor loaders.\"\"\"\n\nfrom unittest.mock import MagicMock, patch\n\nfrom transformers import PreTrainedTokenizerBase\n\nfrom axolotl.utils.dict import DictDefault\n\n\nclass TestRevisionParameter:\n    \"\"\"Tests for revision_of_model being passed to tokenizer and processor loaders.\"\"\"\n\n    @patch(\"axolotl.loaders.tokenizer.load_model_config\")\n    @patch(\"axolotl.loaders.tokenizer.AutoTokenizer\")\n    @patch(\n        \"axolotl.loaders.patch_manager.PatchManager.apply_pre_tokenizer_load_patches\"\n    )\n    def test_load_tokenizer_passes_revision(\n        self, _mock_patches, mock_auto_tokenizer, _mock_load_config\n    ):\n        mock_tokenizer = MagicMock()\n        mock_tokenizer.__class__.__name__ = \"MockTokenizer\"\n        mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer\n\n        cfg = DictDefault(\n            {\n                \"tokenizer_config\": \"some-model\",\n                \"revision_of_model\": \"abc123\",\n            }\n        )\n        from axolotl.loaders.tokenizer import load_tokenizer\n\n        load_tokenizer(cfg)\n\n        call_kwargs = mock_auto_tokenizer.from_pretrained.call_args\n        assert call_kwargs.kwargs.get(\"revision\") == \"abc123\"\n\n    @patch(\"axolotl.loaders.tokenizer.load_model_config\")\n    @patch(\"axolotl.loaders.tokenizer.AutoTokenizer\")\n    @patch(\n        \"axolotl.loaders.patch_manager.PatchManager.apply_pre_tokenizer_load_patches\"\n    )\n    def test_load_tokenizer_omits_revision_when_unset(\n        self, _mock_patches, mock_auto_tokenizer, _mock_load_config\n    ):\n        mock_tokenizer = MagicMock()\n        mock_tokenizer.__class__.__name__ = \"MockTokenizer\"\n        mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer\n\n        cfg = DictDefault(\n            {\n                \"tokenizer_config\": \"some-model\",\n            }\n        )\n        from axolotl.loaders.tokenizer import load_tokenizer\n\n        load_tokenizer(cfg)\n\n        call_kwargs = mock_auto_tokenizer.from_pretrained.call_args\n        assert \"revision\" not in call_kwargs.kwargs\n\n    @patch(\"axolotl.loaders.tokenizer.AutoTokenizer\")\n    @patch(\"axolotl.loaders.tokenizer.is_local_main_process\", return_value=True)\n    @patch(\"axolotl.loaders.tokenizer.barrier\")\n    def test_modify_tokenizer_files_passes_revision(\n        self, _mock_barrier, _mock_main, mock_auto_tokenizer, temp_dir\n    ):\n        mock_tokenizer = MagicMock()\n        mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer\n\n        from axolotl.loaders.tokenizer import modify_tokenizer_files\n\n        modify_tokenizer_files(\"some-model\", {}, output_dir=temp_dir, revision=\"abc123\")\n\n        call_kwargs = mock_auto_tokenizer.from_pretrained.call_args\n        assert call_kwargs.kwargs.get(\"revision\") == \"abc123\"\n\n    @patch(\"axolotl.loaders.tokenizer.AutoTokenizer\")\n    @patch(\"axolotl.loaders.tokenizer.is_local_main_process\", return_value=True)\n    @patch(\"axolotl.loaders.tokenizer.barrier\")\n    def test_modify_tokenizer_files_defaults_revision_to_main(\n        self, _mock_barrier, _mock_main, mock_auto_tokenizer, temp_dir\n    ):\n        mock_tokenizer = MagicMock()\n        mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer\n\n        from axolotl.loaders.tokenizer import modify_tokenizer_files\n\n        modify_tokenizer_files(\"some-model\", {}, output_dir=temp_dir)\n\n        call_kwargs = mock_auto_tokenizer.from_pretrained.call_args\n        assert call_kwargs.kwargs.get(\"revision\") == \"main\"\n\n    @patch(\"axolotl.loaders.processor.AutoProcessor\")\n    def test_load_processor_passes_revision(self, mock_auto_processor):\n        mock_processor = MagicMock()\n        mock_processor.size = {}\n        mock_auto_processor.from_pretrained.return_value = mock_processor\n\n        cfg = DictDefault(\n            {\n                \"processor_config\": \"some-model\",\n                \"revision_of_model\": \"abc123\",\n                \"trust_remote_code\": False,\n            }\n        )\n        tokenizer = MagicMock(spec=PreTrainedTokenizerBase)\n\n        from axolotl.loaders.processor import load_processor\n\n        load_processor(cfg, tokenizer)\n\n        call_kwargs = mock_auto_processor.from_pretrained.call_args\n        assert call_kwargs.kwargs.get(\"revision\") == \"abc123\"\n\n    @patch(\"axolotl.loaders.processor.AutoProcessor\")\n    def test_load_processor_omits_revision_when_unset(self, mock_auto_processor):\n        mock_processor = MagicMock()\n        mock_processor.size = {}\n        mock_auto_processor.from_pretrained.return_value = mock_processor\n\n        cfg = DictDefault(\n            {\n                \"processor_config\": \"some-model\",\n                \"trust_remote_code\": False,\n            }\n        )\n        tokenizer = MagicMock(spec=PreTrainedTokenizerBase)\n\n        from axolotl.loaders.processor import load_processor\n\n        load_processor(cfg, tokenizer)\n\n        call_kwargs = mock_auto_processor.from_pretrained.call_args\n        assert \"revision\" not in call_kwargs.kwargs\n"
  },
  {
    "path": "tests/test_save_deduplicated.py",
    "content": "\"\"\"Tests to verify that deduplication runs before dataset saving during preprocessing.\n\nThis addresses GitHub issue #2719: Save De-duplicated Set During Pre-processing.\n\"\"\"\n\nfrom unittest.mock import MagicMock, patch\n\nfrom datasets import Dataset\n\nfrom axolotl.utils.dict import DictDefault\n\n\nclass TestSFTSaveDeduplicatedBeforeSave:\n    \"\"\"Verify that in SFT data loading, deduplication occurs before saving.\"\"\"\n\n    @patch(\"axolotl.utils.data.sft.save_preprocessed_dataset\")\n    @patch(\"axolotl.utils.data.sft.generate_dataset_hash_from_config\")\n    @patch(\"axolotl.utils.data.sft.deduplicate_and_log_datasets\")\n    @patch(\"axolotl.utils.data.sft.merge_datasets\")\n    @patch(\"axolotl.utils.data.sft._load_and_process_single_dataset\")\n    @patch(\"axolotl.utils.data.sft.datasets_with_name_generator\")\n    def test_dedup_called_before_save_sft(\n        self,\n        mock_datasets_gen,\n        mock_load_single,\n        mock_merge,\n        mock_dedup,\n        mock_gen_hash,\n        mock_save,\n    ):\n        \"\"\"Deduplication should be called before save_preprocessed_dataset in SFT.\"\"\"\n        from axolotl.utils.data.sft import _load_raw_datasets\n\n        # Set up mock data\n        dataset = Dataset.from_dict({\"text\": [\"a\", \"b\", \"a\"], \"label\": [1, 2, 1]})\n        deduped_dataset = Dataset.from_dict({\"text\": [\"a\", \"b\"], \"label\": [1, 2]})\n\n        mock_datasets_gen.return_value = [\n            DictDefault({\"path\": \"test\", \"type\": \"alpaca\"})\n        ]\n        mock_load_single.return_value = (dataset, None)\n        mock_merge.return_value = dataset\n        mock_dedup.return_value = (deduped_dataset, None)\n        mock_gen_hash.return_value = \"testhash\"\n\n        cfg = DictDefault(\n            {\n                \"skip_prepare_dataset\": False,\n                \"dataset_exact_deduplication\": True,\n                \"sequence_len\": 1024,\n                \"eval_sequence_len\": None,\n                \"sample_packing\": False,\n                \"is_preprocess\": False,\n                \"seed\": 42,\n                \"datasets\": [{\"path\": \"test\", \"type\": \"alpaca\"}],\n            }\n        )\n\n        tokenizer = MagicMock()\n        tokenizer.name_or_path = \"test-tokenizer\"\n\n        # Track call order\n        call_order = []\n        mock_dedup.side_effect = lambda **kwargs: (\n            call_order.append(\"dedup\") or (deduped_dataset, None)\n        )\n        mock_save.side_effect = lambda *args, **kwargs: call_order.append(\"save\")\n\n        _load_raw_datasets(\n            cfg=cfg,\n            datasets_configs=cfg.datasets,\n            tokenizer=tokenizer,\n            split=\"train\",\n        )\n\n        # Verify dedup was called\n        assert \"dedup\" in call_order, \"Deduplication should have been called\"\n        # Verify save was called\n        assert \"save\" in call_order, \"Save should have been called\"\n        # Verify dedup happened before save\n        assert call_order.index(\"dedup\") < call_order.index(\"save\"), (\n            \"Deduplication must occur before saving the dataset\"\n        )\n\n    @patch(\"axolotl.utils.data.sft.save_preprocessed_dataset\")\n    @patch(\"axolotl.utils.data.sft.generate_dataset_hash_from_config\")\n    @patch(\"axolotl.utils.data.sft.merge_datasets\")\n    @patch(\"axolotl.utils.data.sft._load_and_process_single_dataset\")\n    @patch(\"axolotl.utils.data.sft.datasets_with_name_generator\")\n    def test_no_dedup_when_disabled_sft(\n        self,\n        mock_datasets_gen,\n        mock_load_single,\n        mock_merge,\n        mock_gen_hash,\n        mock_save,\n    ):\n        \"\"\"Deduplication should not be called when dataset_exact_deduplication is False.\"\"\"\n        from axolotl.utils.data.sft import _load_raw_datasets\n\n        dataset = Dataset.from_dict({\"text\": [\"a\", \"b\", \"a\"], \"label\": [1, 2, 1]})\n\n        mock_datasets_gen.return_value = [\n            DictDefault({\"path\": \"test\", \"type\": \"alpaca\"})\n        ]\n        mock_load_single.return_value = (dataset, None)\n        mock_merge.return_value = dataset\n        mock_gen_hash.return_value = \"testhash\"\n\n        cfg = DictDefault(\n            {\n                \"skip_prepare_dataset\": False,\n                \"dataset_exact_deduplication\": False,\n                \"sequence_len\": 1024,\n                \"eval_sequence_len\": None,\n                \"sample_packing\": False,\n                \"is_preprocess\": False,\n                \"seed\": 42,\n                \"datasets\": [{\"path\": \"test\", \"type\": \"alpaca\"}],\n            }\n        )\n\n        tokenizer = MagicMock()\n        tokenizer.name_or_path = \"test-tokenizer\"\n\n        with patch(\"axolotl.utils.data.sft.deduplicate_and_log_datasets\") as mock_dedup:\n            _load_raw_datasets(\n                cfg=cfg,\n                datasets_configs=cfg.datasets,\n                tokenizer=tokenizer,\n                split=\"train\",\n            )\n            mock_dedup.assert_not_called()\n\n\nclass TestRLSaveDeduplicatedBeforeSave:\n    \"\"\"Verify that in RL data loading, deduplication occurs before saving.\"\"\"\n\n    @patch.object(Dataset, \"filter\", lambda self, *args, **kwargs: self)\n    @patch(\"axolotl.utils.data.rl.save_preprocessed_dataset\")\n    @patch(\"axolotl.utils.data.rl.generate_dataset_hash_from_config\")\n    @patch(\"axolotl.utils.data.rl.deduplicate_and_log_datasets\")\n    @patch(\"axolotl.utils.data.rl.merge_datasets\")\n    @patch(\"axolotl.utils.data.rl.load_dataset_with_config\")\n    @patch(\"axolotl.utils.data.rl.datasets_with_name_generator\")\n    @patch(\"axolotl.utils.data.rl.load_tokenizer\")\n    def test_dedup_called_before_save_rl(\n        self,\n        mock_load_tokenizer,\n        mock_datasets_gen,\n        mock_load_dataset,\n        mock_merge,\n        mock_dedup,\n        mock_gen_hash,\n        mock_save,\n    ):\n        \"\"\"Deduplication should be called before save_preprocessed_dataset in RL.\"\"\"\n        from axolotl.utils.data.rl import _load_split\n\n        dataset = Dataset.from_dict(\n            {\n                \"prompt\": [\"hi\", \"bye\", \"hi\"],\n                \"chosen\": [\"a\", \"b\", \"a\"],\n                \"rejected\": [\"c\", \"d\", \"c\"],\n            }\n        )\n        deduped_dataset = Dataset.from_dict(\n            {\n                \"prompt\": [\"hi\", \"bye\"],\n                \"chosen\": [\"a\", \"b\"],\n                \"rejected\": [\"c\", \"d\"],\n            }\n        )\n\n        mock_datasets_gen.return_value = [DictDefault({\"path\": \"test\", \"type\": None})]\n        mock_load_dataset.return_value = dataset\n        mock_merge.return_value = dataset\n        mock_dedup.return_value = (deduped_dataset, None)\n        mock_gen_hash.return_value = \"testhash\"\n\n        tokenizer = MagicMock()\n        tokenizer.name_or_path = \"test-tokenizer\"\n        mock_load_tokenizer.return_value = tokenizer\n\n        cfg = DictDefault(\n            {\n                \"skip_prepare_dataset\": False,\n                \"dataset_exact_deduplication\": True,\n                \"sequence_len\": 1024,\n                \"rl\": \"dpo\",\n                \"datasets\": [{\"path\": \"test\", \"type\": None}],\n                \"hf_use_auth_token\": False,\n                \"dataset_num_proc\": 1,\n                \"is_preprocess\": False,\n            }\n        )\n\n        call_order = []\n        mock_dedup.side_effect = lambda **kwargs: (\n            call_order.append(\"dedup\") or (deduped_dataset, None)\n        )\n        mock_save.side_effect = lambda *args, **kwargs: call_order.append(\"save\")\n\n        _load_split(cfg, split=\"train\")\n\n        assert \"dedup\" in call_order, \"Deduplication should have been called\"\n        assert \"save\" in call_order, \"Save should have been called\"\n        assert call_order.index(\"dedup\") < call_order.index(\"save\"), (\n            \"Deduplication must occur before saving the dataset\"\n        )\n"
  },
  {
    "path": "tests/test_schedulers.py",
    "content": "\"\"\"\ntest module for the axolotl.utis.data module\n\"\"\"\n\nimport unittest\n\nimport torch\nfrom torch.optim import SGD\n\nfrom axolotl.utils.schedulers import get_cosine_schedule_with_warmup_decay_constant\n\n\nclass TestCosineConstantLr(unittest.TestCase):\n    \"\"\"\n    test class for encode pretraining and md5 helper\n    \"\"\"\n\n    def setUp(self):\n        self.train_steps = 1000\n        self.warmup_steps = 10\n        self.min_lr_ratio = 0.1\n        self.constant_lr_ratio = 0.8\n        self._lr = 0.01\n        self.optimizer = SGD([torch.tensor(1)], lr=self._lr)\n        self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant(\n            self.optimizer,\n            num_warmup_steps=self.warmup_steps,\n            num_training_steps=self.train_steps,\n            min_lr_ratio=self.min_lr_ratio,\n            constant_lr_ratio=self.constant_lr_ratio,\n        )\n\n    def test_schedulers(self):\n        self.assertEqual(self.lr_scheduler.get_last_lr()[0], 0)\n        for _ in range(self.warmup_steps):\n            self.optimizer.step()\n            self.lr_scheduler.step()\n        self.assertEqual(self.lr_scheduler.get_last_lr()[0], self._lr)\n        constant_step = int(self.train_steps * self.constant_lr_ratio)\n        remaining_step = self.train_steps - constant_step\n        for _ in range(constant_step):\n            self.optimizer.step()\n            self.lr_scheduler.step()\n        self.assertEqual(\n            self.lr_scheduler.get_last_lr()[0], self._lr * self.min_lr_ratio\n        )\n        for _ in range(remaining_step):\n            self.optimizer.step()\n            self.lr_scheduler.step()\n        self.assertEqual(\n            self.lr_scheduler.get_last_lr()[0], self._lr * self.min_lr_ratio\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/test_streaming.py",
    "content": "\"\"\"Test streaming configuration and data loading functionality.\"\"\"\n\nimport unittest\nfrom unittest.mock import Mock, patch\n\nfrom datasets import IterableDataset\n\nfrom axolotl.utils.config import validate_config\nfrom axolotl.utils.data.sft import (\n    _prepare_streaming_dataset,\n    prepare_datasets,\n)\nfrom axolotl.utils.dict import DictDefault\n\n\nclass TestStreamingConfig(unittest.TestCase):\n    \"\"\"Test streaming configuration and deprecation handling.\"\"\"\n\n    def test_streaming_multipack_buffer_size_deprecation(self):\n        \"\"\"Test that pretrain_multipack_buffer_size is properly deprecated.\"\"\"\n        # Test with old config name\n        cfg_old = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"pretrain_multipack_buffer_size\": 5000,\n                \"datasets\": [{\"path\": \"test/dataset\", \"type\": \"alpaca\"}],\n                \"sequence_len\": 256,\n                \"micro_batch_size\": 1,\n                \"gradient_accumulation_steps\": 1,\n                \"learning_rate\": 0.0001,\n            }\n        )\n\n        with self.assertLogs(\"axolotl.utils.schemas.validation\", level=\"WARNING\") as cm:\n            validated_cfg = validate_config(cfg_old)\n            self.assertIn(\"pretrain_multipack_buffer_size` is deprecated\", cm.output[0])\n\n        self.assertEqual(validated_cfg.streaming_multipack_buffer_size, 5000)\n        self.assertIsNone(\n            getattr(validated_cfg, \"pretrain_multipack_buffer_size\", None)\n        )\n\n    def test_streaming_multipack_buffer_size_new(self):\n        \"\"\"Test that new streaming_multipack_buffer_size works correctly.\"\"\"\n        cfg_new = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"streaming_multipack_buffer_size\": 7000,\n                \"datasets\": [{\"path\": \"test/dataset\", \"type\": \"alpaca\"}],\n                \"sequence_len\": 256,\n                \"micro_batch_size\": 1,\n                \"gradient_accumulation_steps\": 1,\n                \"learning_rate\": 0.0001,\n            }\n        )\n\n        validated_cfg = validate_config(cfg_new)\n        self.assertEqual(validated_cfg.streaming_multipack_buffer_size, 7000)\n\n    def test_both_buffer_sizes_raises_error(self):\n        \"\"\"Test that having both old and new buffer size configs raises an error.\"\"\"\n        cfg_both = DictDefault(\n            {\n                \"base_model\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"pretrain_multipack_buffer_size\": 5000,\n                \"streaming_multipack_buffer_size\": 7000,\n                \"datasets\": [{\"path\": \"test/dataset\", \"type\": \"alpaca\"}],\n                \"sequence_len\": 256,\n                \"micro_batch_size\": 1,\n                \"gradient_accumulation_steps\": 1,\n                \"learning_rate\": 0.0001,\n            }\n        )\n\n        with self.assertRaises(ValueError) as cm:\n            validate_config(cfg_both)\n        self.assertIn(\"both are set\", str(cm.exception))\n\n\nclass TestStreamingDatasetPreparation(unittest.TestCase):\n    \"\"\"Test dataset preparation with streaming configuration.\"\"\"\n\n    def setUp(self):\n        self.tokenizer = Mock()\n        self.tokenizer.pad_token_id = 0\n        self.tokenizer.eos_token_id = 1\n\n    @patch(\"axolotl.utils.data.sft._prepare_streaming_dataset\")\n    def test_prepare_datasets_with_streaming_true(self, mock_prepare_streaming):\n        \"\"\"Test that streaming=True triggers streaming dataset preparation.\"\"\"\n        cfg = DictDefault(\n            {\n                \"streaming\": True,\n                \"datasets\": [{\"path\": \"test/dataset\", \"type\": \"alpaca\"}],\n            }\n        )\n\n        mock_prepare_streaming.return_value = (Mock(), None, 100, [])\n\n        prepare_datasets(cfg, self.tokenizer)\n\n        mock_prepare_streaming.assert_called_once_with(cfg, self.tokenizer, None)\n\n    @patch(\"axolotl.utils.data.sft._prepare_streaming_dataset\")\n    def test_prepare_datasets_with_pretraining_dataset(self, mock_prepare_streaming):\n        \"\"\"Test that pretraining_dataset triggers streaming dataset preparation.\"\"\"\n        cfg = DictDefault(\n            {\n                \"pretraining_dataset\": \"test/dataset\",\n            }\n        )\n\n        mock_prepare_streaming.return_value = (Mock(), None, 100, [])\n\n        prepare_datasets(cfg, self.tokenizer)\n\n        mock_prepare_streaming.assert_called_once_with(cfg, self.tokenizer, None)\n\n    @patch(\"axolotl.utils.data.sft._prepare_standard_dataset\")\n    def test_prepare_datasets_without_streaming(self, mock_prepare_standard):\n        \"\"\"Test that without streaming, standard dataset preparation is used.\"\"\"\n        cfg = DictDefault(\n            {\n                \"datasets\": [{\"path\": \"test/dataset\", \"type\": \"alpaca\"}],\n            }\n        )\n\n        mock_prepare_standard.return_value = (Mock(), None, 100, [])\n\n        prepare_datasets(cfg, self.tokenizer)\n\n        mock_prepare_standard.assert_called_once_with(cfg, self.tokenizer, None)\n\n\nclass TestStreamingWithSamplePacking(unittest.TestCase):\n    \"\"\"Test streaming dataset preparation with sample packing.\"\"\"\n\n    def setUp(self):\n        self.tokenizer = Mock()\n        self.tokenizer.pad_token_id = 0\n        self.tokenizer.eos_token_id = 1\n\n    @patch(\"axolotl.utils.data.sft._load_streaming_dataset\")\n    def test_streaming_sft_with_sample_packing_sets_split(self, mock_load_streaming):\n        \"\"\"Test that streaming SFT with sample_packing sets default split.\"\"\"\n        cfg = DictDefault(\n            {\n                \"streaming\": True,\n                \"sample_packing\": True,\n                \"datasets\": [{\"path\": \"test/dataset\", \"type\": \"alpaca\"}],\n                \"sequence_len\": 256,\n                \"micro_batch_size\": 1,\n            }\n        )\n\n        mock_load_streaming.return_value = Mock(spec=IterableDataset)\n\n        with patch(\"axolotl.utils.data.sft._load_and_prepare_datasets\"):\n            _prepare_streaming_dataset(cfg, self.tokenizer, None)\n\n            # Check that the dataset config has split set to 'train'\n            call_args = mock_load_streaming.call_args\n            dataset_config = call_args[0][0]\n            self.assertEqual(dataset_config.split, \"train\")\n\n    def test_multipack_attn_forced_true_for_sft(self):\n        \"\"\"Test that multipack_attn is forced to True for SFT with sample packing.\"\"\"\n        from axolotl.utils.data.streaming import wrap_streaming_dataset\n\n        cfg = DictDefault(\n            {\n                \"sample_packing\": True,\n                \"pretrain_multipack_attn\": False,  # Should be overridden for SFT\n                \"pretraining_dataset\": None,  # This makes it SFT\n                \"sequence_len\": 256,\n                \"micro_batch_size\": 1,\n                \"streaming_multipack_buffer_size\": 1000,\n                \"seed\": 42,\n            }\n        )\n\n        mock_dataset = Mock()\n        mock_dataset.features = None  # For streaming datasets\n        mock_dataset.__iter__ = Mock(return_value=iter([]))  # Empty iterator\n        mock_dataset.map = Mock(return_value=mock_dataset)\n        mock_ds_wrapper = Mock()\n\n        with patch(\n            \"axolotl.utils.data.streaming.PretrainingBatchSamplerDataCollatorForSeq2Seq\"\n        ) as mock_collator:\n            with patch(\"axolotl.utils.data.streaming.encode_packed_streaming\"):\n                wrap_streaming_dataset(\n                    mock_dataset, self.tokenizer, cfg, mock_ds_wrapper\n                )\n\n                # Check that multipack_attn=True was used in the collator\n                mock_collator.assert_called_once()\n                call_kwargs = mock_collator.call_args[1]\n                self.assertTrue(call_kwargs[\"multipack_attn\"])\n\n    def test_multipack_attn_respects_config_for_pretraining(self):\n        \"\"\"Test that multipack_attn respects config for pretraining datasets.\"\"\"\n        from axolotl.utils.data.streaming import wrap_streaming_dataset\n\n        cfg = DictDefault(\n            {\n                \"sample_packing\": True,\n                \"pretrain_multipack_attn\": False,  # Should be respected for pretraining\n                \"pretraining_dataset\": \"test/dataset\",  # This makes it pretraining\n                \"sequence_len\": 256,\n                \"micro_batch_size\": 1,\n                \"streaming_multipack_buffer_size\": 1000,\n                \"seed\": 42,\n            }\n        )\n\n        mock_dataset = Mock()\n        mock_dataset.features = None  # For streaming datasets\n        mock_dataset.__iter__ = Mock(return_value=iter([]))  # Empty iterator\n        mock_dataset.map = Mock(return_value=mock_dataset)\n        mock_ds_wrapper = Mock()\n\n        with patch(\n            \"axolotl.utils.data.streaming.PretrainingBatchSamplerDataCollatorForSeq2Seq\"\n        ) as mock_collator:\n            with patch(\"axolotl.utils.data.streaming.encode_packed_streaming\"):\n                wrap_streaming_dataset(\n                    mock_dataset, self.tokenizer, cfg, mock_ds_wrapper\n                )\n\n                # Check that multipack_attn=False was used (respecting config)\n                mock_collator.assert_called_once()\n                call_kwargs = mock_collator.call_args[1]\n                self.assertFalse(call_kwargs[\"multipack_attn\"])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/test_tensor_parallel_batch_size.py",
    "content": "\"\"\"Tests for batch_size calculation with tensor parallelism.\"\"\"\n\nfrom unittest.mock import patch\n\nimport addict\nimport pytest\n\nfrom axolotl.utils.config import normalize_config, validate_config\nfrom axolotl.utils.dict import DictDefault\n\n\n@pytest.fixture(name=\"tp_base_cfg\")\ndef fixture_tp_base_cfg(min_base_cfg):\n    return (\n        DictDefault(\n            micro_batch_size=2,\n            gradient_accumulation_steps=4,\n            sequence_len=2048,\n            num_epochs=1,\n        )\n        | min_base_cfg\n    )\n\n\nclass TestTensorParallelBatchSize:\n    \"\"\"Verify batch_size scales by effective dp world_size when using tensor parallelism.\"\"\"\n\n    @pytest.mark.parametrize(\n        \"world_size, tensor_parallel_size, expected_batch_size\",\n        [\n            (4, 1, 32),  # no TP: 2*4*4 = 32\n            (4, 2, 16),  # TP=2: 2*4*(4//2) = 16\n            (4, 4, 8),  # TP=4: 2*4*(4//4) = 8\n            (2, 2, 8),  # TP=ws: 2*4*(2//2) = 8 (no scaling)\n        ],\n    )\n    def test_batch_size_with_tensor_parallelism(\n        self,\n        tp_base_cfg,\n        monkeypatch,\n        world_size,\n        tensor_parallel_size,\n        expected_batch_size,\n    ):\n        monkeypatch.setenv(\"WORLD_SIZE\", str(world_size))\n        tp_base_cfg[\"tensor_parallel_size\"] = tensor_parallel_size\n        cfg = validate_config(tp_base_cfg)\n        # Mock load_model_config to avoid downloading the model and to bypass\n        # the tie_word_embeddings validation that blocks TP > 1.\n        with patch(\n            \"axolotl.utils.config.load_model_config\",\n            return_value=addict.Dict({\"model_type\": \"llama\"}),\n        ):\n            normalize_config(cfg)\n        assert cfg.batch_size == expected_batch_size\n"
  },
  {
    "path": "tests/test_tokenizers.py",
    "content": "\"\"\"\nTest cases for the tokenizer loading\n\"\"\"\n\nimport unittest\n\nimport pytest\n\nfrom axolotl.loaders import load_tokenizer\nfrom axolotl.utils.dict import DictDefault\n\nfrom tests.hf_offline_utils import enable_hf_offline\n\n\nclass TestTokenizers:\n    \"\"\"\n    test class for the load_tokenizer fn\n    \"\"\"\n\n    @pytest.mark.skip(\"LlamaTokenizer no longer has a Fast/Slow tokenizer\")\n    @enable_hf_offline\n    def test_default_use_fast(self):\n        cfg = DictDefault(\n            {\n                \"tokenizer_config\": \"huggyllama/llama-7b\",\n            }\n        )\n        tokenizer = load_tokenizer(cfg)\n        assert \"Fast\" in tokenizer.__class__.__name__\n\n    @pytest.mark.skip(\"LlamaTokenizer no longer has a Fast/Slow tokenizer\")\n    @enable_hf_offline\n    def test_dont_use_fast(self):\n        cfg = DictDefault(\n            {\n                \"tokenizer_config\": \"huggyllama/llama-7b\",\n                \"tokenizer_use_fast\": False,\n            }\n        )\n        tokenizer = load_tokenizer(cfg)\n        assert \"Fast\" not in tokenizer.__class__.__name__\n\n    @enable_hf_offline\n    def test_special_tokens_modules_to_save(self):\n        # setting special_tokens to new token\n        cfg = DictDefault(\n            {\n                \"tokenizer_config\": \"huggyllama/llama-7b\",\n                \"adapter\": \"lora\",\n                \"special_tokens\": {\"bos_token\": \"[INST]\"},\n            }\n        )\n        with pytest.raises(\n            ValueError,\n            match=r\".*Please set lora_modules_to_save*\",\n        ):\n            load_tokenizer(cfg)\n\n        # setting special_tokens but not changing from default\n        cfg = DictDefault(\n            {\n                \"tokenizer_config\": \"huggyllama/llama-7b\",\n                \"adapter\": \"lora\",\n                \"special_tokens\": {\"bos_token\": \"<s>\"},\n            }\n        )\n        load_tokenizer(cfg)\n\n        # non-adapter setting special_tokens\n        cfg = DictDefault(\n            {\n                \"tokenizer_config\": \"huggyllama/llama-7b\",\n                \"special_tokens\": {\"bos_token\": \"[INST]\"},\n            }\n        )\n        load_tokenizer(cfg)\n\n    @enable_hf_offline\n    def test_add_additional_special_tokens(self):\n        cfg = DictDefault(\n            {\n                \"tokenizer_config\": \"huggyllama/llama-7b\",\n                \"special_tokens\": {\"additional_special_tokens\": [\"<|im_start|>\"]},\n            }\n        )\n        tokenizer = load_tokenizer(cfg)\n        assert \"LlamaTokenizer\" in tokenizer.__class__.__name__\n        assert tokenizer(\"<|im_start|>user\")[\"input_ids\"] == [1, 32000, 1792]\n        assert len(tokenizer) == 32001\n\n        # ensure reloading the tokenizer again from cfg results in same vocab length\n        tokenizer = load_tokenizer(cfg)\n        assert len(tokenizer) == 32001\n\n    @enable_hf_offline\n    def test_added_tokens_overrides(self, temp_dir):\n        cfg = DictDefault(\n            {\n                # use with tokenizer that has reserved_tokens in added_tokens\n                \"tokenizer_config\": \"NousResearch/Llama-3.2-1B\",\n                \"added_tokens_overrides\": {\n                    128041: \"RANDOM_OVERRIDE_1\",\n                    128042: \"RANDOM_OVERRIDE_2\",\n                },\n                \"output_dir\": temp_dir,\n            }\n        )\n\n        tokenizer = load_tokenizer(cfg)\n        assert tokenizer.encode(\"RANDOM_OVERRIDE_1\", add_special_tokens=False) == [\n            128041\n        ]\n        assert tokenizer.encode(\"RANDOM_OVERRIDE_2\", add_special_tokens=False) == [\n            128042\n        ]\n        assert (\n            tokenizer.decode([128041, 128042]) == \"RANDOM_OVERRIDE_1RANDOM_OVERRIDE_2\"\n        )\n\n    @pytest.mark.skip(\"FIXME slow test sdist py3.11 + torch2.8.0\")\n    @enable_hf_offline\n    def test_added_tokens_overrides_gemma3(self, temp_dir):\n        cfg = DictDefault(\n            {\n                # use with tokenizer that has reserved_tokens in added_tokens\n                \"tokenizer_config\": \"mlx-community/gemma-3-4b-it-8bit\",\n                \"added_tokens_overrides\": {\n                    256001: \"RANDOM_OVERRIDE_1\",\n                    256002: \"RANDOM_OVERRIDE_2\",\n                },\n                \"output_dir\": temp_dir,\n            }\n        )\n\n        tokenizer = load_tokenizer(cfg)\n        assert tokenizer.encode(\"RANDOM_OVERRIDE_1\", add_special_tokens=False) == [\n            256001\n        ]\n        assert tokenizer.encode(\"RANDOM_OVERRIDE_2\", add_special_tokens=False) == [\n            256002\n        ]\n        assert (\n            tokenizer.decode([256001, 256002]) == \"RANDOM_OVERRIDE_1RANDOM_OVERRIDE_2\"\n        )\n\n    @enable_hf_offline\n    def test_added_tokens_overrides_with_toolargeid(self, temp_dir):\n        cfg = DictDefault(\n            {\n                # use with tokenizer that has reserved_tokens in added_tokens\n                \"tokenizer_config\": \"HuggingFaceTB/SmolLM2-135M\",\n                \"added_tokens_overrides\": {1000000: \"BROKEN_RANDOM_OVERRIDE_1\"},\n                \"output_dir\": temp_dir,\n            }\n        )\n\n        with pytest.raises(\n            ValueError, match=r\".*Token ID 1000000 not found in added_tokens.*\"\n        ):\n            load_tokenizer(cfg)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/test_train.py",
    "content": "\"\"\"Test for batch size calculation for multi-gpu training.\"\"\"\n\nimport pytest\n\nfrom axolotl.utils.config import normalize_config, validate_config\nfrom axolotl.utils.dict import DictDefault\n\n\n@pytest.fixture(name=\"train_base_cfg\")\ndef fixture_train_base_cfg(min_base_cfg):\n    return (\n        DictDefault(\n            micro_batch_size=2,\n            gradient_accumulation_steps=4,\n            sequence_len=2048,\n            sample_packing=True,\n            num_epochs=1,\n        )\n        | min_base_cfg\n    )\n\n\nclass TestTrain:\n    \"\"\"test class for train related tests\"\"\"\n\n    @pytest.mark.parametrize(\n        \"world_size, expected_batch_size\",\n        [\n            (1, 8),\n            (4, 32),\n        ],\n    )\n    def test_batch_size_ddp(\n        self, train_base_cfg, monkeypatch, world_size, expected_batch_size\n    ):\n        monkeypatch.setenv(\"WORLD_SIZE\", str(world_size))\n        cfg = validate_config(train_base_cfg)\n        normalize_config(cfg)\n        assert cfg.batch_size == expected_batch_size\n"
  },
  {
    "path": "tests/test_triton_kernels.py",
    "content": "# Copyright 2026 Axolotl AI. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n\n\"\"\"Unit tests for Triton kernels: entropy_from_logits and selective_log_softmax.\n\nAdapted from harness/test_entropy.py and harness/test_selective_logsoftmax.py\ninto proper pytest tests, plus new OOB index safety tests.\n\"\"\"\n\nimport math\n\nimport pytest\nimport torch\nimport torch.nn.functional as F\n\npytestmark = pytest.mark.skipif(\n    not torch.cuda.is_available(), reason=\"CUDA required for Triton kernels\"\n)\n\n\n# ---------------------------------------------------------------------------\n# Reference implementations\n# ---------------------------------------------------------------------------\n\n\ndef _ref_entropy(logits):\n    \"\"\"Reference entropy via log_softmax (numerically stable).\"\"\"\n    logp = F.log_softmax(logits.float(), dim=-1)\n    return -(logp.exp() * logp).sum(dim=-1)\n\n\ndef _ref_selective_log_softmax(logits, index):\n    \"\"\"Reference selective log softmax via PyTorch gather.\"\"\"\n    squeeze = index.ndim == logits.ndim - 1\n    if squeeze:\n        index = index.unsqueeze(-1)\n    log_probs = F.log_softmax(logits.float(), dim=-1)\n    result = torch.gather(log_probs, dim=-1, index=index)\n    if squeeze:\n        result = result.squeeze(-1)\n    return result\n\n\n# ---------------------------------------------------------------------------\n# entropy_from_logits\n# ---------------------------------------------------------------------------\n\n\nclass TestEntropyFromLogits:\n    @pytest.mark.parametrize(\n        \"B,L\",\n        [\n            (1, 128),\n            (1, 2048),\n            (4, 512),\n            (8, 256),\n            (1, 1),\n        ],\n    )\n    def test_correctness_various_shapes(self, B, L):\n        from axolotl.monkeypatch.trainer.utils import entropy_from_logits\n\n        V = 1024\n        torch.manual_seed(42)\n        logits = torch.randn(B, L, V, device=\"cuda\", dtype=torch.float32)\n        result = entropy_from_logits(logits)\n        expected = _ref_entropy(logits)\n        assert result.shape == (B, L)\n        torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4)\n\n    def test_2d_input(self):\n        from axolotl.monkeypatch.trainer.utils import entropy_from_logits\n\n        logits = torch.randn(16, 256, device=\"cuda\", dtype=torch.float32)\n        result = entropy_from_logits(logits)\n        expected = _ref_entropy(logits)\n        assert result.shape == (16,)\n        torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4)\n\n    def test_large_vocab(self):\n        from axolotl.monkeypatch.trainer.utils import entropy_from_logits\n\n        V = 32000\n        logits = torch.randn(2, V, device=\"cuda\", dtype=torch.float32)\n        result = entropy_from_logits(logits)\n        expected = _ref_entropy(logits)\n        torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4)\n\n    def test_uniform_distribution(self):\n        \"\"\"Uniform logits -> entropy = log(V).\"\"\"\n        from axolotl.monkeypatch.trainer.utils import entropy_from_logits\n\n        V = 1024\n        logits = torch.zeros(2, V, device=\"cuda\", dtype=torch.float32)\n        result = entropy_from_logits(logits)\n        expected_val = math.log(V)\n        torch.testing.assert_close(\n            result,\n            torch.full((2,), expected_val, device=\"cuda\", dtype=torch.float32),\n            atol=1e-4,\n            rtol=1e-4,\n        )\n\n    def test_peaked_distribution(self):\n        \"\"\"One-hot-like logits -> entropy near 0.\"\"\"\n        from axolotl.monkeypatch.trainer.utils import entropy_from_logits\n\n        logits = torch.full((2, 128), -100.0, device=\"cuda\", dtype=torch.float32)\n        logits[:, 0] = 100.0\n        result = entropy_from_logits(logits)\n        assert (result < 1e-3).all()\n\n    def test_bfloat16(self):\n        from axolotl.monkeypatch.trainer.utils import entropy_from_logits\n\n        logits = torch.randn(4, 256, device=\"cuda\", dtype=torch.bfloat16)\n        result = entropy_from_logits(logits)\n        expected = _ref_entropy(logits.float())\n        assert result.dtype == torch.bfloat16\n        torch.testing.assert_close(result.float(), expected, atol=5e-2, rtol=5e-2)\n\n    def test_float16(self):\n        from axolotl.monkeypatch.trainer.utils import entropy_from_logits\n\n        logits = torch.randn(4, 256, device=\"cuda\", dtype=torch.float16)\n        result = entropy_from_logits(logits)\n        expected = _ref_entropy(logits.float())\n        assert result.dtype == torch.float16\n        torch.testing.assert_close(result.float(), expected, atol=5e-2, rtol=5e-2)\n\n    def test_non_contiguous_3d_transpose(self):\n        \"\"\"Non-contiguous 3D tensor via transpose(0,1).\"\"\"\n        from axolotl.monkeypatch.trainer.utils import entropy_from_logits\n\n        V = 256\n        raw = torch.randn(32, 4, V, device=\"cuda\", dtype=torch.float32)\n        logits = raw.transpose(0, 1)  # (4, 32, V) non-contiguous\n        assert not logits.is_contiguous()\n        result = entropy_from_logits(logits)\n        expected = _ref_entropy(logits)\n        torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4)\n\n    def test_non_contiguous_3d_slice(self):\n        \"\"\"Non-contiguous 3D tensor via batch slicing.\"\"\"\n        from axolotl.monkeypatch.trainer.utils import entropy_from_logits\n\n        V = 256\n        raw = torch.randn(8, 32, V, device=\"cuda\", dtype=torch.float32)\n        logits = raw[::2]  # (4, 32, V) non-contiguous\n        assert not logits.is_contiguous()\n        result = entropy_from_logits(logits)\n        expected = _ref_entropy(logits)\n        torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4)\n\n    def test_many_rows_beyond_max_grid(self):\n        \"\"\"More rows than MAX_GRID (8192) to test chunked dispatch.\"\"\"\n        from axolotl.monkeypatch.trainer.utils import entropy_from_logits\n\n        logits = torch.randn(10000, 128, device=\"cuda\", dtype=torch.float32)\n        result = entropy_from_logits(logits)\n        expected = _ref_entropy(logits)\n        torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4)\n\n    def test_entropy_non_negative(self):\n        from axolotl.monkeypatch.trainer.utils import entropy_from_logits\n\n        logits = torch.randn(32, 512, device=\"cuda\", dtype=torch.float32)\n        result = entropy_from_logits(logits)\n        assert (result >= -1e-5).all(), f\"Negative entropy: {result.min()}\"\n\n\n# ---------------------------------------------------------------------------\n# selective_log_softmax — forward correctness\n# ---------------------------------------------------------------------------\n\n\nclass TestSelectiveLogSoftmax:\n    @pytest.mark.parametrize(\n        \"B,L,K\",\n        [\n            (1, 128, 1),\n            (4, 512, 1),\n            (8, 256, 1),\n            (4, 256, 4),\n            (4, 256, 7),\n            (15, 129, 1),  # non-power-of-2\n        ],\n    )\n    def test_correctness_various_shapes(self, B, L, K):\n        from axolotl.monkeypatch.trainer.utils import selective_log_softmax\n\n        V = 1024\n        torch.manual_seed(42)\n        logits = torch.randn(B, L, V, device=\"cuda\", dtype=torch.float32)\n        if K == 1:\n            index = torch.randint(0, V, (B, L), device=\"cuda\")\n        else:\n            index = torch.randint(0, V, (B, L, K), device=\"cuda\")\n        result = selective_log_softmax(logits, index)\n        expected = _ref_selective_log_softmax(logits, index)\n        torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4)\n\n    def test_squeezed_index(self):\n        \"\"\"Index with ndim == logits.ndim - 1 triggers squeeze path.\"\"\"\n        from axolotl.monkeypatch.trainer.utils import selective_log_softmax\n\n        V = 256\n        logits = torch.randn(8, V, device=\"cuda\", dtype=torch.float32)\n        index = torch.randint(0, V, (8,), device=\"cuda\")\n        result = selective_log_softmax(logits, index)\n        expected = _ref_selective_log_softmax(logits, index)\n        assert result.shape == (8,)\n        torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4)\n\n    def test_large_vocab(self):\n        from axolotl.monkeypatch.trainer.utils import selective_log_softmax\n\n        V = 32000\n        logits = torch.randn(2, V, device=\"cuda\", dtype=torch.float32)\n        index = torch.randint(0, V, (2, 1), device=\"cuda\")\n        result = selective_log_softmax(logits, index)\n        expected = _ref_selective_log_softmax(logits, index)\n        torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4)\n\n    def test_bfloat16(self):\n        from axolotl.monkeypatch.trainer.utils import selective_log_softmax\n\n        V = 1024\n        torch.manual_seed(42)\n        logits = torch.randn(4, 128, V, device=\"cuda\", dtype=torch.bfloat16)\n        index = torch.randint(0, V, (4, 128), device=\"cuda\")\n        result = selective_log_softmax(logits, index)\n        expected = _ref_selective_log_softmax(logits.float(), index)\n        assert result.dtype == torch.bfloat16\n        torch.testing.assert_close(result.float(), expected, atol=0.1, rtol=0.1)\n\n    def test_fp32_tight_tolerance(self):\n        from axolotl.monkeypatch.trainer.utils import selective_log_softmax\n\n        V = 1024\n        torch.manual_seed(42)\n        logits = torch.randn(2, 256, V, device=\"cuda\", dtype=torch.float32)\n        index = torch.randint(0, V, (2, 256), device=\"cuda\")\n        result = selective_log_softmax(logits, index)\n        expected = _ref_selective_log_softmax(logits, index)\n        torch.testing.assert_close(result, expected, atol=1e-5, rtol=1e-5)\n\n    def test_all_same_index(self):\n        from axolotl.monkeypatch.trainer.utils import selective_log_softmax\n\n        V = 128\n        logits = torch.randn(8, V, device=\"cuda\", dtype=torch.float32)\n        index = torch.zeros(8, 1, device=\"cuda\", dtype=torch.long)\n        result = selective_log_softmax(logits, index)\n        expected = _ref_selective_log_softmax(logits, index)\n        torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4)\n\n    def test_last_index(self):\n        from axolotl.monkeypatch.trainer.utils import selective_log_softmax\n\n        V = 128\n        logits = torch.randn(8, V, device=\"cuda\", dtype=torch.float32)\n        index = torch.full((8, 1), V - 1, device=\"cuda\", dtype=torch.long)\n        result = selective_log_softmax(logits, index)\n        expected = _ref_selective_log_softmax(logits, index)\n        torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4)\n\n    def test_output_always_nonpositive(self):\n        \"\"\"Log softmax values should always be <= 0.\"\"\"\n        from axolotl.monkeypatch.trainer.utils import selective_log_softmax\n\n        V = 256\n        logits = torch.randn(32, V, device=\"cuda\", dtype=torch.float32)\n        index = torch.randint(0, V, (32, 1), device=\"cuda\")\n        result = selective_log_softmax(logits, index)\n        assert (result <= 1e-5).all(), f\"Positive log-prob: {result.max()}\"\n\n    def test_many_rows_beyond_max_grid(self):\n        from axolotl.monkeypatch.trainer.utils import selective_log_softmax\n\n        V = 128\n        logits = torch.randn(10000, V, device=\"cuda\", dtype=torch.float32)\n        index = torch.randint(0, V, (10000, 1), device=\"cuda\")\n        result = selective_log_softmax(logits, index)\n        expected = _ref_selective_log_softmax(logits, index)\n        torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4)\n\n\n# ---------------------------------------------------------------------------\n# selective_log_softmax — backward / gradient correctness\n# ---------------------------------------------------------------------------\n\n\nclass TestSelectiveLogSoftmaxBackward:\n    @pytest.mark.parametrize(\n        \"B,L,V,K\",\n        [\n            (2, 16, 64, 1),\n            (2, 16, 64, 4),\n            (1, 8, 128, 1),\n            (2, 8, 128, 7),\n        ],\n    )\n    def test_gradient_matches_reference(self, B, L, V, K):\n        from axolotl.monkeypatch.trainer.utils import selective_log_softmax\n\n        torch.manual_seed(42)\n        logits_ref = torch.randn(\n            B, L, V, device=\"cuda\", dtype=torch.float32, requires_grad=True\n        )\n        logits_tri = logits_ref.detach().clone().requires_grad_(True)\n\n        if K == 1:\n            index = torch.randint(0, V, (B, L), device=\"cuda\")\n        else:\n            index = torch.randint(0, V, (B, L, K), device=\"cuda\")\n\n        ref_out = _ref_selective_log_softmax(logits_ref, index)\n        tri_out = selective_log_softmax(logits_tri, index)\n\n        ref_out.sum().backward()\n        tri_out.sum().backward()\n\n        torch.testing.assert_close(\n            logits_tri.grad, logits_ref.grad, atol=1e-5, rtol=1e-5\n        )\n\n    def test_gradient_bfloat16_full_vocab(self):\n        from axolotl.monkeypatch.trainer.utils import selective_log_softmax\n\n        V = 4096\n        torch.manual_seed(42)\n        logits_ref = torch.randn(\n            2, 64, V, device=\"cuda\", dtype=torch.bfloat16, requires_grad=True\n        )\n        logits_tri = logits_ref.detach().clone().requires_grad_(True)\n        index = torch.randint(0, V, (2, 64), device=\"cuda\")\n\n        _ref_selective_log_softmax(logits_ref, index).sum().backward()\n        selective_log_softmax(logits_tri, index).sum().backward()\n\n        torch.testing.assert_close(\n            logits_tri.grad.float(), logits_ref.grad.float(), atol=0.1, rtol=0.1\n        )\n\n    def test_gradient_k1_squeezed(self):\n        \"\"\"Gradient with squeezed (1D) index.\"\"\"\n        from axolotl.monkeypatch.trainer.utils import selective_log_softmax\n\n        V = 256\n        logits = torch.randn(\n            8, V, device=\"cuda\", dtype=torch.float32, requires_grad=True\n        )\n        index = torch.randint(0, V, (8,), device=\"cuda\")\n\n        result = selective_log_softmax(logits, index)\n        result.sum().backward()\n        triton_grad = logits.grad.clone()\n\n        logits.grad = None\n        ref = torch.gather(\n            F.log_softmax(logits, dim=-1), dim=-1, index=index.unsqueeze(-1)\n        ).squeeze(-1)\n        ref.sum().backward()\n\n        torch.testing.assert_close(triton_grad, logits.grad, atol=1e-4, rtol=1e-4)\n\n\n# ---------------------------------------------------------------------------\n# selective_log_softmax — out-of-bounds index safety\n# ---------------------------------------------------------------------------\n\n\nclass TestSelectiveLogSoftmaxOOBSafety:\n    \"\"\"Verify that out-of-range indices don't crash or corrupt valid results.\"\"\"\n\n    def test_negative_indices_no_crash(self):\n        from axolotl.monkeypatch.trainer.utils import selective_log_softmax\n\n        V = 128\n        logits = torch.randn(4, V, device=\"cuda\", dtype=torch.float32)\n        index = torch.tensor(\n            [[-1], [0], [V - 1], [-5]], device=\"cuda\", dtype=torch.long\n        )\n        result = selective_log_softmax(logits, index)\n        assert result.shape == (4, 1)\n        # Valid rows should be finite and match reference\n        valid_idx = torch.tensor([[0], [V - 1]], device=\"cuda\", dtype=torch.long)\n        valid_logits = logits[1:3]\n        expected = _ref_selective_log_softmax(valid_logits, valid_idx)\n        torch.testing.assert_close(result[1:3], expected, atol=1e-4, rtol=1e-4)\n\n    def test_index_exceeds_vocab_no_crash(self):\n        from axolotl.monkeypatch.trainer.utils import selective_log_softmax\n\n        V = 128\n        logits = torch.randn(4, V, device=\"cuda\", dtype=torch.float32)\n        index = torch.tensor(\n            [[0], [V], [V + 100], [V - 1]], device=\"cuda\", dtype=torch.long\n        )\n        result = selective_log_softmax(logits, index)\n        assert result.shape == (4, 1)\n        # Valid rows (0 and 3) should match reference\n        for row_idx, idx_val in [(0, 0), (3, V - 1)]:\n            ref = _ref_selective_log_softmax(\n                logits[row_idx : row_idx + 1],\n                torch.tensor([[idx_val]], device=\"cuda\", dtype=torch.long),\n            )\n            torch.testing.assert_close(\n                result[row_idx : row_idx + 1], ref, atol=1e-4, rtol=1e-4\n            )\n\n    def test_mixed_valid_invalid_multi_index(self):\n        from axolotl.monkeypatch.trainer.utils import selective_log_softmax\n\n        V = 256\n        K = 3\n        logits = torch.randn(4, V, device=\"cuda\", dtype=torch.float32)\n        index = torch.tensor(\n            [\n                [0, 10, -1],  # last invalid\n                [V, 5, 100],  # first invalid\n                [50, 60, 70],  # all valid\n                [-1, V + 1, -100],  # all invalid\n            ],\n            device=\"cuda\",\n            dtype=torch.long,\n        )\n        result = selective_log_softmax(logits, index)\n        assert result.shape == (4, K)\n        # Row 2 (all valid) must match reference exactly\n        valid_index = torch.tensor([[50, 60, 70]], device=\"cuda\", dtype=torch.long)\n        expected = _ref_selective_log_softmax(logits[2:3], valid_index)\n        torch.testing.assert_close(result[2:3], expected, atol=1e-4, rtol=1e-4)\n\n    def test_oob_backward_no_crash(self):\n        \"\"\"Backward with OOB indices should not crash and grads should be finite.\"\"\"\n        from axolotl.monkeypatch.trainer.utils import selective_log_softmax\n\n        V = 128\n        logits = torch.randn(\n            4, V, device=\"cuda\", dtype=torch.float32, requires_grad=True\n        )\n        index = torch.tensor(\n            [[-1], [0], [V + 10], [V - 1]], device=\"cuda\", dtype=torch.long\n        )\n        result = selective_log_softmax(logits, index)\n        result.sum().backward()\n        assert logits.grad is not None\n        assert torch.isfinite(logits.grad).all()\n\n    def test_oob_backward_valid_rows_correct(self):\n        \"\"\"Gradients for valid-index rows should match reference even when other rows have OOB.\"\"\"\n        from axolotl.monkeypatch.trainer.utils import selective_log_softmax\n\n        V = 128\n        logits = torch.randn(\n            4, V, device=\"cuda\", dtype=torch.float32, requires_grad=True\n        )\n        # Row 0: invalid, Row 1: valid, Row 2: invalid, Row 3: valid\n        index = torch.tensor(\n            [[-1], [42], [V + 5], [100]], device=\"cuda\", dtype=torch.long\n        )\n        result = selective_log_softmax(logits, index)\n        result.sum().backward()\n\n        # Compute reference gradient for valid rows only\n        logits_ref = logits.detach().clone().requires_grad_(True)\n        valid_rows = [1, 3]\n        valid_indices = [42, 100]\n        for r, idx in zip(valid_rows, valid_indices, strict=True):\n            ref_lp = F.log_softmax(logits_ref[r : r + 1], dim=-1)\n            ref_val = ref_lp[0, idx]\n            ref_val.backward(retain_graph=True)\n\n        for r in valid_rows:\n            torch.testing.assert_close(\n                logits.grad[r], logits_ref.grad[r], atol=1e-4, rtol=1e-4\n            )\n"
  },
  {
    "path": "tests/test_utils_tee.py",
    "content": "import os\nimport tempfile\n\n\ndef _dummy_cfg(output_dir: str, append: bool = False):\n    # Minimal object with attributes used by prepare_debug_log\n    class Cfg:\n        def __init__(self, out, append):\n            self.output_dir = out\n            self._append = append\n\n        def get(self, key, default=None):\n            if key in {\"resume_from_checkpoint\", \"auto_resume_from_checkpoints\"}:\n                return self._append\n            return default\n\n    return Cfg(output_dir, append)\n\n\ndef read(path: str) -> str:\n    with open(path, \"r\", encoding=\"utf-8\") as f:\n        return f.read()\n\n\ndef test_file_only_stream_writes_after_prepare(monkeypatch):\n    from axolotl.utils import tee\n\n    with tempfile.TemporaryDirectory() as td:\n        # Avoid stdout tee in this test\n        monkeypatch.setenv(\"AXOLOTL_TEE_STDOUT\", \"0\")\n        cfg = _dummy_cfg(td, append=False)\n\n        # before prepare: writing to file_only_stream creates no file\n        tee.file_only_stream.write(\"before\\n\")\n        tee.file_only_stream.flush()\n        assert not os.path.exists(os.path.join(td, \"debug.log\"))\n\n        # prepare and write\n        path = tee.prepare_debug_log(cfg)\n        assert os.path.basename(path) == \"debug.log\"\n        tee.file_only_stream.write(\"hello\\n\")\n        tee.file_only_stream.flush()\n\n        content = read(path)\n        assert \"hello\" in content\n\n        tee.close_debug_log()\n\n\ndef test_stdout_is_mirrored_after_prepare(capsys, monkeypatch):\n    from axolotl.utils import tee\n\n    with tempfile.TemporaryDirectory() as td:\n        cfg = _dummy_cfg(td, append=False)\n        try:\n            # Install tee while capture is disabled so stdout tee wraps real stdout.\n            with capsys.disabled():\n                monkeypatch.setenv(\"AXOLOTL_TEE_STDOUT\", \"1\")\n                path = tee.prepare_debug_log(cfg)\n                import sys\n\n                print(\"printed-line\")\n                sys.stdout.flush()\n\n            # Now verify file contains the line\n            content = read(path)\n            assert \"printed-line\" in content\n        finally:\n            tee.close_debug_log()\n\n\ndef test_truncate_vs_append_behavior(monkeypatch):\n    from axolotl.utils import tee\n\n    with tempfile.TemporaryDirectory() as td:\n        # Avoid stdout tee in this test\n        monkeypatch.setenv(\"AXOLOTL_TEE_STDOUT\", \"0\")\n        # First run creates file with A\n        cfg = _dummy_cfg(td, append=False)\n        _ = tee.prepare_debug_log(cfg)\n        try:\n            tee.file_only_stream.write(\"A\\n\")\n            tee.file_only_stream.flush()\n        finally:\n            tee.close_debug_log()\n\n        # Second run with append=False truncates\n        cfg2 = _dummy_cfg(td, append=False)\n        path2 = tee.prepare_debug_log(cfg2)\n        try:\n            tee.file_only_stream.write(\"B\\n\")\n            tee.file_only_stream.flush()\n            content = read(path2)\n            assert \"A\\n\" not in content and \"B\\n\" in content\n        finally:\n            tee.close_debug_log()\n\n        # Third run with append=True preserves existing\n        cfg3 = _dummy_cfg(td, append=True)\n        path3 = tee.prepare_debug_log(cfg3)\n        try:\n            tee.file_only_stream.write(\"C\\n\")\n            tee.file_only_stream.flush()\n            content = read(path3)\n            assert \"B\\n\" in content and \"C\\n\" in content\n        finally:\n            tee.close_debug_log()\n"
  },
  {
    "path": "tests/test_validation_dataset.py",
    "content": "\"\"\"Module for testing the validation module for the dataset config\"\"\"\n\nimport warnings\nfrom typing import Optional\n\nimport pytest\n\nfrom axolotl.utils.config import validate_config\nfrom axolotl.utils.dict import DictDefault\nfrom axolotl.utils.schemas.datasets import ChatTemplate\n\nwarnings.filterwarnings(\"error\")\n\n\n@pytest.fixture(name=\"minimal_cfg\")\ndef fixture_cfg():\n    return DictDefault(\n        {\n            \"base_model\": \"TinyLlama/TinyLlama-1.1B-Chat-v0.6\",\n            \"learning_rate\": 0.000001,\n            \"micro_batch_size\": 1,\n            \"gradient_accumulation_steps\": 1,\n        }\n    )\n\n\nclass BaseValidation:\n    \"\"\"\n    Base validation module to setup the log capture\n    \"\"\"\n\n    _caplog: Optional[pytest.LogCaptureFixture] = None\n\n    @pytest.fixture(autouse=True)\n    def inject_fixtures(self, caplog):\n        self._caplog = caplog\n\n\nclass TestValidationCheckDatasetConfig(BaseValidation):\n    \"\"\"\n    Test the validation for the dataset config to ensure no correct parameters are dropped\n    \"\"\"\n\n    def test_dataset_config_no_drop_param(self, minimal_cfg):\n        cfg = DictDefault(\n            minimal_cfg\n            | {\n                \"datasets\": [\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": \"alpaca\",\n                        \"shards\": 10,\n                    }\n                ]\n            }\n        )\n\n        checked_cfg = validate_config(cfg)\n\n        def _check_config():\n            assert checked_cfg.datasets[0].path == cfg.datasets[0].path\n            assert checked_cfg.datasets[0].type == cfg.datasets[0].type\n            assert checked_cfg.datasets[0].shards == cfg.datasets[0].shards\n\n        _check_config()\n\n        checked_cfg = validate_config(\n            cfg,\n            capabilities={\n                \"bf16\": \"false\",\n                \"tf32\": \"false\",\n                \"n_gpu\": 1,\n                \"compute_capability\": \"8.0\",\n            },\n            env_capabilities={\n                \"torch_version\": \"2.6.0\",\n            },\n        )\n\n        _check_config()\n\n    def test_dataset_default_chat_template_no_drop_param(self, minimal_cfg):\n        cfg = DictDefault(\n            minimal_cfg\n            | {\n                \"datasets\": [\n                    {\n                        \"path\": \"LDJnr/Puffin\",\n                        \"type\": \"chat_template\",\n                        \"field_messages\": \"conversations\",\n                        \"shards\": 10,\n                        \"message_field_role\": \"from\",\n                        \"message_field_content\": \"value\",\n                    }\n                ],\n            }\n        )\n\n        checked_cfg = validate_config(cfg)\n\n        def _check_config():\n            assert checked_cfg.datasets[0].path == cfg.datasets[0].path\n            assert checked_cfg.datasets[0].type == cfg.datasets[0].type\n            assert checked_cfg.chat_template is None\n            assert (\n                checked_cfg.datasets[0].chat_template == ChatTemplate.tokenizer_default\n            )\n            assert (\n                checked_cfg.datasets[0].field_messages == cfg.datasets[0].field_messages\n            )\n            assert checked_cfg.datasets[0].shards == cfg.datasets[0].shards\n            assert (\n                checked_cfg.datasets[0].message_field_role\n                == cfg.datasets[0].message_field_role\n            )\n            assert (\n                checked_cfg.datasets[0].message_field_content\n                == cfg.datasets[0].message_field_content\n            )\n\n        _check_config()\n\n        checked_cfg = validate_config(\n            cfg,\n            capabilities={\n                \"bf16\": \"false\",\n                \"n_gpu\": 1,\n                \"compute_capability\": \"8.0\",\n            },\n            env_capabilities={\n                \"torch_version\": \"2.6.0\",\n            },\n        )\n\n        _check_config()\n\n    def test_dataset_partial_default_chat_template_no_drop_param(self, minimal_cfg):\n        cfg = DictDefault(\n            minimal_cfg\n            | {\n                \"chat_template\": \"chatml\",\n                \"datasets\": [\n                    {\n                        \"path\": \"LDJnr/Puffin\",\n                        \"type\": \"chat_template\",\n                        \"field_messages\": \"conversations\",\n                        \"shards\": 10,\n                        \"message_field_role\": \"from\",\n                        \"message_field_content\": \"value\",\n                    }\n                ],\n            }\n        )\n\n        checked_cfg = validate_config(cfg)\n\n        def _check_config():\n            assert checked_cfg.datasets[0].path == cfg.datasets[0].path\n            assert checked_cfg.datasets[0].type == cfg.datasets[0].type\n            assert checked_cfg.chat_template == ChatTemplate.chatml\n            assert (\n                checked_cfg.datasets[0].chat_template == ChatTemplate.tokenizer_default\n            )\n            assert (\n                checked_cfg.datasets[0].field_messages == cfg.datasets[0].field_messages\n            )\n            assert checked_cfg.datasets[0].shards == cfg.datasets[0].shards\n            assert (\n                checked_cfg.datasets[0].message_field_role\n                == cfg.datasets[0].message_field_role\n            )\n            assert (\n                checked_cfg.datasets[0].message_field_content\n                == cfg.datasets[0].message_field_content\n            )\n\n        _check_config()\n\n        checked_cfg = validate_config(\n            cfg,\n            capabilities={\n                \"bf16\": \"false\",\n                \"n_gpu\": 1,\n                \"compute_capability\": \"8.0\",\n            },\n            env_capabilities={\n                \"torch_version\": \"2.6.0\",\n            },\n        )\n\n        _check_config()\n\n    def test_dataset_chatml_chat_template_no_drop_param(self, minimal_cfg):\n        cfg = DictDefault(\n            minimal_cfg\n            | {\n                \"chat_template\": \"chatml\",\n                \"datasets\": [\n                    {\n                        \"path\": \"LDJnr/Puffin\",\n                        \"type\": \"chat_template\",\n                        \"chat_template\": \"gemma\",\n                        \"field_messages\": \"conversations\",\n                        \"shards\": 10,\n                        \"message_field_role\": \"from\",\n                        \"message_field_content\": \"value\",\n                    }\n                ],\n            }\n        )\n\n        checked_cfg = validate_config(cfg)\n\n        def _check_config():\n            assert checked_cfg.datasets[0].path == cfg.datasets[0].path\n            assert checked_cfg.datasets[0].type == cfg.datasets[0].type\n            assert checked_cfg.chat_template == cfg.chat_template\n            assert (\n                checked_cfg.datasets[0].chat_template == cfg.datasets[0].chat_template\n            )\n            assert (\n                checked_cfg.datasets[0].field_messages == cfg.datasets[0].field_messages\n            )\n            assert checked_cfg.datasets[0].shards == cfg.datasets[0].shards\n            assert (\n                checked_cfg.datasets[0].message_field_role\n                == cfg.datasets[0].message_field_role\n            )\n            assert (\n                checked_cfg.datasets[0].message_field_content\n                == cfg.datasets[0].message_field_content\n            )\n\n        _check_config()\n\n        checked_cfg = validate_config(\n            cfg,\n            capabilities={\n                \"bf16\": \"false\",\n                \"n_gpu\": 1,\n                \"compute_capability\": \"8.0\",\n            },\n            env_capabilities={\n                \"torch_version\": \"2.6.0\",\n            },\n        )\n\n        _check_config()\n\n    def test_dataset_sharegpt_deprecation(self, minimal_cfg):\n        cfg = DictDefault(\n            minimal_cfg\n            | {\n                \"chat_template\": \"chatml\",\n                \"datasets\": [\n                    {\n                        \"path\": \"LDJnr/Puffin\",\n                        \"type\": \"sharegpt\",\n                        \"conversation\": \"chatml\",\n                    }\n                ],\n            }\n        )\n\n        # Check sharegpt deprecation is raised\n        with pytest.raises(ValueError, match=r\".*type: sharegpt.*` is deprecated.*\"):\n            validate_config(cfg)\n\n        # Check that deprecation is not thrown for non-str type\n        cfg = DictDefault(\n            minimal_cfg\n            | {\n                \"datasets\": [\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": {\n                            \"field_instruction\": \"instruction\",\n                            \"field_output\": \"output\",\n                            \"field_system\": \"system\",\n                            \"format\": \"<|user|> {instruction} {input} <|model|>\",\n                            \"no_input_format\": \"<|user|> {instruction} <|model|>\",\n                            \"system_prompt\": \"\",\n                        },\n                    }\n                ],\n            }\n        )\n\n        validate_config(cfg)\n\n        # Check that deprecation is not thrown for non-sharegpt type\n        cfg = DictDefault(\n            minimal_cfg\n            | {\n                \"datasets\": [\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": \"alpaca\",\n                    }\n                ],\n            }\n        )\n\n        validate_config(cfg)\n\n    def test_message_property_mappings(self, minimal_cfg):\n        cfg = DictDefault(\n            minimal_cfg\n            | {\n                \"datasets\": [\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": \"alpaca\",\n                        \"message_property_mappings\": {\n                            \"role\": \"role\",\n                            \"content\": \"content\",\n                        },\n                    }\n                ],\n            }\n        )\n\n        validate_config(cfg)\n\n\nclass TestOptimizerValidation(BaseValidation):\n    \"\"\"\n    Test muon optimizer validation\n    \"\"\"\n\n    def test_muon_deepspeed(self, minimal_cfg):\n        cfg = DictDefault(\n            minimal_cfg\n            | {\n                \"datasets\": [\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": \"alpaca\",\n                    }\n                ],\n                \"optimizer\": \"muon\",\n                \"deepspeed\": \"deepspeed_configs/zero3.json\",\n            }\n        )\n\n        with pytest.raises(ValueError, match=r\".*is currently incompatible with*\"):\n            validate_config(cfg)\n\n    def test_muon_fsdp(self, minimal_cfg):\n        cfg = DictDefault(\n            minimal_cfg\n            | {\n                \"datasets\": [\n                    {\n                        \"path\": \"mhenrichsen/alpaca_2k_test\",\n                        \"type\": \"alpaca\",\n                    }\n                ],\n                \"optimizer\": \"muon\",\n                \"fsdp\": [\"full_shard\"],\n                \"fsdp_config\": {\n                    \"fsdp_auto_wrap_policy\": \"TRANSFORMER_BASED_WRAP\",\n                },\n            }\n        )\n\n        with pytest.raises(ValueError, match=r\".*only compatible with FSDP2.*\"):\n            validate_config(cfg)\n"
  },
  {
    "path": "tests/utils/callbacks/test_dynamic_checkpoint.py",
    "content": "\"\"\"Unit tests for dynamic checkpoint callback\"\"\"\n\nimport tempfile\nfrom pathlib import Path\nfrom unittest.mock import MagicMock, Mock, patch\n\nfrom axolotl.utils.callbacks.dynamic_checkpoint import (\n    DEFAULT_TRIGGER_FILENAME,\n    DynamicCheckpointCallback,\n)\nfrom axolotl.utils.dict import DictDefault\n\n\nclass TestDynamicCheckpointCallbackInit:\n    \"\"\"Test callback initialization\"\"\"\n\n    def test_callback_disabled_by_default(self):\n        \"\"\"Test that callback is disabled when config.enabled=False\"\"\"\n        with tempfile.TemporaryDirectory() as tmpdir:\n            cfg = DictDefault(\n                {\n                    \"dynamic_checkpoint\": {\"enabled\": False},\n                    \"output_dir\": tmpdir,\n                }\n            )\n            callback = DynamicCheckpointCallback(cfg)\n            assert callback.enabled is False\n\n    def test_callback_disabled_when_none(self):\n        \"\"\"Test that callback is disabled when dynamic_checkpoint is None\"\"\"\n        with tempfile.TemporaryDirectory() as tmpdir:\n            cfg = DictDefault(\n                {\n                    \"dynamic_checkpoint\": None,\n                    \"output_dir\": tmpdir,\n                }\n            )\n            callback = DynamicCheckpointCallback(cfg)\n            assert callback.enabled is False\n\n    def test_callback_enabled_when_configured(self):\n        \"\"\"Test that callback is enabled when config.enabled=True\"\"\"\n        with tempfile.TemporaryDirectory() as tmpdir:\n            cfg = DictDefault(\n                {\n                    \"dynamic_checkpoint\": {\"enabled\": True, \"check_interval\": 10},\n                    \"output_dir\": tmpdir,\n                }\n            )\n            callback = DynamicCheckpointCallback(cfg)\n            assert callback.enabled is True\n            assert callback.check_interval == 10\n\n    def test_default_trigger_filename(self):\n        \"\"\"Test that default trigger filename is used\"\"\"\n        with tempfile.TemporaryDirectory() as tmpdir:\n            cfg = DictDefault(\n                {\n                    \"dynamic_checkpoint\": {\"enabled\": True, \"check_interval\": 10},\n                    \"output_dir\": tmpdir,\n                }\n            )\n            callback = DynamicCheckpointCallback(cfg)\n            assert callback.trigger_filename == DEFAULT_TRIGGER_FILENAME\n\n    def test_check_interval_default(self):\n        \"\"\"Test default check interval\"\"\"\n        with tempfile.TemporaryDirectory() as tmpdir:\n            cfg = DictDefault(\n                {\n                    \"dynamic_checkpoint\": {\"enabled\": True},\n                    \"output_dir\": tmpdir,\n                }\n            )\n            callback = DynamicCheckpointCallback(cfg)\n            assert callback.check_interval == 100  # Default from schema\n\n\nclass TestDynamicCheckpointFileDetection:\n    \"\"\"Test file-based checkpoint triggering\"\"\"\n\n    def test_trigger_file_detected_and_deleted(self):\n        \"\"\"Test that trigger file is detected and deleted\"\"\"\n        with tempfile.TemporaryDirectory() as tmpdir:\n            cfg = DictDefault(\n                {\n                    \"dynamic_checkpoint\": {\"enabled\": True, \"check_interval\": 1},\n                    \"output_dir\": tmpdir,\n                }\n            )\n            callback = DynamicCheckpointCallback(cfg)\n\n            trigger_file = Path(tmpdir) / DEFAULT_TRIGGER_FILENAME\n            trigger_file.touch()\n            assert trigger_file.exists()\n\n            args = Mock(output_dir=tmpdir)\n            state = Mock(global_step=1)\n            control = Mock(should_save=False)\n\n            with patch(\n                \"axolotl.utils.callbacks.dynamic_checkpoint.is_main_process\",\n                return_value=True,\n            ):\n                with patch(\n                    \"axolotl.utils.callbacks.dynamic_checkpoint.is_distributed\",\n                    return_value=False,\n                ):\n                    result = callback.on_step_end(args, state, control)\n\n            assert not trigger_file.exists()\n            assert result.should_save is True\n\n    def test_check_interval_honored(self):\n        \"\"\"Test that file is only checked at check_interval steps\"\"\"\n        with tempfile.TemporaryDirectory() as tmpdir:\n            cfg = DictDefault(\n                {\n                    \"dynamic_checkpoint\": {\"enabled\": True, \"check_interval\": 10},\n                    \"output_dir\": tmpdir,\n                }\n            )\n            callback = DynamicCheckpointCallback(cfg)\n\n            args = Mock(output_dir=tmpdir)\n            control = Mock(should_save=False)\n\n            trigger_file = Path(tmpdir) / DEFAULT_TRIGGER_FILENAME\n            trigger_file.touch()\n\n            with patch(\n                \"axolotl.utils.callbacks.dynamic_checkpoint.is_main_process\",\n                return_value=True,\n            ):\n                with patch(\n                    \"axolotl.utils.callbacks.dynamic_checkpoint.is_distributed\",\n                    return_value=False,\n                ):\n                    # Step 5 - shouldn't check (not divisible by 10)\n                    state = Mock(global_step=5)\n                    result = callback.on_step_end(args, state, control)\n                    assert trigger_file.exists()  # Still there\n                    assert result.should_save is False\n\n                    # Step 10 - should check\n                    state = Mock(global_step=10)\n                    result = callback.on_step_end(args, state, control)\n                    assert not trigger_file.exists()  # Deleted\n                    assert result.should_save is True\n\n    def test_no_file_no_trigger(self):\n        \"\"\"Test that no trigger occurs when file doesn't exist\"\"\"\n        with tempfile.TemporaryDirectory() as tmpdir:\n            cfg = DictDefault(\n                {\n                    \"dynamic_checkpoint\": {\"enabled\": True, \"check_interval\": 1},\n                    \"output_dir\": tmpdir,\n                }\n            )\n            callback = DynamicCheckpointCallback(cfg)\n\n            args = Mock(output_dir=tmpdir)\n            state = Mock(global_step=1)\n            control = Mock(should_save=False)\n\n            with patch(\n                \"axolotl.utils.callbacks.dynamic_checkpoint.is_main_process\",\n                return_value=True,\n            ):\n                with patch(\n                    \"axolotl.utils.callbacks.dynamic_checkpoint.is_distributed\",\n                    return_value=False,\n                ):\n                    result = callback.on_step_end(args, state, control)\n\n            assert result.should_save is False\n\n    def test_file_deletion_error_handling(self):\n        \"\"\"Test that file deletion errors are handled gracefully\"\"\"\n        with tempfile.TemporaryDirectory() as tmpdir:\n            cfg = DictDefault(\n                {\n                    \"dynamic_checkpoint\": {\"enabled\": True, \"check_interval\": 1},\n                    \"output_dir\": tmpdir,\n                }\n            )\n            callback = DynamicCheckpointCallback(cfg)\n\n            trigger_file = Path(tmpdir) / DEFAULT_TRIGGER_FILENAME\n            trigger_file.touch()\n\n            args = Mock(output_dir=tmpdir)\n            state = Mock(global_step=1)\n            control = Mock(should_save=False)\n\n            with patch(\n                \"axolotl.utils.callbacks.dynamic_checkpoint.is_main_process\",\n                return_value=True,\n            ):\n                with patch(\n                    \"axolotl.utils.callbacks.dynamic_checkpoint.is_distributed\",\n                    return_value=False,\n                ):\n                    with patch.object(\n                        Path, \"unlink\", side_effect=OSError(\"Permission denied\")\n                    ):\n                        result = callback.on_step_end(args, state, control)\n\n            assert result.should_save is True\n\n\nclass TestDynamicCheckpointMultiGPU:\n    \"\"\"Test multi-GPU synchronization\"\"\"\n\n    def test_only_rank_0_checks_file(self):\n        \"\"\"Test that only rank 0 checks filesystem in multi-GPU setup\"\"\"\n        with tempfile.TemporaryDirectory() as tmpdir:\n            cfg = DictDefault(\n                {\n                    \"dynamic_checkpoint\": {\"enabled\": True, \"check_interval\": 1},\n                    \"output_dir\": tmpdir,\n                }\n            )\n            callback = DynamicCheckpointCallback(cfg)\n\n            trigger_file = Path(tmpdir) / DEFAULT_TRIGGER_FILENAME\n            trigger_file.touch()\n\n            args = Mock(output_dir=tmpdir)\n            state = Mock(global_step=1)\n            control = Mock(should_save=False)\n\n            # Rank 1 (not main process) - shouldn't check file\n            with patch(\n                \"axolotl.utils.callbacks.dynamic_checkpoint.is_main_process\",\n                return_value=False,\n            ):\n                with patch(\n                    \"axolotl.utils.callbacks.dynamic_checkpoint.is_distributed\",\n                    return_value=True,\n                ):\n                    with patch(\"torch.distributed.broadcast\") as mock_broadcast:\n                        with patch(\n                            \"axolotl.utils.callbacks.dynamic_checkpoint.barrier\"\n                        ):\n                            mock_tensor = MagicMock()\n                            mock_tensor.item.return_value = 0\n                            with patch(\"torch.tensor\", return_value=mock_tensor):\n                                callback.on_step_end(args, state, control)\n\n            assert trigger_file.exists()\n            # Broadcast should have been called\n            assert mock_broadcast.called\n\n    def test_broadcast_synchronization(self):\n        \"\"\"Test that trigger decision is broadcasted to all ranks\"\"\"\n        with tempfile.TemporaryDirectory() as tmpdir:\n            cfg = DictDefault(\n                {\n                    \"dynamic_checkpoint\": {\"enabled\": True, \"check_interval\": 1},\n                    \"output_dir\": tmpdir,\n                }\n            )\n            callback = DynamicCheckpointCallback(cfg)\n\n            trigger_file = Path(tmpdir) / DEFAULT_TRIGGER_FILENAME\n            trigger_file.touch()\n\n            args = Mock(output_dir=tmpdir)\n            state = Mock(global_step=1)\n            control = Mock(should_save=False)\n\n            # Rank 0 detects file\n            with patch(\n                \"axolotl.utils.callbacks.dynamic_checkpoint.is_main_process\",\n                return_value=True,\n            ):\n                with patch(\n                    \"axolotl.utils.callbacks.dynamic_checkpoint.is_distributed\",\n                    return_value=True,\n                ):\n                    with patch(\"torch.distributed.broadcast\") as mock_broadcast:\n                        with patch(\n                            \"axolotl.utils.callbacks.dynamic_checkpoint.barrier\"\n                        ) as mock_barrier:\n                            mock_tensor = MagicMock()\n                            mock_tensor.item.return_value = 1\n                            with patch(\"torch.tensor\", return_value=mock_tensor):\n                                with patch(\"torch.cuda.current_device\", return_value=0):\n                                    result = callback.on_step_end(args, state, control)\n\n            assert mock_broadcast.called\n            assert mock_barrier.called\n            # All ranks should trigger\n            assert result.should_save is True\n\n\nclass TestDynamicCheckpointSignalHandling:\n    \"\"\"Test signal-based checkpoint triggering\"\"\"\n\n    def test_signal_trigger_via_callback(self):\n        \"\"\"Test that signal flag triggers checkpoint save\"\"\"\n        with tempfile.TemporaryDirectory() as tmpdir:\n            cfg = DictDefault(\n                {\n                    \"dynamic_checkpoint\": {\n                        \"enabled\": True,\n                        \"check_interval\": 1,\n                        \"enable_signal\": True,\n                    },\n                    \"output_dir\": tmpdir,\n                }\n            )\n\n            with patch(\"signal.signal\"):\n                with patch(\n                    \"axolotl.utils.callbacks.dynamic_checkpoint.is_main_process\",\n                    return_value=True,\n                ):\n                    with patch(\n                        \"axolotl.utils.callbacks.dynamic_checkpoint.hasattr\",\n                        return_value=True,\n                    ):\n                        callback = DynamicCheckpointCallback(cfg)\n\n            callback.should_save_checkpoint = True\n\n            args = Mock(output_dir=tmpdir)\n            state = Mock(global_step=1)\n            control = Mock(should_save=False)\n\n            with patch(\n                \"axolotl.utils.callbacks.dynamic_checkpoint.is_main_process\",\n                return_value=True,\n            ):\n                with patch(\n                    \"axolotl.utils.callbacks.dynamic_checkpoint.is_distributed\",\n                    return_value=False,\n                ):\n                    result = callback.on_step_end(args, state, control)\n\n            assert result.should_save is True\n            assert callback.should_save_checkpoint is False\n\n    def test_signal_not_registered_when_disabled(self):\n        \"\"\"Test that signal handler is not registered when disabled\"\"\"\n        with tempfile.TemporaryDirectory() as tmpdir:\n            cfg = DictDefault(\n                {\n                    \"dynamic_checkpoint\": {\n                        \"enabled\": True,\n                        \"check_interval\": 10,\n                        \"enable_signal\": False,\n                    },\n                    \"output_dir\": tmpdir,\n                }\n            )\n\n            with patch(\"signal.signal\") as mock_signal_register:\n                _ = DynamicCheckpointCallback(cfg)\n\n            assert not mock_signal_register.called\n\n\nclass TestDynamicCheckpointDisabled:\n    \"\"\"Test behavior when callback is disabled\"\"\"\n\n    def test_disabled_callback_does_nothing(self):\n        \"\"\"Test that disabled callback doesn't check or trigger\"\"\"\n        with tempfile.TemporaryDirectory() as tmpdir:\n            cfg = DictDefault(\n                {\n                    \"dynamic_checkpoint\": {\"enabled\": False},\n                    \"output_dir\": tmpdir,\n                }\n            )\n            callback = DynamicCheckpointCallback(cfg)\n\n            trigger_file = Path(tmpdir) / DEFAULT_TRIGGER_FILENAME\n            trigger_file.touch()\n\n            args = Mock(output_dir=tmpdir)\n            state = Mock(global_step=1)\n            control = Mock(should_save=False)\n\n            result = callback.on_step_end(args, state, control)\n\n            assert trigger_file.exists()\n            assert result.should_save is False\n"
  },
  {
    "path": "tests/utils/data/test_utils.py",
    "content": "\"\"\"\nUnit tests for data utility functions\n\"\"\"\n\nimport unittest\nfrom unittest.mock import MagicMock\n\nfrom datasets import Dataset\n\nfrom axolotl.utils.data.utils import handle_long_seq_in_dataset\nfrom axolotl.utils.dict import DictDefault\n\n\nclass TestHandleLongSeqInDataset(unittest.TestCase):\n    \"\"\"\n    Test class for handle_long_seq_in_dataset function\n    \"\"\"\n\n    def test_drop_strategy_removes_long_sequences(self):\n        \"\"\"Test that 'drop' strategy removes sequences longer than sequence_len\"\"\"\n        # Create dataset with mixed length sequences\n        dataset = Dataset.from_dict(\n            {\n                \"input_ids\": [\n                    [1, 2, 3],  # length 3 - keep\n                    [1, 2, 3, 4, 5],  # length 5 - keep\n                    [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],  # length 11 - drop\n                    [1, 2],  # length 2 - keep\n                ]\n            }\n        )\n\n        cfg = DictDefault(\n            {\n                \"excess_length_strategy\": \"drop\",\n                \"min_sample_len\": 2,\n                \"dataset_num_proc\": None,\n                \"is_preprocess\": False,\n            }\n        )\n\n        result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)\n\n        # Should have dropped the sequence with length 11\n        self.assertEqual(len(result), 3)\n        self.assertEqual(len(result[0][\"input_ids\"]), 3)\n        self.assertEqual(len(result[1][\"input_ids\"]), 5)\n        self.assertEqual(len(result[2][\"input_ids\"]), 2)\n\n    def test_drop_strategy_is_default(self):\n        \"\"\"Test that 'drop' is the default strategy when not specified\"\"\"\n        dataset = Dataset.from_dict(\n            {\n                \"input_ids\": [\n                    [1, 2, 3],\n                    [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],  # length 11 - should drop\n                ]\n            }\n        )\n\n        cfg = DictDefault(\n            {\n                \"min_sample_len\": 2,\n                \"dataset_num_proc\": None,\n                \"is_preprocess\": False,\n            }\n        )\n\n        result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)\n\n        # Should have dropped the long sequence\n        self.assertEqual(len(result), 1)\n\n    def test_truncate_strategy_truncates_long_sequences(self):\n        \"\"\"Test that 'truncate' strategy truncates sequences to sequence_len\"\"\"\n        dataset = Dataset.from_dict(\n            {\n                \"input_ids\": [\n                    [1, 2, 3],  # length 3 - keep as is\n                    [\n                        1,\n                        2,\n                        3,\n                        4,\n                        5,\n                        6,\n                        7,\n                        8,\n                        9,\n                        10,\n                        11,\n                        12,\n                    ],  # length 12 - truncate to 10\n                ]\n            }\n        )\n\n        cfg = DictDefault(\n            {\n                \"excess_length_strategy\": \"truncate\",\n                \"min_sample_len\": 2,\n                \"dataset_num_proc\": None,\n                \"is_preprocess\": False,\n            }\n        )\n\n        result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)\n\n        # Should have 2 samples\n        self.assertEqual(len(result), 2)\n        # First sample unchanged\n        self.assertEqual(len(result[0][\"input_ids\"]), 3)\n        # Second sample truncated to 10\n        self.assertEqual(len(result[1][\"input_ids\"]), 10)\n        self.assertEqual(result[1][\"input_ids\"], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])\n\n    def test_truncate_strategy_truncates_all_auxiliary_fields(self):\n        \"\"\"Test that truncation applies to all auxiliary fields consistently\"\"\"\n        dataset = Dataset.from_dict(\n            {\n                \"input_ids\": [\n                    [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],\n                ],\n                \"attention_mask\": [\n                    [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n                ],\n                \"labels\": [\n                    [-100, -100, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],\n                ],\n                \"position_ids\": [\n                    [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],\n                ],\n            }\n        )\n\n        cfg = DictDefault(\n            {\n                \"excess_length_strategy\": \"truncate\",\n                \"min_sample_len\": 2,\n                \"dataset_num_proc\": None,\n                \"is_preprocess\": False,\n            }\n        )\n\n        result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)\n\n        # All fields should be truncated to 10\n        self.assertEqual(len(result[0][\"input_ids\"]), 10)\n        self.assertEqual(len(result[0][\"attention_mask\"]), 10)\n        self.assertEqual(len(result[0][\"labels\"]), 10)\n        self.assertEqual(len(result[0][\"position_ids\"]), 10)\n\n        # Verify content is correct\n        self.assertEqual(result[0][\"input_ids\"], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])\n        self.assertEqual(result[0][\"attention_mask\"], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1])\n        self.assertEqual(result[0][\"labels\"], [-100, -100, 3, 4, 5, 6, 7, 8, 9, 10])\n        self.assertEqual(result[0][\"position_ids\"], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9])\n\n    def test_raise_strategy_raises_on_long_sequences(self):\n        \"\"\"Test that 'raise' strategy raises ValueError when encountering long sequences\"\"\"\n        dataset = Dataset.from_dict(\n            {\n                \"input_ids\": [\n                    [1, 2, 3],\n                    [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],  # length 11 - should raise\n                ]\n            }\n        )\n\n        cfg = DictDefault(\n            {\n                \"excess_length_strategy\": \"raise\",\n                \"min_sample_len\": 2,\n                \"dataset_num_proc\": None,\n                \"is_preprocess\": False,\n            }\n        )\n\n        with self.assertRaises(ValueError):\n            handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)\n\n    def test_min_sequence_len_filters_short_sequences(self):\n        \"\"\"Test that sequences shorter than min_sample_len are filtered out\"\"\"\n        dataset = Dataset.from_dict(\n            {\n                \"input_ids\": [\n                    [1],  # length 1 - drop (< min_sample_len=3)\n                    [1, 2],  # length 2 - drop\n                    [1, 2, 3],  # length 3 - keep\n                    [1, 2, 3, 4, 5],  # length 5 - keep\n                ]\n            }\n        )\n\n        cfg = DictDefault(\n            {\n                \"excess_length_strategy\": \"drop\",\n                \"min_sample_len\": 3,\n                \"dataset_num_proc\": None,\n                \"is_preprocess\": False,\n            }\n        )\n\n        result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)\n\n        # Should only keep sequences with length >= 3\n        self.assertEqual(len(result), 2)\n        self.assertEqual(len(result[0][\"input_ids\"]), 3)\n        self.assertEqual(len(result[1][\"input_ids\"]), 5)\n\n    def test_dataset_without_input_ids_column(self):\n        \"\"\"Test that datasets without 'input_ids' column are returned unchanged\"\"\"\n        dataset = Dataset.from_dict(\n            {\n                \"chosen\": [1, 2, 3],\n                \"rejected\": [4, 5, 6],\n            }\n        )\n\n        cfg = DictDefault(\n            {\n                \"excess_length_strategy\": \"drop\",\n                \"min_sample_len\": 2,\n            }\n        )\n\n        result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)\n\n        # Dataset should be unchanged\n        self.assertEqual(len(result), len(dataset))\n        self.assertListEqual(list(result.column_names), [\"chosen\", \"rejected\"])\n\n    def test_truncate_filters_short_before_truncating(self):\n        \"\"\"Test that truncate strategy filters short sequences before truncating long ones\n\n        This is important for efficiency - we should not waste time truncating\n        sequences that will be filtered out anyway.\n        \"\"\"\n        dataset = Dataset.from_dict(\n            {\n                \"input_ids\": [\n                    [1],  # length 1 - filter out first\n                    [1, 2, 3],  # length 3 - keep, no truncation needed\n                    [\n                        1,\n                        2,\n                        3,\n                        4,\n                        5,\n                        6,\n                        7,\n                        8,\n                        9,\n                        10,\n                        11,\n                        12,\n                    ],  # length 12 - keep and truncate\n                ]\n            }\n        )\n\n        cfg = DictDefault(\n            {\n                \"excess_length_strategy\": \"truncate\",\n                \"min_sample_len\": 2,\n                \"dataset_num_proc\": None,\n                \"is_preprocess\": False,\n            }\n        )\n\n        result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)\n\n        # Should have filtered out the first (short) sequence\n        self.assertEqual(len(result), 2)\n        # Second sample unchanged\n        self.assertEqual(len(result[0][\"input_ids\"]), 3)\n        # Third sample truncated to 10\n        self.assertEqual(len(result[1][\"input_ids\"]), 10)\n\n    def test_case_insensitive_strategy(self):\n        \"\"\"Test that excess_length_strategy is case-insensitive\"\"\"\n        dataset = Dataset.from_dict(\n            {\n                \"input_ids\": [\n                    [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],\n                ]\n            }\n        )\n\n        cfg = DictDefault(\n            {\n                \"excess_length_strategy\": \"TRUNCATE\",  # uppercase\n                \"min_sample_len\": 2,\n                \"dataset_num_proc\": None,\n                \"is_preprocess\": False,\n            }\n        )\n\n        result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)\n\n        # Should still truncate\n        self.assertEqual(len(result[0][\"input_ids\"]), 10)\n\n    def test_raise_strategy_silently_drops_short_sequences(self):\n        \"\"\"Test that 'raise' strategy drops short sequences without raising\"\"\"\n        dataset = Dataset.from_dict(\n            {\n                \"input_ids\": [\n                    [1],  # length 1 - too short, should be dropped silently\n                    [1, 2, 3, 4, 5],  # length 5 - keep\n                ]\n            }\n        )\n\n        cfg = DictDefault(\n            {\n                \"excess_length_strategy\": \"raise\",\n                \"min_sample_len\": 3,\n                \"dataset_num_proc\": None,\n                \"is_preprocess\": False,\n            }\n        )\n\n        # Should NOT raise, just silently drop the short sequence\n        result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)\n\n        self.assertEqual(len(result), 1)\n        self.assertEqual(len(result[0][\"input_ids\"]), 5)\n\n    def test_drop_boundary_sequence_equal_to_sequence_len(self):\n        \"\"\"Test that drop strategy keeps sequences with length exactly equal to sequence_len\"\"\"\n        dataset = Dataset.from_dict(\n            {\n                \"input_ids\": [\n                    [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],  # length 10 == sequence_len\n                    [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],  # length 11 > sequence_len\n                ]\n            }\n        )\n\n        cfg = DictDefault(\n            {\n                \"excess_length_strategy\": \"drop\",\n                \"min_sample_len\": 2,\n                \"dataset_num_proc\": None,\n                \"is_preprocess\": False,\n            }\n        )\n\n        result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)\n\n        # Exactly equal should be kept, one over should be dropped\n        self.assertEqual(len(result), 1)\n        self.assertEqual(len(result[0][\"input_ids\"]), 10)\n\n    def test_truncate_boundary_sequence_equal_to_sequence_len(self):\n        \"\"\"Test that truncate strategy leaves sequences with length exactly equal to sequence_len unchanged\"\"\"\n        dataset = Dataset.from_dict(\n            {\n                \"input_ids\": [\n                    [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],  # length 10 == sequence_len\n                ]\n            }\n        )\n\n        cfg = DictDefault(\n            {\n                \"excess_length_strategy\": \"truncate\",\n                \"min_sample_len\": 2,\n                \"dataset_num_proc\": None,\n                \"is_preprocess\": False,\n            }\n        )\n\n        result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)\n\n        # Should be unchanged - not truncated\n        self.assertEqual(len(result), 1)\n        self.assertEqual(result[0][\"input_ids\"], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])\n\n    def test_empty_dataset(self):\n        \"\"\"Test that an empty dataset is handled gracefully\"\"\"\n        dataset = Dataset.from_dict({\"input_ids\": []})\n\n        cfg = DictDefault(\n            {\n                \"excess_length_strategy\": \"drop\",\n                \"min_sample_len\": 2,\n                \"dataset_num_proc\": None,\n                \"is_preprocess\": False,\n            }\n        )\n\n        result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)\n\n        self.assertEqual(len(result), 0)\n\n    def test_all_sequences_dropped_returns_empty_dataset(self):\n        \"\"\"Test that dropping all sequences results in an empty dataset\"\"\"\n        dataset = Dataset.from_dict(\n            {\n                \"input_ids\": [\n                    [1],  # too short\n                    [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],  # too long\n                ]\n            }\n        )\n\n        cfg = DictDefault(\n            {\n                \"excess_length_strategy\": \"drop\",\n                \"min_sample_len\": 5,\n                \"dataset_num_proc\": None,\n                \"is_preprocess\": False,\n            }\n        )\n\n        result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)\n\n        self.assertEqual(len(result), 0)\n\n    def test_iterable_dataset_skips_processing(self):\n        \"\"\"Test that streaming datasets (column_names is None) are returned unchanged.\n\n        The skip check in _should_skip_processing triggers when column_names is\n        None, which happens with true streaming datasets loaded via\n        load_dataset(..., streaming=True).\n        \"\"\"\n        mock_dataset = MagicMock()\n        mock_dataset.column_names = None\n\n        cfg = DictDefault(\n            {\n                \"excess_length_strategy\": \"drop\",\n                \"min_sample_len\": 2,\n                \"dataset_num_proc\": None,\n                \"is_preprocess\": False,\n            }\n        )\n\n        result = handle_long_seq_in_dataset(mock_dataset, sequence_len=10, cfg=cfg)\n\n        # Should be returned unchanged (same object)\n        self.assertIs(result, mock_dataset)\n\n    def test_truncate_with_partial_auxiliary_fields(self):\n        \"\"\"Test truncation when only some auxiliary fields are present\"\"\"\n        dataset = Dataset.from_dict(\n            {\n                \"input_ids\": [\n                    [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],\n                ],\n                \"labels\": [\n                    [-100, -100, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],\n                ],\n                # No attention_mask or position_ids\n            }\n        )\n\n        cfg = DictDefault(\n            {\n                \"excess_length_strategy\": \"truncate\",\n                \"min_sample_len\": 2,\n                \"dataset_num_proc\": None,\n                \"is_preprocess\": False,\n            }\n        )\n\n        result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)\n\n        self.assertEqual(len(result[0][\"input_ids\"]), 10)\n        self.assertEqual(len(result[0][\"labels\"]), 10)\n        self.assertEqual(result[0][\"input_ids\"], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])\n        self.assertEqual(result[0][\"labels\"], [-100, -100, 3, 4, 5, 6, 7, 8, 9, 10])\n        # Confirm no extra columns were introduced\n        self.assertListEqual(sorted(result.column_names), [\"input_ids\", \"labels\"])\n\n    def test_min_sample_len_defaults_to_two_when_not_set(self):\n        \"\"\"Test that min_sample_len defaults to 2 when not specified in config\"\"\"\n        dataset = Dataset.from_dict(\n            {\n                \"input_ids\": [\n                    [1],  # length 1 - should be dropped (< default 2)\n                    [1, 2],  # length 2 - should be kept (>= default 2)\n                    [1, 2, 3],  # length 3 - should be kept\n                ]\n            }\n        )\n\n        cfg = DictDefault(\n            {\n                \"excess_length_strategy\": \"drop\",\n                # min_sample_len not set\n                \"dataset_num_proc\": None,\n                \"is_preprocess\": False,\n            }\n        )\n\n        result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)\n\n        self.assertEqual(len(result), 2)\n        self.assertEqual(len(result[0][\"input_ids\"]), 2)\n        self.assertEqual(len(result[1][\"input_ids\"]), 3)\n\n    def test_invalid_strategy_falls_through_to_drop(self):\n        \"\"\"Test that an unrecognized strategy value falls through to drop behavior\"\"\"\n        dataset = Dataset.from_dict(\n            {\n                \"input_ids\": [\n                    [1, 2, 3],  # keep\n                    [\n                        1,\n                        2,\n                        3,\n                        4,\n                        5,\n                        6,\n                        7,\n                        8,\n                        9,\n                        10,\n                        11,\n                    ],  # length 11 - should be dropped\n                ]\n            }\n        )\n\n        cfg = DictDefault(\n            {\n                \"excess_length_strategy\": \"not_a_real_strategy\",\n                \"min_sample_len\": 2,\n                \"dataset_num_proc\": None,\n                \"is_preprocess\": False,\n            }\n        )\n\n        result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)\n\n        # Should behave like 'drop'\n        self.assertEqual(len(result), 1)\n        self.assertEqual(len(result[0][\"input_ids\"]), 3)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/utils/lora/test_config_validation_lora.py",
    "content": "import pytest\n\nfrom axolotl.utils.config import validate_config\nfrom axolotl.utils.dict import DictDefault\n\n\nclass TestLoRAConfigValidation:\n    \"\"\"Test suite for LoRA/QLoRA configuration validation\"\"\"\n\n    def test_basic_configuration_validation(self):\n        \"\"\"Test basic LoRA configuration validation\"\"\"\n\n        valid_config = DictDefault(\n            {\n                \"adapter\": \"lora\",\n                \"lora_r\": 8,\n                \"lora_alpha\": 16,\n                \"lora_dropout\": 0.1,\n                \"lora_target_modules\": [\"q_proj\", \"v_proj\"],\n                \"datasets\": [{\"path\": \"dummy_dataset\", \"type\": \"alpaca\"}],\n                \"micro_batch_size\": 1,\n                \"gradient_accumulation_steps\": 1,\n                \"learning_rate\": 1e-5,\n                \"base_model\": \"dummy_model\",\n            }\n        )\n\n        result = validate_config(valid_config)\n        assert result[\"adapter\"] == \"lora\"\n\n        with pytest.raises(ValueError, match=\"not compatible with DoRA\"):\n            invalid_config = DictDefault(\n                {\n                    \"adapter\": \"lora\",\n                    \"lora_mlp_kernel\": True,\n                    \"peft_use_dora\": True,\n                    \"datasets\": [{\"path\": \"dummy_dataset\", \"type\": \"alpaca\"}],\n                    \"micro_batch_size\": 1,\n                    \"gradient_accumulation_steps\": 1,\n                    \"learning_rate\": 1e-5,\n                    \"base_model\": \"dummy_model\",\n                }\n            )\n            validate_config(invalid_config)\n\n    def test_qlora_4bit_validation(self):\n        \"\"\"Test QLoRA 4-bit configuration validation\"\"\"\n        valid_config = DictDefault(\n            {\n                \"adapter\": \"qlora\",\n                \"load_in_4bit\": True,\n                \"bnb_4bit_compute_dtype\": \"float16\",\n                \"datasets\": [{\"path\": \"dummy_dataset\", \"type\": \"alpaca\"}],\n                \"micro_batch_size\": 1,\n                \"gradient_accumulation_steps\": 1,\n                \"learning_rate\": 1e-5,\n                \"base_model\": \"dummy_model\",\n            }\n        )\n        result = validate_config(valid_config)\n        assert result[\"adapter\"] == \"qlora\"\n        assert result[\"load_in_4bit\"] is True\n\n        # Test QLoRA without 4-bit (should fail via PEFT validation)\n        with pytest.raises(ValueError, match=r\"Require cfg\\.load_in_4bit\"):\n            invalid_config = DictDefault(\n                {\n                    \"adapter\": \"qlora\",\n                    \"load_in_4bit\": False,\n                    \"datasets\": [{\"path\": \"dummy_dataset\", \"type\": \"alpaca\"}],\n                    \"micro_batch_size\": 1,\n                    \"gradient_accumulation_steps\": 1,\n                    \"learning_rate\": 1e-5,\n                    \"base_model\": \"dummy_model\",\n                }\n            )\n            validate_config(invalid_config)\n\n        # Test QLoRA with 8-bit (incompatible)\n        with pytest.raises(ValueError, match=\"Can't load qlora in 8bit\"):\n            invalid_config = DictDefault(\n                {\n                    \"adapter\": \"qlora\",\n                    \"load_in_8bit\": True,\n                    \"datasets\": [{\"path\": \"dummy_dataset\", \"type\": \"alpaca\"}],\n                    \"micro_batch_size\": 1,\n                    \"gradient_accumulation_steps\": 1,\n                    \"learning_rate\": 1e-5,\n                    \"base_model\": \"dummy_model\",\n                }\n            )\n            validate_config(invalid_config)\n\n    @pytest.mark.parametrize(\n        \"kernel_field\", [\"lora_mlp_kernel\", \"lora_qkv_kernel\", \"lora_o_kernel\"]\n    )\n    def test_lora_kernels_trust_remote_code_incompatible(self, kernel_field):\n        \"\"\"Test that lora kernels are incompatible with trust_remote_code\"\"\"\n        with pytest.raises(ValueError, match=\"not compatible with trust_remote_code\"):\n            invalid_config = DictDefault(\n                {\n                    \"adapter\": \"lora\",\n                    kernel_field: True,\n                    \"trust_remote_code\": True,\n                    \"datasets\": [{\"path\": \"dummy_dataset\", \"type\": \"alpaca\"}],\n                    \"micro_batch_size\": 1,\n                    \"gradient_accumulation_steps\": 1,\n                    \"learning_rate\": 1e-5,\n                    \"base_model\": \"dummy_model\",\n                }\n            )\n            validate_config(invalid_config)\n\n    def test_lora_kernels_trust_remote_code_false(self):\n        \"\"\"Test that lora kernels work when trust_remote_code is false\"\"\"\n        # Test with trust_remote_code=False, lora kernels should be allowed\n        valid_config = DictDefault(\n            {\n                \"adapter\": \"lora\",\n                \"lora_mlp_kernel\": True,\n                \"lora_qkv_kernel\": True,\n                \"lora_o_kernel\": True,\n                \"trust_remote_code\": False,\n                \"datasets\": [{\"path\": \"dummy_dataset\", \"type\": \"alpaca\"}],\n                \"micro_batch_size\": 1,\n                \"gradient_accumulation_steps\": 1,\n                \"learning_rate\": 1e-5,\n                \"base_model\": \"dummy_model\",\n            }\n        )\n        result = validate_config(valid_config)\n        assert result[\"lora_mlp_kernel\"] is True\n        assert result[\"lora_qkv_kernel\"] is True\n        assert result[\"lora_o_kernel\"] is True\n\n        # Test with trust_remote_code=None (unset), kernels should be allowed\n        valid_config = DictDefault(\n            {\n                \"adapter\": \"lora\",\n                \"lora_qkv_kernel\": True,\n                \"trust_remote_code\": None,\n                \"datasets\": [{\"path\": \"dummy_dataset\", \"type\": \"alpaca\"}],\n                \"micro_batch_size\": 1,\n                \"gradient_accumulation_steps\": 1,\n                \"learning_rate\": 1e-5,\n                \"base_model\": \"dummy_model\",\n            }\n        )\n        result = validate_config(valid_config)\n        assert result[\"lora_qkv_kernel\"] is True\n        assert result[\"trust_remote_code\"] is None\n"
  },
  {
    "path": "tests/utils/lora/test_freeze_lora.py",
    "content": "import importlib.util\nfrom unittest.mock import Mock\n\nimport pytest\nimport torch\nimport torch.nn as nn\n\nfrom axolotl.kernels.lora import get_lora_parameters\n\nPEFT_AVAILABLE = importlib.util.find_spec(\"peft\") is not None\n\n\nclass TestLoRAParameterFreezing:\n    \"\"\"Test suite for LoRA parameter freezing validation.\"\"\"\n\n    def setup_method(self):\n        self.dtype = torch.float32\n\n    def create_mock_lora_layer(\n        self, has_adapters=True, adapters_disabled=False, merged=False\n    ):\n        \"\"\"Create a mock LoRA layer for testing.\"\"\"\n        mock_layer = Mock()\n\n        base_layer = Mock()\n        base_layer.weight = torch.randn(512, 256, dtype=self.dtype)\n        base_layer.bias = torch.randn(512, dtype=self.dtype)\n\n        if has_adapters:\n            mock_layer.base_layer = base_layer\n            mock_layer.disable_adapters = adapters_disabled\n            mock_layer.merged = merged\n\n            mock_layer.active_adapters = [\"default\"]\n            mock_layer.lora_A = {\"default\": Mock()}\n            mock_layer.lora_B = {\"default\": Mock()}\n            mock_layer.scaling = {\"default\": 0.1}\n\n            mock_layer.lora_A[\"default\"].weight = torch.randn(16, 256, dtype=self.dtype)\n            mock_layer.lora_B[\"default\"].weight = torch.randn(512, 16, dtype=self.dtype)\n        else:\n            mock_layer.weight = base_layer.weight\n            mock_layer.bias = base_layer.bias\n\n        return mock_layer\n\n    def test_parameter_freezing_adapters_disabled(self):\n        \"\"\"Test that LoRA parameters are None when adapters are disabled.\"\"\"\n        layer = self.create_mock_lora_layer(has_adapters=True, adapters_disabled=True)\n\n        W, b, quant_state, A, B, s = get_lora_parameters(layer)\n\n        # Base parameters should be returned\n        assert W is not None\n        assert b is not None\n        # LoRA parameters should be None (frozen)\n        assert A is None\n        assert B is None\n        assert s is None\n\n    def test_parameter_freezing_adapters_merged(self):\n        \"\"\"Test that LoRA parameters are None when adapters are merged.\"\"\"\n        layer = self.create_mock_lora_layer(has_adapters=True, merged=True)\n\n        W, b, quant_state, A, B, s = get_lora_parameters(layer)\n\n        # Base parameters should be returned\n        assert W is not None\n        assert b is not None\n\n        # LoRA parameters should be None (frozen)\n        assert A is None\n        assert B is None\n        assert s is None\n\n    def test_parameter_freezing_no_adapters(self):\n        \"\"\"Test parameter behavior when no adapters are present.\"\"\"\n        layer = self.create_mock_lora_layer(has_adapters=False)\n\n        W, b, quant_state, A, B, s = get_lora_parameters(layer)\n\n        # Base parameters should be returned\n        assert W is not None\n        assert b is not None\n\n        # LoRA parameters should be None (frozen)\n        assert A is None\n        assert B is None\n        assert s is None\n\n    def test_parameter_active_adapters_enabled(self):\n        \"\"\"Test that LoRA parameters are returned when adapters are active.\"\"\"\n        layer = self.create_mock_lora_layer(\n            has_adapters=True, adapters_disabled=False, merged=False\n        )\n\n        W, b, quant_state, A, B, s = get_lora_parameters(layer)\n\n        # All parameters should be returned\n        assert W is not None\n        assert b is not None\n        assert A is not None\n        assert B is not None\n        assert s is not None\n        assert s == 0.1\n\n    def test_parameter_shapes_consistency(self):\n        \"\"\"Test that parameter shapes are consistent when active.\"\"\"\n        layer = self.create_mock_lora_layer(\n            has_adapters=True, adapters_disabled=False, merged=False\n        )\n\n        W, b, quant_state, A, B, s = get_lora_parameters(layer)\n\n        # Check shape consistency\n        assert W.shape == (512, 256)\n        assert b.shape == (512,)\n        assert A.shape == (16, 256)\n        assert B.shape == (512, 16)\n\n    def test_parameter_dtypes_consistency(self):\n        \"\"\"Test that parameter dtypes are consistent.\"\"\"\n        layer = self.create_mock_lora_layer(\n            has_adapters=True, adapters_disabled=False, merged=False\n        )\n\n        W, b, quant_state, A, B, s = get_lora_parameters(layer)\n\n        assert W.dtype == self.dtype\n        assert b.dtype == self.dtype\n        assert A.dtype == self.dtype\n        assert B.dtype == self.dtype\n\n    def test_quantization_state_handling(self):\n        \"\"\"Test that quantization state is properly handled.\"\"\"\n        layer = self.create_mock_lora_layer(has_adapters=True)\n\n        quant_state_mock = Mock()\n        layer.base_layer.weight.quant_state = quant_state_mock\n\n        W, b, quant_state, A, B, s = get_lora_parameters(layer)\n\n        assert quant_state == quant_state_mock\n\n    def test_multiple_adapters_active_adapter_selection(self):\n        \"\"\"Test that the correct adapter is selected when multiple adapters exist.\"\"\"\n        layer = self.create_mock_lora_layer(\n            has_adapters=True, adapters_disabled=False, merged=False\n        )\n\n        layer.lora_A[\"adapter2\"] = Mock()\n        layer.lora_B[\"adapter2\"] = Mock()\n        layer.scaling[\"adapter2\"] = 0.2\n\n        layer.lora_A[\"adapter2\"].weight = torch.randn(16, 256, dtype=self.dtype)\n        layer.lora_B[\"adapter2\"].weight = torch.randn(512, 16, dtype=self.dtype)\n\n        layer.active_adapters = [\"adapter2\"]\n\n        W, b, quant_state, A, B, s = get_lora_parameters(layer)\n\n        assert s == 0.2\n        assert torch.equal(A, layer.lora_A[\"adapter2\"].weight)\n        assert torch.equal(B, layer.lora_B[\"adapter2\"].weight)\n\n\nclass TestLoRAParameterFreezingIntegration:\n    \"\"\"Integration tests for parameter freezing with actual LoRA layers.\"\"\"\n\n    @pytest.mark.skipif(\n        not PEFT_AVAILABLE, reason=\"PEFT not available for integration tests\"\n    )\n    def test_parameter_freezing_with_real_lora_layer(self):\n        \"\"\"Test parameter freezing with actual PEFT LoRA layer.\"\"\"\n        from peft import LoraConfig, get_peft_model\n\n        class SimpleModel(nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.linear = nn.Linear(256, 512)\n\n            def forward(self, x):\n                return self.linear(x)\n\n        base_model = SimpleModel()\n        lora_config = LoraConfig(\n            r=16,\n            lora_alpha=32,\n            target_modules=[\"linear\"],\n            lora_dropout=0.1,\n        )\n        model = get_peft_model(base_model, lora_config)\n        lora_layer = model.base_model.model.linear\n        # Test with adapters enabled\n        W, b, quant_state, A, B, s = get_lora_parameters(lora_layer)\n        assert A is not None\n        assert B is not None\n        assert s is not None\n        # Test with adapters disabled\n        model.disable_adapter_layers()\n        W, b, quant_state, A, B, s = get_lora_parameters(lora_layer)\n        assert A is None\n        assert B is None\n        assert s is None\n\n    @pytest.mark.skipif(\n        not PEFT_AVAILABLE, reason=\"PEFT not available for integration tests\"\n    )\n    def test_parameter_freezing_gradient_behavior(self):\n        \"\"\"Test that frozen parameters don't receive gradients.\"\"\"\n        from peft import LoraConfig, get_peft_model\n\n        class SimpleModel(nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.linear = nn.Linear(256, 512)\n\n            def forward(self, x):\n                return self.linear(x)\n\n        base_model = SimpleModel()\n        lora_config = LoraConfig(\n            r=16,\n            lora_alpha=32,\n            target_modules=[\"linear\"],\n            lora_dropout=0.1,\n        )\n        model = get_peft_model(base_model, lora_config)\n        x = torch.randn(1, 256)\n        target = torch.randn(1, 512)\n        model.enable_adapter_layers()\n        output = model(x)\n        loss = nn.MSELoss()(output, target)\n        loss.backward()\n        lora_layer = model.base_model.model.linear\n        has_lora_grads = any(\n            param.grad is not None\n            for name, param in lora_layer.named_parameters()\n            if \"lora_\" in name\n        )\n        assert has_lora_grads, (\n            \"LoRA parameters should have gradients when adapters are enabled\"\n        )\n        model.zero_grad()\n        model.disable_adapter_layers()\n        output = model(x)\n        loss = nn.MSELoss()(output, target)\n        any_requires_grad = any(param.requires_grad for param in model.parameters())\n        if any_requires_grad:\n            loss.backward()\n        has_lora_grads_disabled = any(\n            param.grad is not None\n            for name, param in lora_layer.named_parameters()\n            if \"lora_\" in name\n        )\n        assert not has_lora_grads_disabled, (\n            \"LoRA parameters should not have gradients when adapters are disabled\"\n        )\n        model.zero_grad()\n        del model, base_model, lora_layer, x, target, output, loss\n        torch.cuda.empty_cache() if torch.cuda.is_available() else None\n"
  },
  {
    "path": "tests/utils/lora/test_merge_lora.py",
    "content": "from unittest.mock import Mock, patch\n\nimport torch\n\nfrom axolotl.cli.merge_lora import do_merge_lora\nfrom axolotl.utils.dict import DictDefault\n\n\nclass TestAdapterMergeUnmerge:\n    \"\"\"Test suite for LoRA adapter merging/unmerging functionality\"\"\"\n\n    def setup_method(self):\n        self.dtype = torch.float32\n        self.device = torch.device(\"cpu\")\n\n    def create_mock_base_model(self, vocab_size=1000, hidden_size=256):\n        \"\"\"Create a mock base model with linear layers\"\"\"\n        mock_model = Mock()\n\n        mock_model.config = Mock()\n        mock_model.config.vocab_size = vocab_size\n        mock_model.config.hidden_size = hidden_size\n\n        mock_model.q_proj = Mock()\n        mock_model.q_proj.weight = torch.randn(\n            hidden_size, hidden_size, dtype=self.dtype\n        )\n        mock_model.q_proj.bias = torch.randn(hidden_size, dtype=self.dtype)\n\n        mock_model.v_proj = Mock()\n        mock_model.v_proj.weight = torch.randn(\n            hidden_size, hidden_size, dtype=self.dtype\n        )\n        mock_model.v_proj.bias = torch.randn(hidden_size, dtype=self.dtype)\n\n        return mock_model\n\n    def create_mock_lora_model(self, base_model, r=8, alpha=16):\n        \"\"\"Create a mock LoRA model wrapping the base model\"\"\"\n        mock_lora_model = Mock()\n        mock_lora_model.base_model = base_model\n\n        mock_lora_model.merge_and_unload = None\n        mock_lora_model.to = Mock(return_value=mock_lora_model)\n\n        mock_lora_model.generation_config = Mock()\n        mock_lora_model.config = Mock()\n\n        self.original_q_weight = base_model.q_proj.weight.clone()\n        self.original_v_weight = base_model.v_proj.weight.clone()\n\n        mock_lora_model.peft_config = {\"default\": Mock()}\n        mock_lora_model.peft_config[\"default\"].r = r\n        mock_lora_model.peft_config[\"default\"].lora_alpha = alpha\n\n        self.lora_A_q = torch.randn(\n            r, base_model.q_proj.weight.shape[1], dtype=self.dtype\n        )\n        self.lora_B_q = torch.randn(\n            base_model.q_proj.weight.shape[0], r, dtype=self.dtype\n        )\n\n        self.lora_A_v = torch.randn(\n            r, base_model.v_proj.weight.shape[1], dtype=self.dtype\n        )\n        self.lora_B_v = torch.randn(\n            base_model.v_proj.weight.shape[0], r, dtype=self.dtype\n        )\n\n        self.scaling = alpha / r\n\n        def mock_merge_and_unload(progressbar=False):\n            \"\"\"Simulate the actual merge operation\"\"\"\n            # Apply LoRA delta to base weights: W_new = W_base + (B @ A) * scaling\n            delta_q = (self.lora_B_q @ self.lora_A_q) * self.scaling\n            delta_v = (self.lora_B_v @ self.lora_A_v) * self.scaling\n\n            base_model.q_proj.weight = self.original_q_weight + delta_q\n            base_model.v_proj.weight = self.original_v_weight + delta_v\n\n            return base_model\n\n        mock_lora_model.merge_and_unload = mock_merge_and_unload\n        return mock_lora_model\n\n    def test_basic_lora_merge_unmerge_cycle(self):\n        \"\"\"Test: original_weights -> merge -> unmerge -> should equal original_weights\"\"\"\n\n        base_model = self.create_mock_base_model()\n        lora_model = self.create_mock_lora_model(base_model)\n\n        original_q_weight = self.original_q_weight.clone()\n        original_v_weight = self.original_v_weight.clone()\n\n        merged_model = lora_model.merge_and_unload()\n\n        assert not torch.equal(merged_model.q_proj.weight, original_q_weight)\n        assert not torch.equal(merged_model.v_proj.weight, original_v_weight)\n\n        delta_q = (self.lora_B_q @ self.lora_A_q) * self.scaling\n        delta_v = (self.lora_B_v @ self.lora_A_v) * self.scaling\n\n        unmerged_q_weight = merged_model.q_proj.weight - delta_q\n        unmerged_v_weight = merged_model.v_proj.weight - delta_v\n\n        assert torch.allclose(unmerged_q_weight, original_q_weight, atol=1e-6)\n        assert torch.allclose(unmerged_v_weight, original_v_weight, atol=1e-6)\n\n    def test_merge_weight_calculation_accuracy(self):\n        \"\"\"Test: merged_weight = base_weight + (lora_B @ lora_A * scaling)\"\"\"\n        base_model = self.create_mock_base_model()\n        lora_model = self.create_mock_lora_model(base_model, r=16, alpha=32)\n\n        expected_delta_q = (self.lora_B_q @ self.lora_A_q) * self.scaling\n        expected_merged_q = self.original_q_weight + expected_delta_q\n        merged_model = lora_model.merge_and_unload()\n\n        assert torch.allclose(merged_model.q_proj.weight, expected_merged_q, atol=1e-6)\n\n    @patch(\"axolotl.cli.merge_lora.load_model_and_tokenizer\")\n    def test_cli_do_merge_functionality(self, mock_load_model, tmp_path):\n        base_model = self.create_mock_base_model()\n        lora_model = self.create_mock_lora_model(base_model)\n        tokenizer = Mock()\n        processor = None\n\n        mock_load_model.return_value = (lora_model, tokenizer, processor)\n\n        cfg = DictDefault(\n            {\n                \"save_safetensors\": True,\n                \"torch_dtype\": torch.float32,\n                \"local_rank\": 0,\n                \"output_dir\": str(tmp_path),\n            }\n        )\n\n        with (\n            patch(\"pathlib.Path.mkdir\"),\n            patch.object(base_model, \"save_pretrained\") as mock_save_model,\n            patch.object(tokenizer, \"save_pretrained\") as mock_save_tokenizer,\n        ):\n            do_merge_lora(cfg=cfg)\n\n        mock_save_model.assert_called_once()\n        mock_save_tokenizer.assert_called_once()\n\n    def test_quantized_model_merge_compatibility(self):\n        \"\"\"Test 4-bit/8-bit model merging scenarios\"\"\"\n        base_model = self.create_mock_base_model()\n\n        # Mock quantized weights\n        base_model.q_proj.weight.quant_state = Mock()\n        base_model.q_proj.weight.quant_state.dtype = torch.uint8\n\n        lora_model = self.create_mock_lora_model(base_model)\n\n        merged_model = lora_model.merge_and_unload()\n        assert merged_model is not None\n\n    @patch.dict(\"os.environ\", {\"CUDA_VISIBLE_DEVICES\": \"\"})\n    def test_memory_efficient_merge_with_cpu_offload(self, tmp_path):\n        \"\"\"Test lora_on_cpu configuration during merge\"\"\"\n        cfg = DictDefault(\n            {\n                \"lora_on_cpu\": True,\n                \"save_safetensors\": True,\n                \"output_dir\": str(tmp_path),\n                \"local_rank\": 0,\n            }\n        )\n\n        with patch(\"axolotl.cli.merge_lora.load_model_and_tokenizer\") as mock_load:\n            base_model = self.create_mock_base_model()\n            lora_model = self.create_mock_lora_model(base_model)\n            mock_load.return_value = (lora_model, Mock(), None)\n\n            with patch(\"pathlib.Path.mkdir\"), patch(\"torch.save\"):\n                do_merge_lora(cfg=cfg)\n\n            assert mock_load.called\n"
  },
  {
    "path": "tests/utils/schemas/validation/test_activation_offloading.py",
    "content": "\"\"\"Test for config validation for activation offloading.\"\"\"\n\nfrom axolotl.utils.config import validate_config\nfrom axolotl.utils.dict import DictDefault\n\n\nclass TestActivationOffloading:\n    \"\"\"\n    Test cases for activation offloading schema validation\n    \"\"\"\n\n    def test_gc_converts_offload_wo_lora(self, min_base_cfg):\n        cfg = (\n            DictDefault(\n                gradient_checkpointing=\"offload\",\n            )\n            | min_base_cfg\n        )\n\n        cfg = validate_config(cfg)\n        assert cfg.gradient_checkpointing is True\n        assert cfg.activation_offloading is True\n\n    def test_ac_offload_impl_noop_wo_adapter(self, min_base_cfg):\n        cfg = (\n            DictDefault(\n                gradient_checkpointing=True,\n                activation_offloading=True,\n            )\n            | min_base_cfg\n        )\n\n        cfg = validate_config(cfg)\n        assert cfg.gradient_checkpointing is True\n        assert cfg.activation_offloading is True\n"
  },
  {
    "path": "tests/utils/schemas/validation/test_default_values.py",
    "content": "\"\"\"Tests for default values for configurations\"\"\"\n\nfrom axolotl.utils.config import validate_config\nfrom axolotl.utils.dict import DictDefault\n\n\nclass TestDefaultConfigValues:\n    \"\"\"Tests for default values for configurations\"\"\"\n\n    def test_pad_to_sequence_len(self, min_base_cfg):\n        \"\"\"Tests that sample packing automatically sets pad_to_sequence_len to True\"\"\"\n        cfg = (\n            DictDefault(\n                sample_packing=True,\n            )\n            | min_base_cfg\n        )\n\n        cfg = validate_config(cfg)\n\n        assert cfg.pad_to_sequence_len is True\n"
  },
  {
    "path": "tests/utils/schemas/validation/test_fsdp.py",
    "content": "\"\"\"\ntests for pydantic fsdp validation\n\"\"\"\n\nimport pytest\n\nfrom axolotl.utils.config import validate_config\nfrom axolotl.utils.dict import DictDefault\n\n\nclass TestFSDPValidation:\n    \"\"\"\n    test class for pydantic fsdp validation\n    \"\"\"\n\n    def test_fsdp_version_from_fsdp_config(self, min_base_cfg):\n        cfg = min_base_cfg | DictDefault(\n            fsdp_config={\n                \"version\": 2,\n            },\n        )\n        cfg = validate_config(\n            cfg,\n        )\n        assert cfg.fsdp_version == 2\n\n    def test_fsdp_version_in_fsdp_config(self, min_base_cfg):\n        cfg = min_base_cfg | DictDefault(\n            fsdp_version=2,\n            fsdp_config={\n                \"reshard_after_forward\": True,\n            },\n        )\n        cfg = validate_config(\n            cfg,\n        )\n        assert cfg.fsdp_version == 2\n        assert cfg.fsdp_config.fsdp_version == 2\n\n    def test_fsdp_offload_w_8bit_optim(self, min_base_cfg):\n        cfg = min_base_cfg | DictDefault(\n            fsdp_config={\n                \"offload_params\": True,\n            },\n            optimizer=\"adamw_8bit\",\n            fsdp_version=1,\n        )\n        with pytest.raises(\n            ValueError, match=\"FSDP Offload not compatible with adamw_8bit\"\n        ):\n            validate_config(cfg)\n\n    def test_fsdp2_w_8bit_optim(self, min_base_cfg):\n        cfg = min_base_cfg | DictDefault(\n            fsdp_config={\n                \"offload_params\": True,\n            },\n            optimizer=\"adamw_8bit\",\n            fsdp_version=2,\n        )\n        with pytest.raises(\n            ValueError,\n            match=\"FSDP2 not compatible with adamw_8bit, use `adamw_torch_8bit` instead\",\n        ):\n            validate_config(cfg)\n\n    def test_fsdp2_w_cpu_ram_efficient_loading(self, min_base_cfg):\n        cfg = min_base_cfg | DictDefault(\n            load_in_8bit=True,\n            adapter=\"lora\",\n            fsdp_config={\n                \"cpu_ram_efficient_loading\": True,\n            },\n            fsdp_version=2,\n        )\n        validated_cfg = validate_config(cfg)\n        assert validated_cfg.fsdp_version == 2\n        assert validated_cfg.fsdp_config.cpu_ram_efficient_loading is True\n\n    def test_fsdp2_cpu_offload_pin_memory_requires_offload_params(self, min_base_cfg):\n        cfg = min_base_cfg | DictDefault(\n            fsdp_config={\n                \"cpu_offload_pin_memory\": False,\n                \"offload_params\": False,\n            },\n            fsdp_version=2,\n        )\n        with pytest.raises(\n            ValueError,\n            match=\"disabling cpu_offload_pin_memory requires enabling offload_params\",\n        ):\n            validate_config(cfg)\n\n    def test_fsdp1_cpu_offload_pin_memory_not_supported(self, min_base_cfg):\n        cfg = min_base_cfg | DictDefault(\n            fsdp_config={\n                \"cpu_offload_pin_memory\": False,\n                \"offload_params\": True,\n            },\n            fsdp_version=1,\n        )\n        with pytest.raises(\n            ValueError,\n            match=\"FSDP1 does not support disabling cpu_offload_pin_memory, please set `fsdp_version` to 2\",\n        ):\n            validate_config(cfg)\n\n    def test_fsdp2_cpu_offload_pin_memory_w_offload_params(self, min_base_cfg):\n        cfg = min_base_cfg | DictDefault(\n            fsdp_config={\n                \"cpu_offload_pin_memory\": False,\n                \"offload_params\": True,\n            },\n            fsdp_version=2,\n        )\n        validated_cfg = validate_config(cfg)\n        assert validated_cfg.fsdp_config.cpu_offload_pin_memory is False\n        assert validated_cfg.fsdp_config.offload_params is True\n\n    def test_fsdp_prefixes_removed(self, min_base_cfg):\n        cfg = min_base_cfg | DictDefault(\n            fsdp_config={\n                \"fsdp_version\": 2,\n                \"fsdp_auto_wrap_policy\": \"TRANSFORMER_BASED_WRAP\",\n                \"fsdp_transformer_layer_cls_to_wrap\": \"LlamaDecoderLayer\",\n                \"fsdp_reshard_after_forward\": True,\n            }\n        )\n        cfg = validate_config(cfg)\n        assert cfg.fsdp_version == 2\n        assert cfg.fsdp_config.fsdp_version == 2\n        for key in cfg.fsdp_config.keys():\n            if key != \"fsdp_version\":\n                assert not key.startswith(\"fsdp_\")\n        assert cfg.fsdp_config.auto_wrap_policy == \"TRANSFORMER_BASED_WRAP\"\n        assert cfg.fsdp_config.transformer_layer_cls_to_wrap == \"LlamaDecoderLayer\"\n        assert cfg.fsdp_config.reshard_after_forward is True\n\n    def test_muon_fsdp1_rejected(self, min_base_cfg):\n        cfg = min_base_cfg | DictDefault(\n            optimizer=\"muon\",\n            fsdp_version=1,\n            fsdp_config={\"reshard_after_forward\": True},\n        )\n        with pytest.raises(\n            ValueError, match=\"Muon optimizer is only compatible with FSDP2\"\n        ):\n            validate_config(cfg)\n\n    @pytest.mark.parametrize(\n        \"rl\",\n        [\n            \"dpo\",\n            \"kto\",\n            \"orpo\",\n            \"ipo\",\n        ],\n    )\n    def test_fsdp2_dpo(self, min_base_cfg, rl):\n        cfg = min_base_cfg | DictDefault(\n            fsdp_version=2,\n            fsdp_config={\n                \"reshard_after_forward\": True,\n            },\n            rl=rl,\n            load_in_8bit=True,\n            adapter=\"lora\",\n            remove_unused_columns=False,\n        )\n        with pytest.raises(\n            ValueError,\n            match=\"FSDP2 does not support load_in_8bit or load_in_4bit with \",\n        ):\n            validate_config(cfg)\n"
  },
  {
    "path": "tests/utils/schemas/validation/test_moe_quant.py",
    "content": "\"\"\"Tests for MoE expert quantization config validation and PEFT patch idempotency.\"\"\"\n\nimport pytest\n\nfrom axolotl.utils.config import validate_config\nfrom axolotl.utils.dict import DictDefault\n\n\n@pytest.fixture()\ndef gpu_caps():\n    return {\n        \"compute_capability\": \"sm_89\",\n        \"bf16\": True,\n        \"tf32\": False,\n        \"n_gpu\": 1,\n        \"n_node\": 1,\n    }\n\n\n@pytest.fixture()\ndef env_caps():\n    return {\"torch_version\": \"2.7.0\"}\n\n\nclass TestQuantizeMoeExpertsValidation:\n    \"\"\"Test suite for quantize_moe_experts config validator.\"\"\"\n\n    def test_requires_adapter(self, min_base_cfg, gpu_caps, env_caps):\n        \"\"\"quantize_moe_experts without adapter should fail.\"\"\"\n        cfg = (\n            DictDefault(\n                quantize_moe_experts=True,\n            )\n            | min_base_cfg\n        )\n        with pytest.raises(ValueError, match=\"requires adapter\"):\n            validate_config(cfg, capabilities=gpu_caps, env_capabilities=env_caps)\n\n    def test_requires_quantization(self, min_base_cfg, gpu_caps, env_caps):\n        \"\"\"quantize_moe_experts without load_in_4bit/8bit should fail.\"\"\"\n        cfg = (\n            DictDefault(\n                quantize_moe_experts=True,\n                adapter=\"lora\",\n            )\n            | min_base_cfg\n        )\n        with pytest.raises(ValueError, match=\"requires load_in_4bit or load_in_8bit\"):\n            validate_config(cfg, capabilities=gpu_caps, env_capabilities=env_caps)\n\n    def test_valid_qlora_4bit(self, min_base_cfg, gpu_caps, env_caps):\n        \"\"\"quantize_moe_experts with qlora + 4bit should pass.\"\"\"\n        cfg = (\n            DictDefault(\n                quantize_moe_experts=True,\n                adapter=\"qlora\",\n                load_in_4bit=True,\n            )\n            | min_base_cfg\n        )\n        result = validate_config(cfg, capabilities=gpu_caps, env_capabilities=env_caps)\n        assert result[\"quantize_moe_experts\"] is True\n\n    def test_valid_lora_8bit(self, min_base_cfg, gpu_caps, env_caps):\n        \"\"\"quantize_moe_experts with lora + 8bit should pass.\"\"\"\n        cfg = (\n            DictDefault(\n                quantize_moe_experts=True,\n                adapter=\"lora\",\n                load_in_8bit=True,\n            )\n            | min_base_cfg\n        )\n        result = validate_config(cfg, capabilities=gpu_caps, env_capabilities=env_caps)\n        assert result[\"quantize_moe_experts\"] is True\n\n    def test_false_skips_validation(self, min_base_cfg, gpu_caps, env_caps):\n        \"\"\"quantize_moe_experts=false should not check adapter/quantization.\"\"\"\n        cfg = (\n            DictDefault(\n                quantize_moe_experts=False,\n            )\n            | min_base_cfg\n        )\n        result = validate_config(cfg, capabilities=gpu_caps, env_capabilities=env_caps)\n        assert result[\"quantize_moe_experts\"] is False\n\n    def test_rejects_lora_target_linear(self, min_base_cfg, gpu_caps, env_caps):\n        \"\"\"quantize_moe_experts with lora_target_linear should fail.\"\"\"\n        cfg = (\n            DictDefault(\n                quantize_moe_experts=True,\n                adapter=\"qlora\",\n                load_in_4bit=True,\n                lora_target_linear=True,\n            )\n            | min_base_cfg\n        )\n        with pytest.raises(ValueError, match=\"lora_target_linear is not compatible\"):\n            validate_config(cfg, capabilities=gpu_caps, env_capabilities=env_caps)\n\n    def test_default_is_false(self, min_base_cfg, gpu_caps, env_caps):\n        \"\"\"quantize_moe_experts should default to false.\"\"\"\n        cfg = DictDefault({}) | min_base_cfg\n        result = validate_config(cfg, capabilities=gpu_caps, env_capabilities=env_caps)\n        assert result[\"quantize_moe_experts\"] is False\n\n\nclass TestLoraTargetParametersDropout:\n    \"\"\"Test that lora_dropout must be 0 when lora_target_parameters is set.\"\"\"\n\n    def test_rejects_nonzero_dropout(self, min_base_cfg):\n        \"\"\"lora_dropout > 0 with lora_target_parameters should fail.\"\"\"\n        cfg = (\n            DictDefault(\n                adapter=\"lora\",\n                lora_target_parameters=[\"mlp.experts.gate_up_proj\"],\n                lora_dropout=0.1,\n                load_in_8bit=True,\n            )\n            | min_base_cfg\n        )\n        with pytest.raises(ValueError, match=\"lora_dropout must be 0\"):\n            validate_config(cfg)\n\n    def test_zero_dropout_passes(self, min_base_cfg):\n        \"\"\"lora_dropout=0 with lora_target_parameters should pass.\"\"\"\n        cfg = (\n            DictDefault(\n                adapter=\"lora\",\n                lora_target_parameters=[\"mlp.experts.gate_up_proj\"],\n                lora_dropout=0.0,\n                load_in_8bit=True,\n            )\n            | min_base_cfg\n        )\n        result = validate_config(cfg)\n        assert result[\"lora_dropout\"] == 0.0\n\n\nclass TestPeftPatchIdempotency:\n    \"\"\"Test that patch_peft_target_parameters_matching is idempotent.\"\"\"\n\n    def test_double_call_does_not_stack_wrappers(self):\n        \"\"\"Calling patch twice should not double-wrap _inject_parameters.\"\"\"\n        from peft.tuners.tuners_utils import BaseTuner\n\n        from axolotl.monkeypatch.moe_quant import (\n            patch_peft_target_parameters_matching,\n        )\n\n        original = BaseTuner._inject_parameters\n        try:\n            patch_peft_target_parameters_matching()\n            first_patched = BaseTuner._inject_parameters\n            patch_peft_target_parameters_matching()\n            second_patched = BaseTuner._inject_parameters\n            # Should be same function, not double-wrapped\n            assert first_patched is second_patched\n        finally:\n            BaseTuner._inject_parameters = original\n            patch_peft_target_parameters_matching._axolotl_patched = False\n\n\nclass TestMoeAdapterTrainMergeRoundtrip:\n    \"\"\"E2E: train adapter on quantized MoE experts, then merge onto plain model.\n\n    Verifies that param wrapping order during training matches merge, preventing\n    size mismatch errors when loading adapters in standard PEFT/vLLM.\n    \"\"\"\n\n    @staticmethod\n    def _make_classes():\n        \"\"\"Return FakeExperts and FakeModel classes shared by both model builders.\"\"\"\n        import torch\n        import torch.nn as nn\n\n        class FakeExperts(nn.Module):\n            def __init__(self):\n                super().__init__()\n                # Model definition order: gate_up_proj first, then down_proj.\n                self.gate_up_proj = nn.Parameter(torch.randn(4, 16, 8))\n                self.down_proj = nn.Parameter(torch.randn(4, 8, 16))\n\n            def forward(self, x):\n                x = torch.matmul(x, self.gate_up_proj[0].T)  # (batch, 16)\n                x = torch.matmul(x, self.down_proj[0].T)  # (batch, 8)\n                return x\n\n        class FakeModel(nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.linear = nn.Linear(8, 8)\n                self.experts = FakeExperts()\n\n            def forward(self, x):\n                return self.linear(x) + self.experts(x)\n\n        return FakeExperts, FakeModel\n\n    @staticmethod\n    def _make_quantized_model():\n        \"\"\"Training model: parametrizations registered in alphabetical order.\"\"\"\n        import torch.nn as nn\n        import torch.nn.utils.parametrize as P\n\n        from axolotl.monkeypatch.moe_quant import _moe_load_state\n\n        _, FakeModel = TestMoeAdapterTrainMergeRoundtrip._make_classes()\n\n        class PassthroughParametrization(nn.Module):\n            def forward(self, x):\n                return x\n\n        model = FakeModel()\n\n        # Record definition order before parametrization (mirrors real loading).\n        _moe_load_state[\"expert_param_order\"][\"experts\"] = list(\n            model.experts._parameters.keys()\n        )\n\n        # Register in alphabetical order to expose the ordering mismatch.\n        P.register_parametrization(\n            model.experts, \"down_proj\", PassthroughParametrization(), unsafe=True\n        )\n        P.register_parametrization(\n            model.experts, \"gate_up_proj\", PassthroughParametrization(), unsafe=True\n        )\n        return model\n\n    @staticmethod\n    def _make_plain_model():\n        \"\"\"Merge model: no parametrizations — standard branch uses definition order.\"\"\"\n        _, FakeModel = TestMoeAdapterTrainMergeRoundtrip._make_classes()\n        return FakeModel()\n\n    def test_train_save_merge_no_size_mismatch(self, tmp_path):\n        \"\"\"Train on quantized experts, merge onto plain model — must not raise.\"\"\"\n        import torch\n        from peft import LoraConfig, PeftModel, get_peft_model\n        from peft.tuners.tuners_utils import BaseTuner\n\n        from axolotl.monkeypatch.moe_quant import (\n            _moe_load_state,\n            patch_peft_target_parameters_matching,\n        )\n\n        adapter_dir = tmp_path / \"adapter\"\n        lora_cfg = LoraConfig(\n            r=4,\n            lora_alpha=8,\n            target_modules=[],\n            target_parameters=[\"experts.gate_up_proj\", \"experts.down_proj\"],\n            lora_dropout=0.0,\n            bias=\"none\",\n        )\n        original_inject = BaseTuner._inject_parameters\n\n        # Training phase: quantized model (parametrized branch) with axolotl patch.\n        _moe_load_state[\"expert_param_order\"] = {}\n        patch_peft_target_parameters_matching()\n        try:\n            peft_model = get_peft_model(self._make_quantized_model(), lora_cfg)\n        finally:\n            BaseTuner._inject_parameters = original_inject\n            patch_peft_target_parameters_matching._axolotl_patched = False\n\n        optimizer = torch.optim.SGD(peft_model.parameters(), lr=1e-3)\n        for _ in range(3):\n            peft_model(torch.randn(2, 8)).sum().backward()\n            optimizer.step()\n            optimizer.zero_grad()\n        peft_model.save_pretrained(str(adapter_dir))\n\n        # Merge with standard PEFT (no axolotl patch) to verify external compatibility.\n        loaded = PeftModel.from_pretrained(self._make_plain_model(), str(adapter_dir))\n        merged = loaded.merge_and_unload()\n        assert merged is not None\n"
  },
  {
    "path": "tests/utils/test_grpo_rw_fnc.py",
    "content": "import os\n\nimport pytest\n\nfrom axolotl.core.trainers.grpo import GRPOStrategy\n\n\ndef test_get_rollout_func_loads_successfully():\n    \"\"\"Test that a valid rollout function can be loaded\"\"\"\n    rollout_func = GRPOStrategy.get_rollout_func(\"os.path.join\")\n    assert callable(rollout_func)\n    assert rollout_func == os.path.join\n\n\ndef test_get_rollout_func_invalid_module_raises_error():\n    \"\"\"Test that invalid module path raises clear ValueError\"\"\"\n    with pytest.raises(ValueError, match=\"Rollout function .* not found\"):\n        GRPOStrategy.get_rollout_func(\"nonexistent_module.my_func\")\n"
  },
  {
    "path": "tests/utils/test_import_helper.py",
    "content": "\"\"\"\ntest cases for axolotl.utils.import_helper\n\"\"\"\n\nimport pytest\n\nfrom axolotl.utils.import_helper import get_cls_from_module_str\n\n\ndef test_get_cls_from_module_str():\n    cls = get_cls_from_module_str(\"axolotl.core.trainers.base.AxolotlTrainer\")\n    assert cls.__name__ == \"AxolotlTrainer\"\n\n\ndef test_get_cls_from_module_str_empty_string():\n    with pytest.raises(ValueError, match=\"module_str must be a non-empty string\"):\n        get_cls_from_module_str(\"\")\n\n\ndef test_get_cls_from_module_str_whitespace_only():\n    with pytest.raises(ValueError, match=\"module_str must be a non-empty string\"):\n        get_cls_from_module_str(\"   \")\n\n\ndef test_get_cls_from_module_str_invalid_format():\n    with pytest.raises(ValueError, match=\"Invalid module string format\"):\n        get_cls_from_module_str(\"single_part\")\n\n\ndef test_get_cls_from_module_str_nonexistent_module():\n    with pytest.raises(ImportError, match=\"Failed to import module\"):\n        get_cls_from_module_str(\"nonexistent.module.Class\")\n\n\ndef test_get_cls_from_module_str_nonexistent_class():\n    with pytest.raises(AttributeError, match=\"Class 'NonExistentClass' not found\"):\n        get_cls_from_module_str(\"axolotl.core.trainers.base.NonExistentClass\")\n"
  },
  {
    "path": "tests/utils/test_mistral3_processor.py",
    "content": "\"\"\"Tests for Mistral3Processor with transformers v5 ProcessorMixin integration\"\"\"\n\nfrom unittest.mock import MagicMock\n\nimport pytest\nimport torch\nfrom transformers.feature_extraction_utils import BatchFeature\n\nfrom axolotl.utils.mistral.mistral3_processor import Mistral3Processor\nfrom axolotl.utils.mistral.mistral_tokenizer import HFMistralTokenizer\n\n\n@pytest.fixture()\ndef mock_tokenizer():\n    \"\"\"Create a mock HFMistralTokenizer that passes v5 ProcessorMixin isinstance checks.\"\"\"\n    return MagicMock(spec=HFMistralTokenizer)\n\n\n@pytest.fixture()\ndef processor(mock_tokenizer):\n    return Mistral3Processor(tokenizer=mock_tokenizer)\n\n\nclass TestMistral3ProcessorInit:\n    def test_tokenizer_is_set(self, processor, mock_tokenizer):\n        assert processor.tokenizer is mock_tokenizer\n\n    def test_chat_template_is_none(self, processor):\n        assert processor.chat_template is None\n\n    def test_audio_tokenizer_is_none(self, processor):\n        assert processor.audio_tokenizer is None\n\n\nclass TestApplyChatTemplateTokenized:\n    \"\"\"Test apply_chat_template with tokenize=True, return_dict=True\"\"\"\n\n    @pytest.fixture()\n    def batched_conversations(self):\n        return [\n            [\n                {\"role\": \"user\", \"content\": \"Describe this image.\"},\n                {\"role\": \"assistant\", \"content\": \"It is red.\"},\n            ],\n            [\n                {\"role\": \"user\", \"content\": \"What is this?\"},\n                {\"role\": \"assistant\", \"content\": \"A cat.\"},\n            ],\n        ]\n\n    def test_returns_batch_feature_with_pixel_values(\n        self, processor, mock_tokenizer, batched_conversations\n    ):\n        pixel_values = torch.randn(2, 3, 224, 224, dtype=torch.float64)\n        mock_tokenizer.apply_chat_template.return_value = {\n            \"input_ids\": torch.tensor([[1, 2, 3], [4, 5, 6]]),\n            \"attention_mask\": torch.tensor([[1, 1, 1], [1, 1, 1]]),\n            \"pixel_values\": pixel_values,\n        }\n\n        result = processor.apply_chat_template(\n            batched_conversations, tokenize=True, return_dict=True\n        )\n\n        assert isinstance(result, BatchFeature)\n        assert \"pixel_values\" in result\n        assert \"image_sizes\" in result\n        assert result[\"pixel_values\"].dtype == torch.float32\n        assert result[\"image_sizes\"].shape == (2, 2)\n        assert result[\"image_sizes\"][0].tolist() == [224, 224]\n\n    def test_returns_batch_feature_without_pixel_values(\n        self, processor, mock_tokenizer, batched_conversations\n    ):\n        mock_tokenizer.apply_chat_template.return_value = {\n            \"input_ids\": torch.tensor([[1, 2, 3], [4, 5, 6]]),\n            \"attention_mask\": torch.tensor([[1, 1, 1], [1, 1, 1]]),\n        }\n\n        result = processor.apply_chat_template(\n            batched_conversations, tokenize=True, return_dict=True\n        )\n\n        assert isinstance(result, BatchFeature)\n        assert \"input_ids\" in result\n        assert \"image_sizes\" not in result\n\n\nclass TestApplyChatTemplateNotTokenized:\n    def test_single_conversation_returns_unwrapped(self, processor, mock_tokenizer):\n        \"\"\"Single conversation (not batched) should return unwrapped result.\"\"\"\n        single_conversation = [\n            {\"role\": \"user\", \"content\": \"Hello\"},\n            {\"role\": \"assistant\", \"content\": \"Hi\"},\n        ]\n        mock_tokenizer.apply_chat_template.return_value = [\n            \"<s>[INST]Hello[/INST]Hi</s>\"\n        ]\n\n        result = processor.apply_chat_template(\n            single_conversation, tokenize=False, return_dict=False\n        )\n\n        assert result == \"<s>[INST]Hello[/INST]Hi</s>\"\n\n    def test_batched_conversations_returns_list(self, processor, mock_tokenizer):\n        batched = [\n            [\n                {\"role\": \"user\", \"content\": \"Hello\"},\n                {\"role\": \"assistant\", \"content\": \"Hi\"},\n            ],\n            [\n                {\"role\": \"user\", \"content\": \"Bye\"},\n                {\"role\": \"assistant\", \"content\": \"Bye\"},\n            ],\n        ]\n        mock_tokenizer.apply_chat_template.return_value = [\"text1\", \"text2\"]\n\n        result = processor.apply_chat_template(\n            batched, tokenize=False, return_dict=False\n        )\n\n        assert result == [\"text1\", \"text2\"]\n\n\nclass TestCall:\n    def test_delegates_to_tokenizer(self, processor, mock_tokenizer):\n        mock_tokenizer.return_value = {\n            \"input_ids\": [1, 2, 3],\n            \"attention_mask\": [1, 1, 1],\n        }\n\n        result = processor(\"Hello world\")\n\n        mock_tokenizer.assert_called_once()\n        assert isinstance(result, BatchFeature)\n\n\nclass TestReturnTensorsValidation:\n    def test_rejects_non_pt_return_tensors(self, processor):\n        conversation = [\n            {\"role\": \"user\", \"content\": \"Hello\"},\n            {\"role\": \"assistant\", \"content\": \"Hi\"},\n        ]\n\n        with pytest.raises(ValueError, match=r\"only supports.*return_tensors='pt'\"):\n            processor.apply_chat_template(\n                conversation, tokenize=True, return_dict=True, return_tensors=\"np\"\n            )\n"
  },
  {
    "path": "tests/utils/test_train.py",
    "content": "\"\"\"test for train checkpoint utils\"\"\"\n\nimport os\n\nfrom axolotl.utils.dict import DictDefault\nfrom axolotl.utils.train import determine_last_checkpoint\n\n\ndef test_determine_last_checkpoint(temp_dir):\n    cfg = DictDefault(\n        output_dir=temp_dir,\n    )\n    for cpt_idx in [1, 9, 10, 20]:\n        os.makedirs(\n            os.path.join(cfg.output_dir, f\"checkpoint-{cpt_idx}\"), exist_ok=True\n        )\n\n    last_checkpoint = determine_last_checkpoint(cfg, update=False)\n    assert last_checkpoint == os.path.join(cfg.output_dir, \"checkpoint-20\")\n\n    cfg.resume_from_checkpoint = None\n    cfg.auto_resume_from_checkpoints = True\n    determine_last_checkpoint(cfg, update=True)\n    assert cfg.resume_from_checkpoint == os.path.join(cfg.output_dir, \"checkpoint-20\")\n"
  }
]